You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1415 lines
48 KiB
1415 lines
48 KiB
#!/usr/bin/env python |
|
# encoding: utf-8 |
|
""" |
|
PlotHelpers.py |
|
|
|
Routines to help use matplotlib and make cleaner plots |
|
as well as get plots ready for publication. |
|
|
|
Modified to allow us to use a list of axes, and operate on all of those, |
|
or to use just one axis if that's all that is passed. |
|
Therefore, the first argument to these calls can either be an axes object, |
|
or a list of axes objects. 2/10/2012 pbm. |
|
|
|
Plotter class: a simple class for managing figures with multiple plots. |
|
Uses gridspec to build sets of axes. |
|
|
|
Created by Paul Manis on 2010-03-09. |
|
Copyright 2010-2016 Paul Manis |
|
Distributed under MIT/X11 license. See license.txt for more infofmation. |
|
|
|
""" |
|
|
|
import sys |
|
import os |
|
import string |
|
from collections import OrderedDict |
|
|
|
stdFont = "Arial" |
|
from matplotlib.ticker import FormatStrFormatter |
|
from matplotlib.font_manager import FontProperties |
|
from matplotlib.offsetbox import AnchoredOffsetbox, TextArea, DrawingArea, HPacker |
|
from scipy.stats import gaussian_kde |
|
import numpy as np |
|
import matplotlib.pyplot as mpl |
|
import matplotlib.gridspec as gridspec |
|
from matplotlib.patches import Circle |
|
from matplotlib.patches import Rectangle |
|
from matplotlib.patches import Ellipse |
|
from matplotlib.collections import PatchCollection |
|
import matplotlib |
|
|
|
rcParams = matplotlib.rcParams |
|
rcParams["svg.fonttype"] = "none" # No text as paths. Assume font installed. |
|
rcParams["pdf.fonttype"] = 42 |
|
rcParams["ps.fonttype"] = 42 |
|
# rcParams['font.serif'] = ['Times New Roman'] |
|
from matplotlib import rc |
|
|
|
rc("font", **{"family": "sans-serif", "sans-serif": ["Arial"]}) |
|
# rcParams['font.sans-serif'] = ['Arial'] |
|
# rcParams['font.family'] = 'sans-serif' |
|
# check for LaTeX install - |
|
from distutils.spawn import find_executable |
|
|
|
latex_avail = False |
|
if find_executable("latex"): |
|
latex_avail = True |
|
rc("text", usetex=latex_avail) |
|
rcParams["text.latex.unicode"] = latex_avail |
|
|
|
|
|
def _ax_tolist(ax): |
|
if isinstance(ax, list): |
|
return ax |
|
elif isinstance(ax, dict): |
|
axlist = list(axl.keys()) |
|
return [ax for ax in axl[axlist]] |
|
else: |
|
return [ax] |
|
|
|
|
|
def nice_plot( |
|
axl, spines=["left", "bottom"], position=10, direction="inward", axesoff=False |
|
): |
|
""" Adjust a plot so that it looks nicer than the default matplotlib plot. |
|
Also allow quickaccess to things we like to do for publication plots, including: |
|
using a calbar instead of an axes: calbar = [x0, y0, xs, ys] |
|
inserting a reference line (grey, 3pt dashed, 0.5pt, at refline = y position) |
|
|
|
Parameters |
|
---------- |
|
axl : list of axes objects |
|
If a single axis object is present, it will be converted to a list here. |
|
|
|
spines : list of strings (default : ['left', 'bottom']) |
|
Sets whether spines will occur on particular axes. Choices are 'left', 'right', |
|
'bottom', and 'top'. Chosen spines will be displayed, others are not |
|
|
|
position : float (default : 10) |
|
Determines position of spines in points, typically outward by x points. The |
|
spines are the main axes lines, not the tick marks |
|
if the position is dict, then interpret as such. |
|
|
|
direction : string (default : 'inward') |
|
Sets the direction of spines. Choices are 'inward' and 'outward' |
|
|
|
axesoff : boolean (default : False) |
|
If true, forces the axes to be turned completely off. |
|
|
|
Returns |
|
------- |
|
Nothing. |
|
""" |
|
# print 'NICEPLOT' |
|
if type(axl) is not list: |
|
axl = [axl] |
|
for ax in axl: |
|
if ax is None: |
|
continue |
|
# print 'ax: ', ax |
|
for loc, spine in ax.spines.items(): |
|
if loc in spines: |
|
spine.set_color("k") |
|
# print 'spine color : k' |
|
if type(position) in [int, float]: |
|
spine.set_position(("axes", position)) |
|
elif type(position) is dict: |
|
spine.set_position(("axes", position[loc])) |
|
else: |
|
raise ValueError( |
|
"position must be int, float or dict [ex: ]{'left': -0.05, 'bottom': -0.05}]" |
|
) |
|
else: |
|
spine.set_color("none") |
|
# print 'spine color : none' |
|
if axesoff is True: |
|
noaxes(ax) |
|
|
|
# turn off ticks where there is no spine, if there are axes |
|
if "left" in spines and not axesoff: |
|
ax.yaxis.set_ticks_position("left") |
|
ax.yaxis.set_tick_params(color="k") |
|
else: |
|
ax.yaxis.set_ticks([]) # no yaxis ticks |
|
|
|
if "bottom" in spines and not axesoff: |
|
ax.xaxis.set_ticks_position("bottom") |
|
ax.xaxis.set_tick_params(color="k") |
|
else: |
|
ax.xaxis.set_ticks([]) # no xaxis ticks |
|
|
|
if direction == "inward": |
|
ax.tick_params(axis="y", direction="in") |
|
ax.tick_params(axis="x", direction="in") |
|
else: |
|
ax.tick_params(axis="y", direction="out") |
|
ax.tick_params(axis="x", direction="out") |
|
|
|
|
|
def noaxes(axl, whichaxes="xy"): |
|
""" take away all the axis ticks and the lines |
|
|
|
Parameters |
|
---------- |
|
|
|
axl : list of axes objects |
|
If a single axis object is present, it will be converted to a list here. |
|
|
|
whichaxes : string (default : 'xy') |
|
Sets which axes are turned off. The presence of an 'x' in |
|
the string turns off x, the presence of 'y' turns off y. |
|
|
|
Returns |
|
------- |
|
Nothing |
|
""" |
|
if type(axl) is not list: |
|
axl = [axl] |
|
for ax in axl: |
|
if ax is None: |
|
continue |
|
if "x" in whichaxes: |
|
ax.xaxis.set_ticks([]) |
|
if "y" in whichaxes: |
|
ax.yaxis.set_ticks([]) |
|
if "xy" == whichaxes: |
|
ax.set_axis_off() |
|
|
|
|
|
def setY(ax1, ax2): |
|
""" |
|
Set the Y limits for an axes from a source axes to |
|
the target axes. |
|
|
|
Parameters |
|
---------- |
|
|
|
ax1 : axis object |
|
The source axis object |
|
ax2 : list of axes objects |
|
If a single axis object is present, it will be converted to a list here. |
|
These are the target axes objects that will take on the limits of the source. |
|
|
|
Returns |
|
------- |
|
Nothing |
|
|
|
""" |
|
if type(ax1) is list: |
|
print("PlotHelpers: cannot use list as source to set Y axis") |
|
return |
|
ax2 = _ax_tolist(ax2) |
|
# if type(ax2) is not list: |
|
# ax2 = [ax2] |
|
refy = ax1.get_ylim() |
|
for ax in ax2: |
|
ax.set_ylim(refy) |
|
|
|
|
|
def setX(ax1, ax2): |
|
""" |
|
Set the X limits for an axes from a source axes to |
|
the target axes. |
|
|
|
Parameters |
|
---------- |
|
|
|
ax1 : axis object |
|
The source axis object |
|
ax2 : list of axes objects |
|
If a single axis object is present, it will be converted to a list here. |
|
These are the target axes objects that will take on the limits of the source. |
|
|
|
Returns |
|
------- |
|
Nothing |
|
|
|
""" |
|
if type(ax1) is list: |
|
print("PlotHelpers: cannot use list as source to set Y axis") |
|
return |
|
ax2 = _ax_tolist(ax2) |
|
# if type(ax2) is not list: |
|
# ax2 = [ax2] |
|
refx = ax1.get_xlim() |
|
for ax in ax2: |
|
ax.set_xlim(refx) |
|
|
|
|
|
def labelPanels( |
|
axl, |
|
axlist=None, |
|
font="Arial", |
|
fontsize=18, |
|
weight="normal", |
|
xy=(-0.05, 1.05), |
|
horizontalalignment="right", |
|
verticalalignment="bottom", |
|
rotation=0.0, |
|
): |
|
""" |
|
Provide labeling of panels in a figure with multiple subplots (axes) |
|
|
|
Parameters |
|
---------- |
|
axl : list of axes objects |
|
If a single axis object is present, it will be converted to a list here. |
|
|
|
axlist : list of string labels (default : None) |
|
Contains a list of the string labels. If the default value is provided, |
|
the axes will be lettered in alphabetical sequence. |
|
|
|
font : string (default : 'Arial') |
|
Name of a valid font to use for the panel labels |
|
|
|
fontsize : float (default : 18, in points) |
|
Font size to use for axis labeling |
|
|
|
weight : string (default : 'normal') |
|
Font weight to use for labels. 'Bold', 'Italic', and 'Normal' are options |
|
|
|
xy : tuple (default : (-0.05, 1.05)) |
|
A tuple (x,y) indicating where the label should go relative to the axis frame. |
|
Values are normalized as a fraction of the frame size. |
|
|
|
Returns |
|
------- |
|
list of the annotations |
|
|
|
""" |
|
if isinstance(axl, dict): |
|
axlist = list(axl.keys()) |
|
axl = _ax_tolist(axl) |
|
# if isinstance(axl, dict): |
|
# axt = [axl[x] for x in axl] |
|
# axlist = axl.keys() |
|
# axl = axt |
|
# if not isinstance(axl, list): |
|
# axl = [axl] |
|
if axlist is None: |
|
axlist = string.ascii_uppercase[0 : len(axl)] |
|
# assume we wish to go in sequence |
|
if len(axlist) > len(axl): |
|
raise ValueError( |
|
"axl must have more entries than axlist: got axl=%d and axlist=%d for axlist:" |
|
% (len(axl), len(axlist)), |
|
axlist, |
|
) |
|
font = FontProperties() |
|
font.set_family("sans-serif") |
|
font.set_weight = weight |
|
font.set_size = fontsize |
|
font.set_style("normal") |
|
labels = [] |
|
for i, ax in enumerate(axl): |
|
if i >= len(axlist): |
|
continue |
|
if ax is None: |
|
continue |
|
if isinstance(ax, list): |
|
ax = ax[0] |
|
ann = ax.annotate( |
|
axlist[i], |
|
xy=xy, |
|
xycoords="axes fraction", |
|
annotation_clip=False, |
|
color="k", |
|
verticalalignment=verticalalignment, |
|
weight=weight, |
|
horizontalalignment=horizontalalignment, |
|
fontsize=fontsize, |
|
family="sans-serif", |
|
rotation=rotation, |
|
) |
|
labels.append(ann) |
|
return labels |
|
|
|
|
|
def listAxes(axd): |
|
""" |
|
make a list of the axes from the dictionary |
|
""" |
|
if type(axd) is not dict: |
|
if type(axd) is list: |
|
return axd |
|
else: |
|
print("listAxes expects dictionary or list; type not known (fix the code)") |
|
raise |
|
axl = [axd[x] for x in axd] |
|
return axl |
|
|
|
|
|
def cleanAxes(axl): |
|
axl = _ax_tolist(axl) |
|
for ax in axl: |
|
if ax is None: |
|
continue |
|
for loc, spine in ax.spines.items(): |
|
if loc in ["left", "bottom"]: |
|
spine.set_visible(True) |
|
elif loc in ["right", "top"]: |
|
spine.set_visible(False) |
|
# spine.set_color('none') |
|
# do not draw the spine |
|
else: |
|
raise ValueError("Unknown spine location: %s" % loc) |
|
# turn off ticks when there is no spine |
|
ax.xaxis.set_ticks_position("bottom") |
|
# pdb.set_trace() |
|
ax.yaxis.set_ticks_position("left") # stopped working in matplotlib 1.10 |
|
update_font(ax) |
|
|
|
|
|
def setTicks(axl, axis="x", ticks=np.arange(0, 1.1, 1.0)): |
|
axl = _ax_tolist(axl) |
|
# if type(axl) is dict: |
|
# axl = [axl[x] for x in axl.keys()] |
|
# if type(axl) is not list: |
|
# axl = [axl] |
|
for ax in axl: |
|
if ax is None: |
|
continue |
|
if axis == "x": |
|
ax.set_xticks(ticks) |
|
if axis == "y": |
|
ax.set_yticks(ticks) |
|
|
|
|
|
def formatTicks(axl, axis="xy", fmt="%d", font="Arial"): |
|
""" |
|
Convert tick labels to integers |
|
To do just one axis, set axis = 'x' or 'y' |
|
Control the format with the formatting string |
|
""" |
|
axl = _ax_tolist(axl) |
|
# if type(axl) is not list: |
|
# axl = [axl] |
|
majorFormatter = FormatStrFormatter(fmt) |
|
for ax in axl: |
|
if ax is None: |
|
continue |
|
if "x" in axis: |
|
ax.xaxis.set_major_formatter(majorFormatter) |
|
if "y" in axis: |
|
ax.yaxis.set_major_formatter(majorFormatter) |
|
|
|
|
|
def autoFormatTicks(axl, axis="xy", font="Arial"): |
|
axl = _ax_tolist(axl) |
|
# if type(axl) is not list: |
|
# axl = [axl] |
|
for ax in axl: |
|
if ax is None: |
|
continue |
|
if "x" in axis: |
|
# print ax.get_xlim() |
|
x0, x1 = ax.get_xlim() |
|
setFormatter(ax, x0, x1, axis="x") |
|
if "y" in axis: |
|
y0, y1 = ax.get_xlim |
|
setFormatter(ax, y0, y1, axis="y") |
|
|
|
|
|
def setFormatter(axl, x0, x1, axis="x"): |
|
axl = _ax_tolist(axl) |
|
datarange = np.abs(x0 - x1) |
|
mdata = np.ceil(np.log10(datarange)) |
|
if mdata > 0 and mdata <= 4: |
|
majorFormatter = FormatStrFormatter("%d") |
|
elif mdata > 4: |
|
majorFormatter = FormatStrFormatter("%e") |
|
elif mdata <= 0 and mdata > -1: |
|
majorFormatter = FormatStrFormatter("%5.1f") |
|
elif mdata < -1 and mdata > -3: |
|
majorFormatatter = FormatStrFormatter("%6.3f") |
|
else: |
|
majorFormatter = FormatStrFormatter("%e") |
|
for ax in axl: |
|
if axis == "x": |
|
ax.xaxis.set_major_formatter(majorFormatter) |
|
elif axis == "y": |
|
ax.yaxis.set_major_formatter(majorFormatter) |
|
|
|
|
|
def update_font(axl, size=9, font=stdFont): |
|
axl = _ax_tolist(axl) |
|
# if type(axl) is not list: |
|
# axl = [axl] |
|
fontProperties = { |
|
"family": "sans-serif", #'sans-serif': font, |
|
"weight": "normal", |
|
"size": size, |
|
} |
|
for ax in axl: |
|
if ax is None: |
|
continue |
|
for tick in ax.xaxis.get_major_ticks(): |
|
# tick.label1.set_family('sans-serif') |
|
# tick.label1.set_fontname(stdFont) |
|
tick.label1.set_size(size) |
|
|
|
for tick in ax.yaxis.get_major_ticks(): |
|
# tick.label1.set_family('sans-serif') |
|
# tick.label1.set_fontname(stdFont) |
|
tick.label1.set_size(size) |
|
ax.set_xticklabels(ax.get_xticks(), fontProperties) |
|
ax.set_yticklabels(ax.get_yticks(), fontProperties) |
|
ax.xaxis.set_smart_bounds(True) |
|
ax.yaxis.set_smart_bounds(True) |
|
ax.tick_params(axis="both", labelsize=size) |
|
|
|
|
|
def lockPlot(axl, lims, ticks=None): |
|
""" |
|
This routine forces the plot of invisible data to force the axes to take certain |
|
limits and to force the tick marks to appear. |
|
call with the axis and lims (limits) = [x0, x1, y0, y1] |
|
""" |
|
axl = _ax_tolist(axl) |
|
# if type(axl) is not list: |
|
# axl = [axl] |
|
plist = [] |
|
for ax in axl: |
|
if ax is None: |
|
continue |
|
lpl = ax.plot( |
|
[lims[0], lims[0], lims[1], lims[1]], |
|
[lims[2], lims[3], lims[2], lims[3]], |
|
color="none", |
|
marker="", |
|
linestyle="None", |
|
) |
|
plist.extend(lpl) |
|
ax.axis(lims) |
|
return plist # just in case you want to modify these plots later. |
|
|
|
|
|
def adjust_spines( |
|
axl, spines=["left", "bottom"], direction="outward", distance=5, smart=True |
|
): |
|
axl = _ax_tolist(axl) |
|
# if type(axl) is not list: |
|
# axl = [axl] |
|
for ax in axl: |
|
if ax is None: |
|
continue |
|
# turn off ticks where there is no spine |
|
if "left" in spines: |
|
ax.yaxis.set_ticks_position("left") |
|
else: |
|
# no yaxis ticks |
|
ax.yaxis.set_ticks([]) |
|
|
|
if "bottom" in spines: |
|
ax.xaxis.set_ticks_position("bottom") |
|
else: |
|
# no xaxis ticks |
|
ax.xaxis.set_ticks([]) |
|
for loc, spine in ax.spines.items(): |
|
if loc in spines: |
|
spine.set_position((direction, distance)) # outward by 10 points |
|
if smart is True: |
|
spine.set_smart_bounds(True) |
|
else: |
|
spine.set_smart_bounds(False) |
|
else: |
|
spine.set_color("none") # don't draw spine |
|
|
|
|
|
def getLayoutDimensions(n, pref="height"): |
|
""" |
|
Return a tuple of optimized layout dimensions for n axes |
|
|
|
Parameters |
|
---------- |
|
n : int (no default): |
|
Number of plots needed |
|
pref : string (default : 'height') |
|
prefered way to organized the plots (height, or width) |
|
|
|
Returns |
|
------- |
|
(h, w) : tuple |
|
height (rows) and width (columns) |
|
|
|
""" |
|
nopt = np.sqrt(n) |
|
inoptw = int(nopt) |
|
inopth = int(nopt) |
|
while inoptw * inopth < n: |
|
if pref == "width": |
|
inoptw += 1 |
|
if inoptw * inopth > (n - inopth): |
|
inoptw -= 1 |
|
inopth += 1 |
|
else: |
|
inopth += 1 |
|
if inoptw * inopth > (n - inoptw): |
|
inopth -= 1 |
|
inoptw += 1 |
|
|
|
return (inopth, inoptw) |
|
|
|
|
|
def calbar( |
|
axl, |
|
calbar=None, |
|
axesoff=True, |
|
orient="left", |
|
unitNames=None, |
|
fontsize=11, |
|
weight="normal", |
|
font="Arial", |
|
): |
|
""" |
|
draw a calibration bar and label it. The calibration bar is defined as: |
|
[x0, y0, xlen, ylen] |
|
""" |
|
axl = _ax_tolist(axl) |
|
# if type(axl) is not list: |
|
# axl = [axl] |
|
for ax in axl: |
|
if ax is None: |
|
continue |
|
if axesoff is True: |
|
noaxes(ax) |
|
Hfmt = r"{:.0f}" |
|
if calbar[2] < 1.0: |
|
Hfmt = r"{:.1f}" |
|
Vfmt = r" {:.0f}" |
|
if calbar[3] < 1.0: |
|
Vfmt = r" {:.1f}" |
|
if unitNames is not None: |
|
Vfmt = Vfmt + r" " + r"{:s}".format(unitNames["y"]) |
|
Hfmt = Hfmt + r" " + r"{:s}".format(unitNames["x"]) |
|
# print(Vfmt, unitNames['y']) |
|
# print(Vfmt.format(calbar[3])) |
|
font = FontProperties() |
|
font.set_family("sans-serif") |
|
font.set_weight = weight |
|
font.set_size = fontsize |
|
font.set_style("normal") |
|
if calbar is not None: |
|
if orient == "left": # vertical part is on the left |
|
ax.plot( |
|
[calbar[0], calbar[0], calbar[0] + calbar[2]], |
|
[calbar[1] + calbar[3], calbar[1], calbar[1]], |
|
color="k", |
|
linestyle="-", |
|
linewidth=1.5, |
|
) |
|
ax.text( |
|
calbar[0] + 0.05 * calbar[2], |
|
calbar[1] + 0.5 * calbar[3], |
|
Vfmt.format(calbar[3]), |
|
horizontalalignment="left", |
|
verticalalignment="center", |
|
fontsize=fontsize, |
|
weight=weight, |
|
family="sans-serif", |
|
) |
|
elif orient == "right": # vertical part goes on the right |
|
ax.plot( |
|
[calbar[0] + calbar[2], calbar[0] + calbar[2], calbar[0]], |
|
[calbar[1] + calbar[3], calbar[1], calbar[1]], |
|
color="k", |
|
linestyle="-", |
|
linewidth=1.5, |
|
) |
|
ax.text( |
|
calbar[0] + calbar[2] - 0.05 * calbar[2], |
|
calbar[1] + 0.5 * calbar[3], |
|
Vfmt.format(calbar[3]), |
|
horizontalalignment="right", |
|
verticalalignment="center", |
|
fontsize=fontsize, |
|
weight=weight, |
|
family="sans-serif", |
|
) |
|
else: |
|
print("PlotHelpers.py: I did not understand orientation: %s" % (orient)) |
|
print("plotting as if set to left... ") |
|
ax.plot( |
|
[calbar[0], calbar[0], calbar[0] + calbar[2]], |
|
[calbar[1] + calbar[3], calbar[1], calbar[1]], |
|
color="k", |
|
linestyle="-", |
|
linewidth=1.5, |
|
) |
|
ax.text( |
|
calbar[0] + 0.05 * calbar[2], |
|
calbar[1] + 0.5 * calbar[3], |
|
Vfmt.format(calbar[3]), |
|
horizontalalignment="left", |
|
verticalalignment="center", |
|
fontsize=fontsize, |
|
weight=weight, |
|
family="sans-serif", |
|
) |
|
ax.text( |
|
calbar[0] + calbar[2] * 0.5, |
|
calbar[1] - 0.1 * calbar[3], |
|
Hfmt.format(calbar[2]), |
|
horizontalalignment="center", |
|
verticalalignment="top", |
|
fontsize=fontsize, |
|
weight=weight, |
|
family="sans-serif", |
|
) |
|
|
|
|
|
def referenceline( |
|
axl, |
|
reference=None, |
|
limits=None, |
|
color="0.33", |
|
linestyle="--", |
|
linewidth=0.5, |
|
dashes=None, |
|
): |
|
""" |
|
draw a reference line at a particular level of the data on the y axis |
|
returns the line object. |
|
""" |
|
axl = _ax_tolist(axl) |
|
# if type(axl) is not list: |
|
# axl = [axl] |
|
if reference is None: |
|
refeference = 0.0 |
|
for ax in axl: |
|
if ax is None: |
|
continue |
|
if limits is None or type(limits) is not list or len(limits) != 2: |
|
xlims = ax.get_xlim() |
|
else: |
|
xlims = limits |
|
rl, = ax.plot( |
|
[xlims[0], xlims[1]], |
|
[reference, reference], |
|
color=color, |
|
linestyle=linestyle, |
|
linewidth=linewidth, |
|
) |
|
if dashes is not None: |
|
rl.set_dashes(dashes) |
|
return rl |
|
|
|
|
|
def crossAxes(axl, xyzero=[0.0, 0.0], limits=[None, None, None, None]): |
|
""" |
|
Make plot(s) with crossed axes at the data points set by xyzero, and optionally |
|
set axes limits |
|
""" |
|
axl = _ax_tolist(axl) |
|
# if type(axl) is not list: |
|
# axl = [axl] |
|
for ax in axl: |
|
if ax is None: |
|
continue |
|
# ax.set_title('spines at data (1,2)') |
|
# ax.plot(x,y) |
|
ax.spines["left"].set_position(("data", xyzero[0])) |
|
ax.spines["right"].set_color("none") |
|
ax.spines["bottom"].set_position(("data", xyzero[1])) |
|
ax.spines["top"].set_color("none") |
|
ax.spines["left"].set_smart_bounds(True) |
|
ax.spines["bottom"].set_smart_bounds(True) |
|
ax.xaxis.set_ticks_position("bottom") |
|
ax.yaxis.set_ticks_position("left") |
|
if limits[0] is not None: |
|
ax.set_xlim(left=limits[0], right=limits[2]) |
|
ax.set_ylim(bottom=limits[1], top=limits[3]) |
|
|
|
|
|
def violin_plot(ax, data, pos, bp=False, median=False): |
|
""" |
|
create violin plots on an axis |
|
""" |
|
dist = max(pos) - min(pos) |
|
w = min(0.15 * max(dist, 1.0), 0.5) |
|
for d, p in zip(data, pos): |
|
k = gaussian_kde(d) # calculates the kernel density |
|
m = k.dataset.min() # lower bound of violin |
|
M = k.dataset.max() # upper bound of violin |
|
x = np.arange(m, M, (M - m) / 100.0) # support for violin |
|
v = k.evaluate(x) # violin profile (density curve) |
|
v = v / v.max() * w # scaling the violin to the available space |
|
ax.fill_betweenx(x, p, v + p, facecolor="y", alpha=0.3) |
|
ax.fill_betweenx(x, p, -v + p, facecolor="y", alpha=0.3) |
|
if median: |
|
ax.plot([p - 0.5, p + 0.5], [np.median(d), np.median(d)], "-") |
|
if bp: |
|
bpf = ax.boxplot(data, notch=0, positions=pos, vert=1) |
|
mpl.setp(bpf["boxes"], color="black") |
|
mpl.setp(bpf["whiskers"], color="black", linestyle="-") |
|
|
|
|
|
# # from somewhere on the web: |
|
|
|
|
|
class NiceScale: |
|
def __init__(self, minv, maxv): |
|
self.maxTicks = 6 |
|
self.tickSpacing = 0 |
|
self.lst = 10 |
|
self.niceMin = 0 |
|
self.niceMax = 0 |
|
self.minPoint = minv |
|
self.maxPoint = maxv |
|
self.calculate() |
|
|
|
def calculate(self): |
|
self.lst = self.niceNum(self.maxPoint - self.minPoint, False) |
|
self.tickSpacing = self.niceNum(self.lst / (self.maxTicks - 1), True) |
|
self.niceMin = np.floor(self.minPoint / self.tickSpacing) * self.tickSpacing |
|
self.niceMax = np.ceil(self.maxPoint / self.tickSpacing) * self.tickSpacing |
|
|
|
def niceNum(self, lst, rround): |
|
self.lst = lst |
|
exponent = 0 # exponent of range */ |
|
fraction = 0 # fractional part of range */ |
|
niceFraction = 0 # nice, rounded fraction */ |
|
|
|
exponent = np.floor(np.log10(self.lst)) |
|
fraction = self.lst / np.power(10, exponent) |
|
|
|
if self.lst: |
|
if fraction < 1.5: |
|
niceFraction = 1 |
|
elif fraction < 3: |
|
niceFraction = 2 |
|
elif fraction < 7: |
|
niceFraction = 5 |
|
else: |
|
niceFraction = 10 |
|
else: |
|
if fraction <= 1: |
|
niceFraction = 1 |
|
elif fraction <= 2: |
|
niceFraction = 2 |
|
elif fraction <= 5: |
|
niceFraction = 5 |
|
else: |
|
niceFraction = 10 |
|
|
|
return niceFraction * np.power(10, exponent) |
|
|
|
def setMinMaxPoints(self, minPoint, maxPoint): |
|
self.minPoint = minPoint |
|
self.maxPoint = maxPoint |
|
self.calculate() |
|
|
|
def setMaxTicks(self, maxTicks): |
|
self.maxTicks = maxTicks |
|
self.calculate() |
|
|
|
|
|
def circles(x, y, s, c="b", ax=None, vmin=None, vmax=None, **kwargs): |
|
""" |
|
Make a scatter of circles plot of x vs y, where x and y are sequence |
|
like objects of the same lengths. The size of circles are in data scale. |
|
|
|
Parameters |
|
---------- |
|
x,y : scalar or array_like, shape (n, ) |
|
Input data |
|
s : scalar or array_like, shape (n, ) |
|
Radius of circle in data scale (ie. in data unit) |
|
c : color or sequence of color, optional, default : 'b' |
|
`c` can be a single color format string, or a sequence of color |
|
specifications of length `N`, or a sequence of `N` numbers to be |
|
mapped to colors using the `cmap` and `norm` specified via kwargs. |
|
Note that `c` should not be a single numeric RGB or |
|
RGBA sequence because that is indistinguishable from an array of |
|
values to be colormapped. `c` can be a 2-D array in which the |
|
rows are RGB or RGBA, however. |
|
ax : Axes object, optional, default: None |
|
Parent axes of the plot. It uses gca() if not specified. |
|
vmin, vmax : scalar, optional, default: None |
|
`vmin` and `vmax` are used in conjunction with `norm` to normalize |
|
luminance data. If either are `None`, the min and max of the |
|
color array is used. (Note if you pass a `norm` instance, your |
|
settings for `vmin` and `vmax` will be ignored.) |
|
|
|
Returns |
|
------- |
|
paths : `~matplotlib.collections.PathCollection` |
|
|
|
Other parameters |
|
---------------- |
|
kwargs : `~matplotlib.collections.Collection` properties |
|
eg. alpha, edgecolors, facecolors, linewidths, linestyles, norm, cmap |
|
|
|
Examples |
|
-------- |
|
a = np.arange(11) |
|
circles(a, a, a*0.2, c=a, alpha=0.5, edgecolor='none') |
|
|
|
License |
|
-------- |
|
This code is under [The BSD 3-Clause License] |
|
(http://opensource.org/licenses/BSD-3-Clause) |
|
""" |
|
|
|
# import matplotlib.colors as colors |
|
|
|
if ax is None: |
|
ax = mpl.gca() |
|
|
|
if isinstance(c, str): |
|
color = c # ie. use colors.colorConverter.to_rgba_array(c) |
|
else: |
|
color = None # use cmap, norm after collection is created |
|
kwargs.update(color=color) |
|
|
|
if np.isscalar(x): |
|
patches = [Circle((x, y), s)] |
|
elif np.isscalar(s): |
|
patches = [Circle((x_, y_), s) for x_, y_ in zip(x, y)] |
|
else: |
|
patches = [Circle((x_, y_), s_) for x_, y_, s_ in zip(x, y, s)] |
|
collection = PatchCollection(patches, **kwargs) |
|
|
|
if color is None: |
|
collection.set_array(np.asarray(c)) |
|
if vmin is not None or vmax is not None: |
|
collection.set_clim(vmin, vmax) |
|
|
|
ax.add_collection(collection) |
|
ax.autoscale_view() |
|
return collection |
|
|
|
|
|
def rectangles(x, y, sw, sh=None, c="b", ax=None, vmin=None, vmax=None, **kwargs): |
|
""" |
|
Make a scatter of squares plot of x vs y, where x and y are sequence |
|
like objects of the same lengths. The size of sqares are in data scale. |
|
|
|
Parameters |
|
---------- |
|
x,y : scalar or array_like, shape (n, ) |
|
Input data |
|
s : scalar or array_like, shape (n, ) |
|
side of square in data scale (ie. in data unit) |
|
c : color or sequence of color, optional, default : 'b' |
|
`c` can be a single color format string, or a sequence of color |
|
specifications of length `N`, or a sequence of `N` numbers to be |
|
mapped to colors using the `cmap` and `norm` specified via kwargs. |
|
Note that `c` should not be a single numeric RGB or |
|
RGBA sequence because that is indistinguishable from an array of |
|
values to be colormapped. `c` can be a 2-D array in which the |
|
rows are RGB or RGBA, however. |
|
ax : Axes object, optional, default: None |
|
Parent axes of the plot. It uses gca() if not specified. |
|
vmin, vmax : scalar, optional, default: None |
|
`vmin` and `vmax` are used in conjunction with `norm` to normalize |
|
luminance data. If either are `None`, the min and max of the |
|
color array is used. (Note if you pass a `norm` instance, your |
|
settings for `vmin` and `vmax` will be ignored.) |
|
|
|
Returns |
|
------- |
|
paths : `~matplotlib.collections.PathCollection` |
|
|
|
Other parameters |
|
---------------- |
|
kwargs : `~matplotlib.collections.Collection` properties |
|
eg. alpha, edgecolors, facecolors, linewidths, linestyles, norm, cmap |
|
|
|
Examples |
|
-------- |
|
a = np.arange(11) |
|
squaress(a, a, a*0.2, c=a, alpha=0.5, edgecolor='none') |
|
|
|
License |
|
-------- |
|
This code is under [The BSD 3-Clause License] |
|
(http://opensource.org/licenses/BSD-3-Clause) |
|
""" |
|
# import matplotlib.colors as colors |
|
|
|
if ax is None: |
|
ax = mpl.gca() |
|
|
|
if isinstance(c, str): |
|
color = c # ie. use colors.colorConverter.to_rgba_array(c) |
|
else: |
|
color = None # use cmap, norm after collection is created |
|
kwargs.update(color=color) |
|
if sh is None: |
|
sh = sw |
|
x = x - sw / 2.0 # offset as position specified is "lower left corner" |
|
y = y - sh / 2.0 |
|
if np.isscalar(x): |
|
patches = [Rectangle((x, y), sw, sh)] |
|
elif np.isscalar(sw): |
|
patches = [Rectangle((x_, y_), sw, sh) for x_, y_ in zip(x, y)] |
|
else: |
|
patches = [ |
|
Rectangle((x_, y_), sw_, sh_) for x_, y_, sw_, sh_ in zip(x, y, sw, sh) |
|
] |
|
collection = PatchCollection(patches, **kwargs) |
|
|
|
if color is None: |
|
collection.set_array(np.asarray(c)) |
|
if vmin is not None or vmax is not None: |
|
collection.set_clim(vmin, vmax) |
|
|
|
ax.add_collection(collection) |
|
ax.autoscale_view() |
|
return collection |
|
|
|
|
|
def show_figure_grid(fig, figx=10.0, figy=10.0): |
|
""" |
|
Create a background grid with major and minor lines like graph paper |
|
if using default figx and figy, the grid will be in units of the |
|
overall figure on a [0,1,0,1] grid |
|
if figx and figy are in units of inches or cm, then the grid |
|
will be on that scale. |
|
|
|
Figure grid is useful when building figures and placing labels |
|
at absolute locations on the figure. |
|
|
|
Parameters |
|
---------- |
|
|
|
fig : Matplotlib figure handle (no default): |
|
The figure to which the grid will be applied |
|
|
|
figx : float (default: 10.) |
|
# of major lines along the X dimension |
|
|
|
figy : float (default: 10.) |
|
# of major lines along the Y dimension |
|
|
|
""" |
|
backGrid = fig.add_axes([0, 0, 1, 1], frameon=False) |
|
backGrid.set_ylim(0.0, figy) |
|
backGrid.set_xlim(0.0, figx) |
|
backGrid.grid(True) |
|
|
|
backGrid.set_yticks(np.arange(0.0, figy + 0.01, 1.0)) |
|
backGrid.set_yticks(np.arange(0.0, figy + 0.01, 0.1), minor=True) |
|
backGrid.set_xticks(np.arange(0.0, figx + 0.01, 1.0)) |
|
backGrid.set_xticks(np.arange(0.0, figx + 0.01, 0.1), minor=True) |
|
# backGrid.get_xaxis().set_minor_locator(matplotlib.ticker.AutoMinorLocator()) |
|
# backGrid.get_yaxis().set_minor_locator(matplotlib.ticker.AutoMinorLocator()) |
|
backGrid.grid(b=True, which="major", color="g", alpha=0.6, linewidth=0.8) |
|
backGrid.grid(b=True, which="minor", color="g", alpha=0.4, linewidth=0.2) |
|
return backGrid |
|
|
|
|
|
def hide_figure_grid(fig, grid): |
|
grid.grid(False) |
|
|
|
|
|
def delete_figure_grid(fig, grid): |
|
mpl.delete(grid) |
|
|
|
|
|
class Plotter: |
|
""" |
|
The Plotter class provides a simple convenience for plotting data in |
|
an row x column array. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
rcshape=None, |
|
axmap=None, |
|
arrangement=None, |
|
title=None, |
|
label=False, |
|
roworder=True, |
|
refline=None, |
|
figsize=(11, 8.5), |
|
fontsize=10, |
|
position=0, |
|
labeloffset=[0.0, 0.0], |
|
labelsize=12, |
|
): |
|
""" |
|
Create an instance of the plotter. Generates a new matplotlib figure, |
|
and sets up an array of subplots as defined, initializes the counters |
|
|
|
Examples |
|
-------- |
|
Ex. 1: |
|
One way to generate plots on a standard grid, uses gridspec to specify an axis map: |
|
labels = ['A', 'B1', 'B2', 'C1', 'C2', 'D', 'E', 'F'] |
|
gr = [(0, 4, 0, 1), (0, 3, 1, 2), (3, 4, 1, 2), (0, 3, 2, 3), (3, 4, 2, 3), (5, 8, 0, 1), (5, 8, 1, 2), (5, 8, 2, 3)] |
|
axmap = OrderedDict(zip(labels, gr)) |
|
P = PH.Plotter((8, 1), axmap=axmap, label=True, figsize=(8., 6.)) |
|
PH.show_figure_grid(P.figure_handle) |
|
|
|
Ex. 2: |
|
Place plots on defined locations on the page - no messing with gridspec or subplots. |
|
For this version, we just generate N subplots with labels (used to tag each plot) |
|
The "sizer" array then maps the tags to specific panel locations |
|
# define positions for each panel in Figure coordinages (0, 1, 0, 1) |
|
# you don't have to use an ordered dict for this, I just prefer it when debugging |
|
sizer = OrderedDict([('A', [0.08, 0.22, 0.55, 0.4]), ('B1', [0.40, 0.25, 0.65, 0.3]), ('B2', [0.40, 0.25, 0.5, 0.1]), |
|
('C1', [0.72, 0.25, 0.65, 0.3]), ('C2', [0.72, 0.25, 0.5, 0.1]), |
|
('D', [0.08, 0.25, 0.1, 0.3]), ('E', [0.40, 0.25, 0.1, 0.3]), ('F', [0.72, 0.25, 0.1, 0.3]), |
|
]) # dict elements are [left, width, bottom, height] for the axes in the plot. |
|
gr = [(a, a+1, 0, 1) for a in range(0, 8)] # just generate subplots - shape does not matter |
|
axmap = OrderedDict(zip(sizer.keys(), gr)) |
|
P = PH.Plotter((8, 1), axmap=axmap, label=True, figsize=(8., 6.)) |
|
PH.show_figure_grid(P.figure_handle) |
|
P.resize(sizer) # perform positioning magic |
|
P.axdict['B1'] access the plot associated with panel B1 |
|
|
|
Parameters |
|
---------- |
|
rcshape : a list or tuple: 2x1 (no default) |
|
rcshape is an array [row, col] telling us how many rows and columns to build. |
|
default defines a rectangular array r x c of plots |
|
a dict : |
|
None: expect axmap to provide the input... |
|
|
|
axmap : |
|
list of gridspec slices (default : None) |
|
define slices for the axes of a gridspec, allowing for non-rectangular arrangements |
|
The list is defined as: |
|
[(r1t, r1b, c1l, c1r), slice(r2, c2)] |
|
where r1t is the top for row 1 in the grid, r1b is the bottom, etc... |
|
When using this mode, the axarr returned is a 1-D list, as if r is all plots indexed, |
|
and the number of columns is 1. The results match in order the list entered in axmap |
|
|
|
|
|
|
|
arrangement: Ordered Dict (default: None) |
|
Arrangement allows the data to be plotted according to a logical arrangement |
|
The dict keys are the names ("groups") for each column, and the elements are |
|
string names for the entities in the groups |
|
|
|
title : string (default: None) |
|
Provide a title for the entire plot |
|
|
|
label : Boolean (default: False) |
|
If True, sets labels on panels |
|
|
|
roworder : Boolean (default: True) |
|
Define whether labels run in row order first or column order first |
|
|
|
refline : float (default: None) |
|
Define the position of a reference line to be used in all panels |
|
|
|
figsize : tuple (default : (11, 8.5)) |
|
Figure size in inches. Default is for a landscape figure |
|
|
|
fontsize : points (default : 10) |
|
Defines the size of the font to use for panel labels |
|
|
|
position : position of spines (0 means close, 0.05 means break out) |
|
x, y spines.. |
|
Returns |
|
------- |
|
Nothing |
|
""" |
|
self.arrangement = arrangement |
|
self.fontsize = fontsize |
|
self.referenceLines = {} |
|
self.figure_handle = mpl.figure(figsize=figsize) # create the figure |
|
self.figure_handle.set_size_inches(figsize[0], figsize[1], forward=True) |
|
self.axlabels = [] |
|
self.axdict = ( |
|
OrderedDict() |
|
) # make axis label dictionary for indirect access (better!) |
|
if isinstance(fontsize, int): |
|
fontsize = {"tick": fontsize, "label": fontsize, "panel": fontsize} |
|
gridbuilt = False |
|
# compute label offsets |
|
p = [0.0, 0.0] |
|
if label: |
|
if type(labeloffset) is int: |
|
p = [labeloffset, labeloffset] |
|
elif type(labeloffset) is dict: |
|
p = [position["left"], position["bottom"]] |
|
elif type(labeloffset) in [list, tuple]: |
|
p = labeloffset |
|
else: |
|
p = [0.0, 0.0] |
|
|
|
# build axes arrays |
|
# 1. nxm grid |
|
if isinstance(rcshape, list) or isinstance(rcshape, tuple): |
|
rc = rcshape |
|
gs = gridspec.GridSpec(rc[0], rc[1]) # define a grid using gridspec |
|
# assign to axarr |
|
self.axarr = np.empty( |
|
shape=(rc[0], rc[1]), dtype=object |
|
) # use a numpy object array, indexing features |
|
ix = 0 |
|
for r in range(rc[0]): |
|
for c in range(rc[1]): |
|
self.axarr[r, c] = mpl.subplot(gs[ix]) |
|
ix += 1 |
|
gridbuilt = True |
|
# 2. specified values - starts with Nx1 subplots, then reorganizes according to shape boxes |
|
elif isinstance(rcshape, dict): # true for OrderedDict also |
|
nplots = len(list(rcshape.keys())) |
|
gs = gridspec.GridSpec(nplots, 1) |
|
rc = (nplots, 1) |
|
self.axarr = np.empty( |
|
shape=(rc[0], rc[1]), dtype=object |
|
) # use a numpy object array, indexing features |
|
ix = 0 |
|
for r in range(rc[0]): |
|
for c in range(rc[1]): |
|
self.axarr[r, c] = mpl.subplot(gs[ix]) |
|
ix += 1 |
|
gridbuilt = True |
|
for k, pk in enumerate(rcshape.keys()): |
|
self.axdict[pk] = self.axarr[k, 0] |
|
plo = labeloffset |
|
self.axlabels = labelPanels( |
|
self.axarr.tolist(), |
|
axlist=list(rcshape.keys()), |
|
xy=(-0.095 + plo[0], 0.95 + plo[1]), |
|
fontsize=fontsize["panel"], |
|
) |
|
self.resize(rcshape) |
|
else: |
|
raise ValueError("Input rcshape must be list/tuple or dict") |
|
|
|
# create sublots |
|
if axmap is not None: |
|
if isinstance(axmap, list) and not gridbuilt: |
|
self.axarr = np.empty(shape=(len(axmap), 1), dtype=object) |
|
for k, g in enumerate(axmap): |
|
self.axarr[k,] = mpl.subplot(gs[g[0] : g[1], g[2] : g[3]]) |
|
elif isinstance(axmap, dict) or isinstance( |
|
axmap, OrderedDict |
|
): # keys are panel labels |
|
if not gridbuilt: |
|
self.axarr = np.empty( |
|
shape=(len(list(axmap.keys())), 1), dtype=object |
|
) |
|
na = np.prod(self.axarr.shape) # number of axes |
|
for k, pk in enumerate(axmap.keys()): |
|
g = axmap[pk] # get the gridspec info |
|
if not gridbuilt: |
|
self.axarr[k,] = mpl.subplot(gs[g[0] : g[1], g[2] : g[3]]) |
|
self.axdict[pk] = self.axarr.ravel()[k] |
|
else: |
|
raise TypeError("Plotter in PlotHelpers: axmap must be a list or dict") |
|
|
|
if len(self.axdict) == 0: |
|
for i, a in enumerate(self.axarr.flatten()): |
|
label = string.ascii_uppercase[i] |
|
self.axdict[label] = a |
|
|
|
if title is not None: |
|
self.figure_handle.canvas.set_window_title(title) |
|
self.figure_handle.suptitle(title) |
|
self.nrows = self.axarr.shape[0] |
|
if len(self.axarr.shape) > 1: |
|
self.ncolumns = self.axarr.shape[1] |
|
else: |
|
self.ncolumns = 1 |
|
self.row_counter = 0 |
|
self.column_counter = 0 |
|
for i in range(self.nrows): |
|
for j in range(self.ncolumns): |
|
self.axarr[i, j].spines["top"].set_visible(False) |
|
self.axarr[i, j].get_xaxis().set_tick_params( |
|
direction="out", width=0.8, length=4.0 |
|
) |
|
self.axarr[i, j].get_yaxis().set_tick_params( |
|
direction="out", width=0.8, length=4.0 |
|
) |
|
self.axarr[i, j].tick_params( |
|
axis="both", which="major", labelsize=fontsize["tick"] |
|
) |
|
# if i < self.nrows-1: |
|
# self.axarr[i, j].xaxis.set_major_formatter(mpl.NullFormatter()) |
|
nice_plot(self.axarr[i, j], position=position) |
|
if refline is not None: |
|
self.referenceLines[self.axarr[i, j]] = referenceline( |
|
self.axarr[i, j], reference=refline |
|
) |
|
|
|
if label: |
|
if isinstance(axmap, dict) or isinstance( |
|
axmap, OrderedDict |
|
): # in case predefined... |
|
self.axlabels = labelPanels( |
|
self.axarr.ravel().tolist(), |
|
axlist=list(axmap.keys()), |
|
xy=(-0.095 + p[0], 0.95 + p[1]), |
|
fontsize=fontsize["panel"], |
|
) |
|
return |
|
self.axlist = [] |
|
if roworder == True: |
|
for i in range(self.nrows): |
|
for j in range(self.ncolumns): |
|
self.axlist.append(self.axarr[i, j]) |
|
else: |
|
for i in range(self.ncolumns): |
|
for j in range(self.nrows): |
|
self.axlist.append(self.axarr[j, i]) |
|
|
|
if self.nrows * self.ncolumns > 26: # handle large plot using "A1..." |
|
ctxt = string.ascii_uppercase[0 : self.ncolumns] # columns are lettered |
|
rtxt = [ |
|
str(x + 1) for x in range(self.nrows) |
|
] # rows are numbered, starting at 1 |
|
axl = [] |
|
for i in range(self.nrows): |
|
for j in range(self.ncolumns): |
|
axl.append(ctxt[j] + rtxt[i]) |
|
self.axlabels = labelPanels( |
|
self.axlist, axlist=axl, xy=(-0.35 + p[0], 0.75) |
|
) |
|
else: |
|
self.axlabels = labelPanels( |
|
self.axlist, xy=(-0.095 + p[0], 0.95 + p[1]) |
|
) |
|
|
|
def _next(self): |
|
""" |
|
Private function |
|
_next gets the axis pointer to the next row, column index that is available |
|
Only sets internal variables |
|
""" |
|
self.column_counter += 1 |
|
if self.column_counter >= self.ncolumns: |
|
self.row_counter += 1 |
|
self.column_counter = 0 |
|
if self.row_counter >= self.nrows: |
|
raise ValueError( |
|
"Call to get next row exceeds the number of rows requested initially: %d" |
|
% self.nrows |
|
) |
|
|
|
def getaxis(self, group=None): |
|
""" |
|
getaxis gets the current row, column counter, and calls _next to increment the counter |
|
(so that the next getaxis returns the next available axis pointer) |
|
|
|
Parameters |
|
---------- |
|
group : string (default: None) |
|
forces the current axis to be selected from text name of a "group" |
|
|
|
Returns |
|
------- |
|
the current axis or the axis associated with a group |
|
""" |
|
|
|
if group is None: |
|
currentaxis = self.axarr[self.row_counter, self.column_counter] |
|
self._next() # prepare for next call |
|
else: |
|
currentaxis = self.getRC(group) |
|
|
|
return currentaxis |
|
|
|
def getRC(self, group): |
|
""" |
|
Get the axis associated with a group |
|
|
|
Parameters |
|
---------- |
|
group : string (default: None) |
|
returns the matplotlib axis associated with a text name of a "group" |
|
|
|
Returns |
|
------- |
|
The matplotlib axis associated with the group name, or None if no group by |
|
that name exists in the arrangement |
|
""" |
|
|
|
if self.arrangement is None: |
|
raise ValueError("specifying a group requires an arrangment dictionary") |
|
# look for the group label in the arrangement dicts |
|
for c, colname in enumerate(self.arrangement.keys()): |
|
if group in self.arrangement[colname]: |
|
# print ('column name, column: ', colname, self.arrangement[colname]) |
|
# print ('group: ', group) |
|
r = self.arrangement[colname].index( |
|
group |
|
) # get the row position this way |
|
return self.axarr[r, c] |
|
print(("Group {:s} not in the arrangement".format(group))) |
|
return None |
|
|
|
sizer = { |
|
"A": {"pos": [0.08, 0.22, 0.50, 0.4]}, |
|
"B1": {"pos": [0.40, 0.25, 0.60, 0.3]}, |
|
"B2": {"pos": [0.40, 0.25, 0.5, 0.1]}, |
|
"C1": {"pos": [0.72, 0.25, 0.60, 0.3]}, |
|
"C2": {"pos": [0.72, 0.25, 0.5, 0.1]}, |
|
"D": {"pos": [0.08, 0.25, 0.1, 0.3]}, |
|
"E": {"pos": [0.40, 0.25, 0.1, 0.3]}, |
|
"F": {"pos": [0.72, 0.25, 0.1, 0.3]}, |
|
} |
|
|
|
def resize(self, sizer): |
|
""" |
|
Resize the graphs in the array. |
|
|
|
Parameters |
|
---------- |
|
sizer : dict (no default) |
|
A dictionary with keys corresponding to the plot labels. |
|
The values for each key are a list (or tuple) of [left, width, bottom, height] |
|
for each panel in units of the graph [0, 1, 0, 1]. |
|
|
|
sizer = {'A': {'pos': [0.08, 0.22, 0.50, 0.4], 'labelpos': (x,y), 'noaxes': True}, 'B1': {'pos': [0.40, 0.25, 0.60, 0.3], 'labelpos': (x,y)}, |
|
'B2': {'pos': [0.40, 0.25, 0.5, 0.1],, 'labelpos': (x,y), 'noaxes': False}, |
|
'C1': {'pos': [0.72, 0.25, 0.60, 0.3], 'labelpos': (x,y)}, 'C2': {'pos': [0.72, 0.25, 0.5, 0.1], 'labelpos': (x,y)}, |
|
'D': {'pos': [0.08, 0.25, 0.1, 0.3], 'labelpos': (x,y)}, |
|
'E': {'pos': [0.40, 0.25, 0.1, 0.3], 'labelpos': (x,y)}, 'F': {'pos': [0.72, 0.25, 0.1, 0.3],, 'labelpos': (x,y)} |
|
} |
|
Returns |
|
------- |
|
Nothing |
|
""" |
|
|
|
for i, s in enumerate(sizer.keys()): |
|
ax = self.axdict[s] |
|
bbox = ax.get_position() |
|
bbox.x0 = sizer[s]["pos"][0] |
|
bbox.x1 = sizer[s]["pos"][1] + sizer[s]["pos"][0] |
|
bbox.y0 = sizer[s]["pos"][2] |
|
bbox.y1 = ( |
|
sizer[s]["pos"][3] + sizer[s]["pos"][2] |
|
) # offsets are in figure fractions |
|
ax.set_position(bbox) |
|
if "labelpos" in list(sizer[s].keys()) and len(sizer[s]["labelpos"]) == 2: |
|
x, y = sizer[s]["labelpos"] |
|
self.axlabels[i].set_x(x) |
|
self.axlabels[i].set_y(y) |
|
if "noaxes" in sizer[s] and sizer[s]["noaxes"] == True: |
|
noaxes(ax) |
|
|
|
|
|
if __name__ == "__main__": |
|
# P = Plotter((3,3), axmap=[(0, 1, 0, 3), (1, 2, 0, 2), (2, 1, 2, 3), (2, 3, 0, 1), (2, 3, 1, 2)]) |
|
labels = ["A", "B", "C", "D", "E", "F", "G", "H", "I"] |
|
l = [(a, a + 2, 0, 1) for a in range(0, 6, 2)] |
|
r = [(a, a + 1, 1, 2) for a in range(0, 6)] |
|
axmap = OrderedDict(list(zip(labels, l + r))) |
|
P = Plotter((6, 2), axmap=axmap, figsize=(6.0, 6.0), label=True) |
|
# P = Plotter((2,3), label=True) # create a figure with plots |
|
# for a in P.axarr.flatten(): |
|
# a.plot(np.random.random(10), np.random.random(10)) |
|
|
|
# hfig, ax = mpl.subplots(2, 3) |
|
axd = OrderedDict() |
|
for i, a in enumerate(P.axarr.flatten()): |
|
label = string.ascii_uppercase[i] |
|
axd[label] = a |
|
for a in list(axd.keys()): |
|
axd[a].plot(np.random.random(10), np.random.random(10)) |
|
nice_plot([axd[a] for a in axd], position=-0.1) |
|
cleanAxes([axd["B"], axd["C"]]) |
|
calbar([axd["B"], axd["C"]], calbar=[0.5, 0.5, 0.2, 0.2]) |
|
# labelPanels([axd[a] for a in axd], axd.keys()) |
|
# mpl.tight_layout(pad=2, w_pad=0.5, h_pad=2.0) |
|
mpl.show()
|
|
|