model of DCN pyramidal neuron
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

1415 lines
48 KiB

#!/usr/bin/env python
# encoding: utf-8
"""
PlotHelpers.py
Routines to help use matplotlib and make cleaner plots
as well as get plots ready for publication.
Modified to allow us to use a list of axes, and operate on all of those,
or to use just one axis if that's all that is passed.
Therefore, the first argument to these calls can either be an axes object,
or a list of axes objects. 2/10/2012 pbm.
Plotter class: a simple class for managing figures with multiple plots.
Uses gridspec to build sets of axes.
Created by Paul Manis on 2010-03-09.
Copyright 2010-2016 Paul Manis
Distributed under MIT/X11 license. See license.txt for more infofmation.
"""
import sys
import os
import string
from collections import OrderedDict
stdFont = "Arial"
from matplotlib.ticker import FormatStrFormatter
from matplotlib.font_manager import FontProperties
from matplotlib.offsetbox import AnchoredOffsetbox, TextArea, DrawingArea, HPacker
from scipy.stats import gaussian_kde
import numpy as np
import matplotlib.pyplot as mpl
import matplotlib.gridspec as gridspec
from matplotlib.patches import Circle
from matplotlib.patches import Rectangle
from matplotlib.patches import Ellipse
from matplotlib.collections import PatchCollection
import matplotlib
rcParams = matplotlib.rcParams
rcParams["svg.fonttype"] = "none" # No text as paths. Assume font installed.
rcParams["pdf.fonttype"] = 42
rcParams["ps.fonttype"] = 42
# rcParams['font.serif'] = ['Times New Roman']
from matplotlib import rc
rc("font", **{"family": "sans-serif", "sans-serif": ["Arial"]})
# rcParams['font.sans-serif'] = ['Arial']
# rcParams['font.family'] = 'sans-serif'
# check for LaTeX install -
from distutils.spawn import find_executable
latex_avail = False
if find_executable("latex"):
latex_avail = True
rc("text", usetex=latex_avail)
rcParams["text.latex.unicode"] = latex_avail
def _ax_tolist(ax):
if isinstance(ax, list):
return ax
elif isinstance(ax, dict):
axlist = list(axl.keys())
return [ax for ax in axl[axlist]]
else:
return [ax]
def nice_plot(
axl, spines=["left", "bottom"], position=10, direction="inward", axesoff=False
):
""" Adjust a plot so that it looks nicer than the default matplotlib plot.
Also allow quickaccess to things we like to do for publication plots, including:
using a calbar instead of an axes: calbar = [x0, y0, xs, ys]
inserting a reference line (grey, 3pt dashed, 0.5pt, at refline = y position)
Parameters
----------
axl : list of axes objects
If a single axis object is present, it will be converted to a list here.
spines : list of strings (default : ['left', 'bottom'])
Sets whether spines will occur on particular axes. Choices are 'left', 'right',
'bottom', and 'top'. Chosen spines will be displayed, others are not
position : float (default : 10)
Determines position of spines in points, typically outward by x points. The
spines are the main axes lines, not the tick marks
if the position is dict, then interpret as such.
direction : string (default : 'inward')
Sets the direction of spines. Choices are 'inward' and 'outward'
axesoff : boolean (default : False)
If true, forces the axes to be turned completely off.
Returns
-------
Nothing.
"""
# print 'NICEPLOT'
if type(axl) is not list:
axl = [axl]
for ax in axl:
if ax is None:
continue
# print 'ax: ', ax
for loc, spine in ax.spines.items():
if loc in spines:
spine.set_color("k")
# print 'spine color : k'
if type(position) in [int, float]:
spine.set_position(("axes", position))
elif type(position) is dict:
spine.set_position(("axes", position[loc]))
else:
raise ValueError(
"position must be int, float or dict [ex: ]{'left': -0.05, 'bottom': -0.05}]"
)
else:
spine.set_color("none")
# print 'spine color : none'
if axesoff is True:
noaxes(ax)
# turn off ticks where there is no spine, if there are axes
if "left" in spines and not axesoff:
ax.yaxis.set_ticks_position("left")
ax.yaxis.set_tick_params(color="k")
else:
ax.yaxis.set_ticks([]) # no yaxis ticks
if "bottom" in spines and not axesoff:
ax.xaxis.set_ticks_position("bottom")
ax.xaxis.set_tick_params(color="k")
else:
ax.xaxis.set_ticks([]) # no xaxis ticks
if direction == "inward":
ax.tick_params(axis="y", direction="in")
ax.tick_params(axis="x", direction="in")
else:
ax.tick_params(axis="y", direction="out")
ax.tick_params(axis="x", direction="out")
def noaxes(axl, whichaxes="xy"):
""" take away all the axis ticks and the lines
Parameters
----------
axl : list of axes objects
If a single axis object is present, it will be converted to a list here.
whichaxes : string (default : 'xy')
Sets which axes are turned off. The presence of an 'x' in
the string turns off x, the presence of 'y' turns off y.
Returns
-------
Nothing
"""
if type(axl) is not list:
axl = [axl]
for ax in axl:
if ax is None:
continue
if "x" in whichaxes:
ax.xaxis.set_ticks([])
if "y" in whichaxes:
ax.yaxis.set_ticks([])
if "xy" == whichaxes:
ax.set_axis_off()
def setY(ax1, ax2):
"""
Set the Y limits for an axes from a source axes to
the target axes.
Parameters
----------
ax1 : axis object
The source axis object
ax2 : list of axes objects
If a single axis object is present, it will be converted to a list here.
These are the target axes objects that will take on the limits of the source.
Returns
-------
Nothing
"""
if type(ax1) is list:
print("PlotHelpers: cannot use list as source to set Y axis")
return
ax2 = _ax_tolist(ax2)
# if type(ax2) is not list:
# ax2 = [ax2]
refy = ax1.get_ylim()
for ax in ax2:
ax.set_ylim(refy)
def setX(ax1, ax2):
"""
Set the X limits for an axes from a source axes to
the target axes.
Parameters
----------
ax1 : axis object
The source axis object
ax2 : list of axes objects
If a single axis object is present, it will be converted to a list here.
These are the target axes objects that will take on the limits of the source.
Returns
-------
Nothing
"""
if type(ax1) is list:
print("PlotHelpers: cannot use list as source to set Y axis")
return
ax2 = _ax_tolist(ax2)
# if type(ax2) is not list:
# ax2 = [ax2]
refx = ax1.get_xlim()
for ax in ax2:
ax.set_xlim(refx)
def labelPanels(
axl,
axlist=None,
font="Arial",
fontsize=18,
weight="normal",
xy=(-0.05, 1.05),
horizontalalignment="right",
verticalalignment="bottom",
rotation=0.0,
):
"""
Provide labeling of panels in a figure with multiple subplots (axes)
Parameters
----------
axl : list of axes objects
If a single axis object is present, it will be converted to a list here.
axlist : list of string labels (default : None)
Contains a list of the string labels. If the default value is provided,
the axes will be lettered in alphabetical sequence.
font : string (default : 'Arial')
Name of a valid font to use for the panel labels
fontsize : float (default : 18, in points)
Font size to use for axis labeling
weight : string (default : 'normal')
Font weight to use for labels. 'Bold', 'Italic', and 'Normal' are options
xy : tuple (default : (-0.05, 1.05))
A tuple (x,y) indicating where the label should go relative to the axis frame.
Values are normalized as a fraction of the frame size.
Returns
-------
list of the annotations
"""
if isinstance(axl, dict):
axlist = list(axl.keys())
axl = _ax_tolist(axl)
# if isinstance(axl, dict):
# axt = [axl[x] for x in axl]
# axlist = axl.keys()
# axl = axt
# if not isinstance(axl, list):
# axl = [axl]
if axlist is None:
axlist = string.ascii_uppercase[0 : len(axl)]
# assume we wish to go in sequence
if len(axlist) > len(axl):
raise ValueError(
"axl must have more entries than axlist: got axl=%d and axlist=%d for axlist:"
% (len(axl), len(axlist)),
axlist,
)
font = FontProperties()
font.set_family("sans-serif")
font.set_weight = weight
font.set_size = fontsize
font.set_style("normal")
labels = []
for i, ax in enumerate(axl):
if i >= len(axlist):
continue
if ax is None:
continue
if isinstance(ax, list):
ax = ax[0]
ann = ax.annotate(
axlist[i],
xy=xy,
xycoords="axes fraction",
annotation_clip=False,
color="k",
verticalalignment=verticalalignment,
weight=weight,
horizontalalignment=horizontalalignment,
fontsize=fontsize,
family="sans-serif",
rotation=rotation,
)
labels.append(ann)
return labels
def listAxes(axd):
"""
make a list of the axes from the dictionary
"""
if type(axd) is not dict:
if type(axd) is list:
return axd
else:
print("listAxes expects dictionary or list; type not known (fix the code)")
raise
axl = [axd[x] for x in axd]
return axl
def cleanAxes(axl):
axl = _ax_tolist(axl)
for ax in axl:
if ax is None:
continue
for loc, spine in ax.spines.items():
if loc in ["left", "bottom"]:
spine.set_visible(True)
elif loc in ["right", "top"]:
spine.set_visible(False)
# spine.set_color('none')
# do not draw the spine
else:
raise ValueError("Unknown spine location: %s" % loc)
# turn off ticks when there is no spine
ax.xaxis.set_ticks_position("bottom")
# pdb.set_trace()
ax.yaxis.set_ticks_position("left") # stopped working in matplotlib 1.10
update_font(ax)
def setTicks(axl, axis="x", ticks=np.arange(0, 1.1, 1.0)):
axl = _ax_tolist(axl)
# if type(axl) is dict:
# axl = [axl[x] for x in axl.keys()]
# if type(axl) is not list:
# axl = [axl]
for ax in axl:
if ax is None:
continue
if axis == "x":
ax.set_xticks(ticks)
if axis == "y":
ax.set_yticks(ticks)
def formatTicks(axl, axis="xy", fmt="%d", font="Arial"):
"""
Convert tick labels to integers
To do just one axis, set axis = 'x' or 'y'
Control the format with the formatting string
"""
axl = _ax_tolist(axl)
# if type(axl) is not list:
# axl = [axl]
majorFormatter = FormatStrFormatter(fmt)
for ax in axl:
if ax is None:
continue
if "x" in axis:
ax.xaxis.set_major_formatter(majorFormatter)
if "y" in axis:
ax.yaxis.set_major_formatter(majorFormatter)
def autoFormatTicks(axl, axis="xy", font="Arial"):
axl = _ax_tolist(axl)
# if type(axl) is not list:
# axl = [axl]
for ax in axl:
if ax is None:
continue
if "x" in axis:
# print ax.get_xlim()
x0, x1 = ax.get_xlim()
setFormatter(ax, x0, x1, axis="x")
if "y" in axis:
y0, y1 = ax.get_xlim
setFormatter(ax, y0, y1, axis="y")
def setFormatter(axl, x0, x1, axis="x"):
axl = _ax_tolist(axl)
datarange = np.abs(x0 - x1)
mdata = np.ceil(np.log10(datarange))
if mdata > 0 and mdata <= 4:
majorFormatter = FormatStrFormatter("%d")
elif mdata > 4:
majorFormatter = FormatStrFormatter("%e")
elif mdata <= 0 and mdata > -1:
majorFormatter = FormatStrFormatter("%5.1f")
elif mdata < -1 and mdata > -3:
majorFormatatter = FormatStrFormatter("%6.3f")
else:
majorFormatter = FormatStrFormatter("%e")
for ax in axl:
if axis == "x":
ax.xaxis.set_major_formatter(majorFormatter)
elif axis == "y":
ax.yaxis.set_major_formatter(majorFormatter)
def update_font(axl, size=9, font=stdFont):
axl = _ax_tolist(axl)
# if type(axl) is not list:
# axl = [axl]
fontProperties = {
"family": "sans-serif", #'sans-serif': font,
"weight": "normal",
"size": size,
}
for ax in axl:
if ax is None:
continue
for tick in ax.xaxis.get_major_ticks():
# tick.label1.set_family('sans-serif')
# tick.label1.set_fontname(stdFont)
tick.label1.set_size(size)
for tick in ax.yaxis.get_major_ticks():
# tick.label1.set_family('sans-serif')
# tick.label1.set_fontname(stdFont)
tick.label1.set_size(size)
ax.set_xticklabels(ax.get_xticks(), fontProperties)
ax.set_yticklabels(ax.get_yticks(), fontProperties)
ax.xaxis.set_smart_bounds(True)
ax.yaxis.set_smart_bounds(True)
ax.tick_params(axis="both", labelsize=size)
def lockPlot(axl, lims, ticks=None):
"""
This routine forces the plot of invisible data to force the axes to take certain
limits and to force the tick marks to appear.
call with the axis and lims (limits) = [x0, x1, y0, y1]
"""
axl = _ax_tolist(axl)
# if type(axl) is not list:
# axl = [axl]
plist = []
for ax in axl:
if ax is None:
continue
lpl = ax.plot(
[lims[0], lims[0], lims[1], lims[1]],
[lims[2], lims[3], lims[2], lims[3]],
color="none",
marker="",
linestyle="None",
)
plist.extend(lpl)
ax.axis(lims)
return plist # just in case you want to modify these plots later.
def adjust_spines(
axl, spines=["left", "bottom"], direction="outward", distance=5, smart=True
):
axl = _ax_tolist(axl)
# if type(axl) is not list:
# axl = [axl]
for ax in axl:
if ax is None:
continue
# turn off ticks where there is no spine
if "left" in spines:
ax.yaxis.set_ticks_position("left")
else:
# no yaxis ticks
ax.yaxis.set_ticks([])
if "bottom" in spines:
ax.xaxis.set_ticks_position("bottom")
else:
# no xaxis ticks
ax.xaxis.set_ticks([])
for loc, spine in ax.spines.items():
if loc in spines:
spine.set_position((direction, distance)) # outward by 10 points
if smart is True:
spine.set_smart_bounds(True)
else:
spine.set_smart_bounds(False)
else:
spine.set_color("none") # don't draw spine
def getLayoutDimensions(n, pref="height"):
"""
Return a tuple of optimized layout dimensions for n axes
Parameters
----------
n : int (no default):
Number of plots needed
pref : string (default : 'height')
prefered way to organized the plots (height, or width)
Returns
-------
(h, w) : tuple
height (rows) and width (columns)
"""
nopt = np.sqrt(n)
inoptw = int(nopt)
inopth = int(nopt)
while inoptw * inopth < n:
if pref == "width":
inoptw += 1
if inoptw * inopth > (n - inopth):
inoptw -= 1
inopth += 1
else:
inopth += 1
if inoptw * inopth > (n - inoptw):
inopth -= 1
inoptw += 1
return (inopth, inoptw)
def calbar(
axl,
calbar=None,
axesoff=True,
orient="left",
unitNames=None,
fontsize=11,
weight="normal",
font="Arial",
):
"""
draw a calibration bar and label it. The calibration bar is defined as:
[x0, y0, xlen, ylen]
"""
axl = _ax_tolist(axl)
# if type(axl) is not list:
# axl = [axl]
for ax in axl:
if ax is None:
continue
if axesoff is True:
noaxes(ax)
Hfmt = r"{:.0f}"
if calbar[2] < 1.0:
Hfmt = r"{:.1f}"
Vfmt = r" {:.0f}"
if calbar[3] < 1.0:
Vfmt = r" {:.1f}"
if unitNames is not None:
Vfmt = Vfmt + r" " + r"{:s}".format(unitNames["y"])
Hfmt = Hfmt + r" " + r"{:s}".format(unitNames["x"])
# print(Vfmt, unitNames['y'])
# print(Vfmt.format(calbar[3]))
font = FontProperties()
font.set_family("sans-serif")
font.set_weight = weight
font.set_size = fontsize
font.set_style("normal")
if calbar is not None:
if orient == "left": # vertical part is on the left
ax.plot(
[calbar[0], calbar[0], calbar[0] + calbar[2]],
[calbar[1] + calbar[3], calbar[1], calbar[1]],
color="k",
linestyle="-",
linewidth=1.5,
)
ax.text(
calbar[0] + 0.05 * calbar[2],
calbar[1] + 0.5 * calbar[3],
Vfmt.format(calbar[3]),
horizontalalignment="left",
verticalalignment="center",
fontsize=fontsize,
weight=weight,
family="sans-serif",
)
elif orient == "right": # vertical part goes on the right
ax.plot(
[calbar[0] + calbar[2], calbar[0] + calbar[2], calbar[0]],
[calbar[1] + calbar[3], calbar[1], calbar[1]],
color="k",
linestyle="-",
linewidth=1.5,
)
ax.text(
calbar[0] + calbar[2] - 0.05 * calbar[2],
calbar[1] + 0.5 * calbar[3],
Vfmt.format(calbar[3]),
horizontalalignment="right",
verticalalignment="center",
fontsize=fontsize,
weight=weight,
family="sans-serif",
)
else:
print("PlotHelpers.py: I did not understand orientation: %s" % (orient))
print("plotting as if set to left... ")
ax.plot(
[calbar[0], calbar[0], calbar[0] + calbar[2]],
[calbar[1] + calbar[3], calbar[1], calbar[1]],
color="k",
linestyle="-",
linewidth=1.5,
)
ax.text(
calbar[0] + 0.05 * calbar[2],
calbar[1] + 0.5 * calbar[3],
Vfmt.format(calbar[3]),
horizontalalignment="left",
verticalalignment="center",
fontsize=fontsize,
weight=weight,
family="sans-serif",
)
ax.text(
calbar[0] + calbar[2] * 0.5,
calbar[1] - 0.1 * calbar[3],
Hfmt.format(calbar[2]),
horizontalalignment="center",
verticalalignment="top",
fontsize=fontsize,
weight=weight,
family="sans-serif",
)
def referenceline(
axl,
reference=None,
limits=None,
color="0.33",
linestyle="--",
linewidth=0.5,
dashes=None,
):
"""
draw a reference line at a particular level of the data on the y axis
returns the line object.
"""
axl = _ax_tolist(axl)
# if type(axl) is not list:
# axl = [axl]
if reference is None:
refeference = 0.0
for ax in axl:
if ax is None:
continue
if limits is None or type(limits) is not list or len(limits) != 2:
xlims = ax.get_xlim()
else:
xlims = limits
rl, = ax.plot(
[xlims[0], xlims[1]],
[reference, reference],
color=color,
linestyle=linestyle,
linewidth=linewidth,
)
if dashes is not None:
rl.set_dashes(dashes)
return rl
def crossAxes(axl, xyzero=[0.0, 0.0], limits=[None, None, None, None]):
"""
Make plot(s) with crossed axes at the data points set by xyzero, and optionally
set axes limits
"""
axl = _ax_tolist(axl)
# if type(axl) is not list:
# axl = [axl]
for ax in axl:
if ax is None:
continue
# ax.set_title('spines at data (1,2)')
# ax.plot(x,y)
ax.spines["left"].set_position(("data", xyzero[0]))
ax.spines["right"].set_color("none")
ax.spines["bottom"].set_position(("data", xyzero[1]))
ax.spines["top"].set_color("none")
ax.spines["left"].set_smart_bounds(True)
ax.spines["bottom"].set_smart_bounds(True)
ax.xaxis.set_ticks_position("bottom")
ax.yaxis.set_ticks_position("left")
if limits[0] is not None:
ax.set_xlim(left=limits[0], right=limits[2])
ax.set_ylim(bottom=limits[1], top=limits[3])
def violin_plot(ax, data, pos, bp=False, median=False):
"""
create violin plots on an axis
"""
dist = max(pos) - min(pos)
w = min(0.15 * max(dist, 1.0), 0.5)
for d, p in zip(data, pos):
k = gaussian_kde(d) # calculates the kernel density
m = k.dataset.min() # lower bound of violin
M = k.dataset.max() # upper bound of violin
x = np.arange(m, M, (M - m) / 100.0) # support for violin
v = k.evaluate(x) # violin profile (density curve)
v = v / v.max() * w # scaling the violin to the available space
ax.fill_betweenx(x, p, v + p, facecolor="y", alpha=0.3)
ax.fill_betweenx(x, p, -v + p, facecolor="y", alpha=0.3)
if median:
ax.plot([p - 0.5, p + 0.5], [np.median(d), np.median(d)], "-")
if bp:
bpf = ax.boxplot(data, notch=0, positions=pos, vert=1)
mpl.setp(bpf["boxes"], color="black")
mpl.setp(bpf["whiskers"], color="black", linestyle="-")
# # from somewhere on the web:
class NiceScale:
def __init__(self, minv, maxv):
self.maxTicks = 6
self.tickSpacing = 0
self.lst = 10
self.niceMin = 0
self.niceMax = 0
self.minPoint = minv
self.maxPoint = maxv
self.calculate()
def calculate(self):
self.lst = self.niceNum(self.maxPoint - self.minPoint, False)
self.tickSpacing = self.niceNum(self.lst / (self.maxTicks - 1), True)
self.niceMin = np.floor(self.minPoint / self.tickSpacing) * self.tickSpacing
self.niceMax = np.ceil(self.maxPoint / self.tickSpacing) * self.tickSpacing
def niceNum(self, lst, rround):
self.lst = lst
exponent = 0 # exponent of range */
fraction = 0 # fractional part of range */
niceFraction = 0 # nice, rounded fraction */
exponent = np.floor(np.log10(self.lst))
fraction = self.lst / np.power(10, exponent)
if self.lst:
if fraction < 1.5:
niceFraction = 1
elif fraction < 3:
niceFraction = 2
elif fraction < 7:
niceFraction = 5
else:
niceFraction = 10
else:
if fraction <= 1:
niceFraction = 1
elif fraction <= 2:
niceFraction = 2
elif fraction <= 5:
niceFraction = 5
else:
niceFraction = 10
return niceFraction * np.power(10, exponent)
def setMinMaxPoints(self, minPoint, maxPoint):
self.minPoint = minPoint
self.maxPoint = maxPoint
self.calculate()
def setMaxTicks(self, maxTicks):
self.maxTicks = maxTicks
self.calculate()
def circles(x, y, s, c="b", ax=None, vmin=None, vmax=None, **kwargs):
"""
Make a scatter of circles plot of x vs y, where x and y are sequence
like objects of the same lengths. The size of circles are in data scale.
Parameters
----------
x,y : scalar or array_like, shape (n, )
Input data
s : scalar or array_like, shape (n, )
Radius of circle in data scale (ie. in data unit)
c : color or sequence of color, optional, default : 'b'
`c` can be a single color format string, or a sequence of color
specifications of length `N`, or a sequence of `N` numbers to be
mapped to colors using the `cmap` and `norm` specified via kwargs.
Note that `c` should not be a single numeric RGB or
RGBA sequence because that is indistinguishable from an array of
values to be colormapped. `c` can be a 2-D array in which the
rows are RGB or RGBA, however.
ax : Axes object, optional, default: None
Parent axes of the plot. It uses gca() if not specified.
vmin, vmax : scalar, optional, default: None
`vmin` and `vmax` are used in conjunction with `norm` to normalize
luminance data. If either are `None`, the min and max of the
color array is used. (Note if you pass a `norm` instance, your
settings for `vmin` and `vmax` will be ignored.)
Returns
-------
paths : `~matplotlib.collections.PathCollection`
Other parameters
----------------
kwargs : `~matplotlib.collections.Collection` properties
eg. alpha, edgecolors, facecolors, linewidths, linestyles, norm, cmap
Examples
--------
a = np.arange(11)
circles(a, a, a*0.2, c=a, alpha=0.5, edgecolor='none')
License
--------
This code is under [The BSD 3-Clause License]
(http://opensource.org/licenses/BSD-3-Clause)
"""
# import matplotlib.colors as colors
if ax is None:
ax = mpl.gca()
if isinstance(c, str):
color = c # ie. use colors.colorConverter.to_rgba_array(c)
else:
color = None # use cmap, norm after collection is created
kwargs.update(color=color)
if np.isscalar(x):
patches = [Circle((x, y), s)]
elif np.isscalar(s):
patches = [Circle((x_, y_), s) for x_, y_ in zip(x, y)]
else:
patches = [Circle((x_, y_), s_) for x_, y_, s_ in zip(x, y, s)]
collection = PatchCollection(patches, **kwargs)
if color is None:
collection.set_array(np.asarray(c))
if vmin is not None or vmax is not None:
collection.set_clim(vmin, vmax)
ax.add_collection(collection)
ax.autoscale_view()
return collection
def rectangles(x, y, sw, sh=None, c="b", ax=None, vmin=None, vmax=None, **kwargs):
"""
Make a scatter of squares plot of x vs y, where x and y are sequence
like objects of the same lengths. The size of sqares are in data scale.
Parameters
----------
x,y : scalar or array_like, shape (n, )
Input data
s : scalar or array_like, shape (n, )
side of square in data scale (ie. in data unit)
c : color or sequence of color, optional, default : 'b'
`c` can be a single color format string, or a sequence of color
specifications of length `N`, or a sequence of `N` numbers to be
mapped to colors using the `cmap` and `norm` specified via kwargs.
Note that `c` should not be a single numeric RGB or
RGBA sequence because that is indistinguishable from an array of
values to be colormapped. `c` can be a 2-D array in which the
rows are RGB or RGBA, however.
ax : Axes object, optional, default: None
Parent axes of the plot. It uses gca() if not specified.
vmin, vmax : scalar, optional, default: None
`vmin` and `vmax` are used in conjunction with `norm` to normalize
luminance data. If either are `None`, the min and max of the
color array is used. (Note if you pass a `norm` instance, your
settings for `vmin` and `vmax` will be ignored.)
Returns
-------
paths : `~matplotlib.collections.PathCollection`
Other parameters
----------------
kwargs : `~matplotlib.collections.Collection` properties
eg. alpha, edgecolors, facecolors, linewidths, linestyles, norm, cmap
Examples
--------
a = np.arange(11)
squaress(a, a, a*0.2, c=a, alpha=0.5, edgecolor='none')
License
--------
This code is under [The BSD 3-Clause License]
(http://opensource.org/licenses/BSD-3-Clause)
"""
# import matplotlib.colors as colors
if ax is None:
ax = mpl.gca()
if isinstance(c, str):
color = c # ie. use colors.colorConverter.to_rgba_array(c)
else:
color = None # use cmap, norm after collection is created
kwargs.update(color=color)
if sh is None:
sh = sw
x = x - sw / 2.0 # offset as position specified is "lower left corner"
y = y - sh / 2.0
if np.isscalar(x):
patches = [Rectangle((x, y), sw, sh)]
elif np.isscalar(sw):
patches = [Rectangle((x_, y_), sw, sh) for x_, y_ in zip(x, y)]
else:
patches = [
Rectangle((x_, y_), sw_, sh_) for x_, y_, sw_, sh_ in zip(x, y, sw, sh)
]
collection = PatchCollection(patches, **kwargs)
if color is None:
collection.set_array(np.asarray(c))
if vmin is not None or vmax is not None:
collection.set_clim(vmin, vmax)
ax.add_collection(collection)
ax.autoscale_view()
return collection
def show_figure_grid(fig, figx=10.0, figy=10.0):
"""
Create a background grid with major and minor lines like graph paper
if using default figx and figy, the grid will be in units of the
overall figure on a [0,1,0,1] grid
if figx and figy are in units of inches or cm, then the grid
will be on that scale.
Figure grid is useful when building figures and placing labels
at absolute locations on the figure.
Parameters
----------
fig : Matplotlib figure handle (no default):
The figure to which the grid will be applied
figx : float (default: 10.)
# of major lines along the X dimension
figy : float (default: 10.)
# of major lines along the Y dimension
"""
backGrid = fig.add_axes([0, 0, 1, 1], frameon=False)
backGrid.set_ylim(0.0, figy)
backGrid.set_xlim(0.0, figx)
backGrid.grid(True)
backGrid.set_yticks(np.arange(0.0, figy + 0.01, 1.0))
backGrid.set_yticks(np.arange(0.0, figy + 0.01, 0.1), minor=True)
backGrid.set_xticks(np.arange(0.0, figx + 0.01, 1.0))
backGrid.set_xticks(np.arange(0.0, figx + 0.01, 0.1), minor=True)
# backGrid.get_xaxis().set_minor_locator(matplotlib.ticker.AutoMinorLocator())
# backGrid.get_yaxis().set_minor_locator(matplotlib.ticker.AutoMinorLocator())
backGrid.grid(b=True, which="major", color="g", alpha=0.6, linewidth=0.8)
backGrid.grid(b=True, which="minor", color="g", alpha=0.4, linewidth=0.2)
return backGrid
def hide_figure_grid(fig, grid):
grid.grid(False)
def delete_figure_grid(fig, grid):
mpl.delete(grid)
class Plotter:
"""
The Plotter class provides a simple convenience for plotting data in
an row x column array.
"""
def __init__(
self,
rcshape=None,
axmap=None,
arrangement=None,
title=None,
label=False,
roworder=True,
refline=None,
figsize=(11, 8.5),
fontsize=10,
position=0,
labeloffset=[0.0, 0.0],
labelsize=12,
):
"""
Create an instance of the plotter. Generates a new matplotlib figure,
and sets up an array of subplots as defined, initializes the counters
Examples
--------
Ex. 1:
One way to generate plots on a standard grid, uses gridspec to specify an axis map:
labels = ['A', 'B1', 'B2', 'C1', 'C2', 'D', 'E', 'F']
gr = [(0, 4, 0, 1), (0, 3, 1, 2), (3, 4, 1, 2), (0, 3, 2, 3), (3, 4, 2, 3), (5, 8, 0, 1), (5, 8, 1, 2), (5, 8, 2, 3)]
axmap = OrderedDict(zip(labels, gr))
P = PH.Plotter((8, 1), axmap=axmap, label=True, figsize=(8., 6.))
PH.show_figure_grid(P.figure_handle)
Ex. 2:
Place plots on defined locations on the page - no messing with gridspec or subplots.
For this version, we just generate N subplots with labels (used to tag each plot)
The "sizer" array then maps the tags to specific panel locations
# define positions for each panel in Figure coordinages (0, 1, 0, 1)
# you don't have to use an ordered dict for this, I just prefer it when debugging
sizer = OrderedDict([('A', [0.08, 0.22, 0.55, 0.4]), ('B1', [0.40, 0.25, 0.65, 0.3]), ('B2', [0.40, 0.25, 0.5, 0.1]),
('C1', [0.72, 0.25, 0.65, 0.3]), ('C2', [0.72, 0.25, 0.5, 0.1]),
('D', [0.08, 0.25, 0.1, 0.3]), ('E', [0.40, 0.25, 0.1, 0.3]), ('F', [0.72, 0.25, 0.1, 0.3]),
]) # dict elements are [left, width, bottom, height] for the axes in the plot.
gr = [(a, a+1, 0, 1) for a in range(0, 8)] # just generate subplots - shape does not matter
axmap = OrderedDict(zip(sizer.keys(), gr))
P = PH.Plotter((8, 1), axmap=axmap, label=True, figsize=(8., 6.))
PH.show_figure_grid(P.figure_handle)
P.resize(sizer) # perform positioning magic
P.axdict['B1'] access the plot associated with panel B1
Parameters
----------
rcshape : a list or tuple: 2x1 (no default)
rcshape is an array [row, col] telling us how many rows and columns to build.
default defines a rectangular array r x c of plots
a dict :
None: expect axmap to provide the input...
axmap :
list of gridspec slices (default : None)
define slices for the axes of a gridspec, allowing for non-rectangular arrangements
The list is defined as:
[(r1t, r1b, c1l, c1r), slice(r2, c2)]
where r1t is the top for row 1 in the grid, r1b is the bottom, etc...
When using this mode, the axarr returned is a 1-D list, as if r is all plots indexed,
and the number of columns is 1. The results match in order the list entered in axmap
arrangement: Ordered Dict (default: None)
Arrangement allows the data to be plotted according to a logical arrangement
The dict keys are the names ("groups") for each column, and the elements are
string names for the entities in the groups
title : string (default: None)
Provide a title for the entire plot
label : Boolean (default: False)
If True, sets labels on panels
roworder : Boolean (default: True)
Define whether labels run in row order first or column order first
refline : float (default: None)
Define the position of a reference line to be used in all panels
figsize : tuple (default : (11, 8.5))
Figure size in inches. Default is for a landscape figure
fontsize : points (default : 10)
Defines the size of the font to use for panel labels
position : position of spines (0 means close, 0.05 means break out)
x, y spines..
Returns
-------
Nothing
"""
self.arrangement = arrangement
self.fontsize = fontsize
self.referenceLines = {}
self.figure_handle = mpl.figure(figsize=figsize) # create the figure
self.figure_handle.set_size_inches(figsize[0], figsize[1], forward=True)
self.axlabels = []
self.axdict = (
OrderedDict()
) # make axis label dictionary for indirect access (better!)
if isinstance(fontsize, int):
fontsize = {"tick": fontsize, "label": fontsize, "panel": fontsize}
gridbuilt = False
# compute label offsets
p = [0.0, 0.0]
if label:
if type(labeloffset) is int:
p = [labeloffset, labeloffset]
elif type(labeloffset) is dict:
p = [position["left"], position["bottom"]]
elif type(labeloffset) in [list, tuple]:
p = labeloffset
else:
p = [0.0, 0.0]
# build axes arrays
# 1. nxm grid
if isinstance(rcshape, list) or isinstance(rcshape, tuple):
rc = rcshape
gs = gridspec.GridSpec(rc[0], rc[1]) # define a grid using gridspec
# assign to axarr
self.axarr = np.empty(
shape=(rc[0], rc[1]), dtype=object
) # use a numpy object array, indexing features
ix = 0
for r in range(rc[0]):
for c in range(rc[1]):
self.axarr[r, c] = mpl.subplot(gs[ix])
ix += 1
gridbuilt = True
# 2. specified values - starts with Nx1 subplots, then reorganizes according to shape boxes
elif isinstance(rcshape, dict): # true for OrderedDict also
nplots = len(list(rcshape.keys()))
gs = gridspec.GridSpec(nplots, 1)
rc = (nplots, 1)
self.axarr = np.empty(
shape=(rc[0], rc[1]), dtype=object
) # use a numpy object array, indexing features
ix = 0
for r in range(rc[0]):
for c in range(rc[1]):
self.axarr[r, c] = mpl.subplot(gs[ix])
ix += 1
gridbuilt = True
for k, pk in enumerate(rcshape.keys()):
self.axdict[pk] = self.axarr[k, 0]
plo = labeloffset
self.axlabels = labelPanels(
self.axarr.tolist(),
axlist=list(rcshape.keys()),
xy=(-0.095 + plo[0], 0.95 + plo[1]),
fontsize=fontsize["panel"],
)
self.resize(rcshape)
else:
raise ValueError("Input rcshape must be list/tuple or dict")
# create sublots
if axmap is not None:
if isinstance(axmap, list) and not gridbuilt:
self.axarr = np.empty(shape=(len(axmap), 1), dtype=object)
for k, g in enumerate(axmap):
self.axarr[k,] = mpl.subplot(gs[g[0] : g[1], g[2] : g[3]])
elif isinstance(axmap, dict) or isinstance(
axmap, OrderedDict
): # keys are panel labels
if not gridbuilt:
self.axarr = np.empty(
shape=(len(list(axmap.keys())), 1), dtype=object
)
na = np.prod(self.axarr.shape) # number of axes
for k, pk in enumerate(axmap.keys()):
g = axmap[pk] # get the gridspec info
if not gridbuilt:
self.axarr[k,] = mpl.subplot(gs[g[0] : g[1], g[2] : g[3]])
self.axdict[pk] = self.axarr.ravel()[k]
else:
raise TypeError("Plotter in PlotHelpers: axmap must be a list or dict")
if len(self.axdict) == 0:
for i, a in enumerate(self.axarr.flatten()):
label = string.ascii_uppercase[i]
self.axdict[label] = a
if title is not None:
self.figure_handle.canvas.set_window_title(title)
self.figure_handle.suptitle(title)
self.nrows = self.axarr.shape[0]
if len(self.axarr.shape) > 1:
self.ncolumns = self.axarr.shape[1]
else:
self.ncolumns = 1
self.row_counter = 0
self.column_counter = 0
for i in range(self.nrows):
for j in range(self.ncolumns):
self.axarr[i, j].spines["top"].set_visible(False)
self.axarr[i, j].get_xaxis().set_tick_params(
direction="out", width=0.8, length=4.0
)
self.axarr[i, j].get_yaxis().set_tick_params(
direction="out", width=0.8, length=4.0
)
self.axarr[i, j].tick_params(
axis="both", which="major", labelsize=fontsize["tick"]
)
# if i < self.nrows-1:
# self.axarr[i, j].xaxis.set_major_formatter(mpl.NullFormatter())
nice_plot(self.axarr[i, j], position=position)
if refline is not None:
self.referenceLines[self.axarr[i, j]] = referenceline(
self.axarr[i, j], reference=refline
)
if label:
if isinstance(axmap, dict) or isinstance(
axmap, OrderedDict
): # in case predefined...
self.axlabels = labelPanels(
self.axarr.ravel().tolist(),
axlist=list(axmap.keys()),
xy=(-0.095 + p[0], 0.95 + p[1]),
fontsize=fontsize["panel"],
)
return
self.axlist = []
if roworder == True:
for i in range(self.nrows):
for j in range(self.ncolumns):
self.axlist.append(self.axarr[i, j])
else:
for i in range(self.ncolumns):
for j in range(self.nrows):
self.axlist.append(self.axarr[j, i])
if self.nrows * self.ncolumns > 26: # handle large plot using "A1..."
ctxt = string.ascii_uppercase[0 : self.ncolumns] # columns are lettered
rtxt = [
str(x + 1) for x in range(self.nrows)
] # rows are numbered, starting at 1
axl = []
for i in range(self.nrows):
for j in range(self.ncolumns):
axl.append(ctxt[j] + rtxt[i])
self.axlabels = labelPanels(
self.axlist, axlist=axl, xy=(-0.35 + p[0], 0.75)
)
else:
self.axlabels = labelPanels(
self.axlist, xy=(-0.095 + p[0], 0.95 + p[1])
)
def _next(self):
"""
Private function
_next gets the axis pointer to the next row, column index that is available
Only sets internal variables
"""
self.column_counter += 1
if self.column_counter >= self.ncolumns:
self.row_counter += 1
self.column_counter = 0
if self.row_counter >= self.nrows:
raise ValueError(
"Call to get next row exceeds the number of rows requested initially: %d"
% self.nrows
)
def getaxis(self, group=None):
"""
getaxis gets the current row, column counter, and calls _next to increment the counter
(so that the next getaxis returns the next available axis pointer)
Parameters
----------
group : string (default: None)
forces the current axis to be selected from text name of a "group"
Returns
-------
the current axis or the axis associated with a group
"""
if group is None:
currentaxis = self.axarr[self.row_counter, self.column_counter]
self._next() # prepare for next call
else:
currentaxis = self.getRC(group)
return currentaxis
def getRC(self, group):
"""
Get the axis associated with a group
Parameters
----------
group : string (default: None)
returns the matplotlib axis associated with a text name of a "group"
Returns
-------
The matplotlib axis associated with the group name, or None if no group by
that name exists in the arrangement
"""
if self.arrangement is None:
raise ValueError("specifying a group requires an arrangment dictionary")
# look for the group label in the arrangement dicts
for c, colname in enumerate(self.arrangement.keys()):
if group in self.arrangement[colname]:
# print ('column name, column: ', colname, self.arrangement[colname])
# print ('group: ', group)
r = self.arrangement[colname].index(
group
) # get the row position this way
return self.axarr[r, c]
print(("Group {:s} not in the arrangement".format(group)))
return None
sizer = {
"A": {"pos": [0.08, 0.22, 0.50, 0.4]},
"B1": {"pos": [0.40, 0.25, 0.60, 0.3]},
"B2": {"pos": [0.40, 0.25, 0.5, 0.1]},
"C1": {"pos": [0.72, 0.25, 0.60, 0.3]},
"C2": {"pos": [0.72, 0.25, 0.5, 0.1]},
"D": {"pos": [0.08, 0.25, 0.1, 0.3]},
"E": {"pos": [0.40, 0.25, 0.1, 0.3]},
"F": {"pos": [0.72, 0.25, 0.1, 0.3]},
}
def resize(self, sizer):
"""
Resize the graphs in the array.
Parameters
----------
sizer : dict (no default)
A dictionary with keys corresponding to the plot labels.
The values for each key are a list (or tuple) of [left, width, bottom, height]
for each panel in units of the graph [0, 1, 0, 1].
sizer = {'A': {'pos': [0.08, 0.22, 0.50, 0.4], 'labelpos': (x,y), 'noaxes': True}, 'B1': {'pos': [0.40, 0.25, 0.60, 0.3], 'labelpos': (x,y)},
'B2': {'pos': [0.40, 0.25, 0.5, 0.1],, 'labelpos': (x,y), 'noaxes': False},
'C1': {'pos': [0.72, 0.25, 0.60, 0.3], 'labelpos': (x,y)}, 'C2': {'pos': [0.72, 0.25, 0.5, 0.1], 'labelpos': (x,y)},
'D': {'pos': [0.08, 0.25, 0.1, 0.3], 'labelpos': (x,y)},
'E': {'pos': [0.40, 0.25, 0.1, 0.3], 'labelpos': (x,y)}, 'F': {'pos': [0.72, 0.25, 0.1, 0.3],, 'labelpos': (x,y)}
}
Returns
-------
Nothing
"""
for i, s in enumerate(sizer.keys()):
ax = self.axdict[s]
bbox = ax.get_position()
bbox.x0 = sizer[s]["pos"][0]
bbox.x1 = sizer[s]["pos"][1] + sizer[s]["pos"][0]
bbox.y0 = sizer[s]["pos"][2]
bbox.y1 = (
sizer[s]["pos"][3] + sizer[s]["pos"][2]
) # offsets are in figure fractions
ax.set_position(bbox)
if "labelpos" in list(sizer[s].keys()) and len(sizer[s]["labelpos"]) == 2:
x, y = sizer[s]["labelpos"]
self.axlabels[i].set_x(x)
self.axlabels[i].set_y(y)
if "noaxes" in sizer[s] and sizer[s]["noaxes"] == True:
noaxes(ax)
if __name__ == "__main__":
# P = Plotter((3,3), axmap=[(0, 1, 0, 3), (1, 2, 0, 2), (2, 1, 2, 3), (2, 3, 0, 1), (2, 3, 1, 2)])
labels = ["A", "B", "C", "D", "E", "F", "G", "H", "I"]
l = [(a, a + 2, 0, 1) for a in range(0, 6, 2)]
r = [(a, a + 1, 1, 2) for a in range(0, 6)]
axmap = OrderedDict(list(zip(labels, l + r)))
P = Plotter((6, 2), axmap=axmap, figsize=(6.0, 6.0), label=True)
# P = Plotter((2,3), label=True) # create a figure with plots
# for a in P.axarr.flatten():
# a.plot(np.random.random(10), np.random.random(10))
# hfig, ax = mpl.subplots(2, 3)
axd = OrderedDict()
for i, a in enumerate(P.axarr.flatten()):
label = string.ascii_uppercase[i]
axd[label] = a
for a in list(axd.keys()):
axd[a].plot(np.random.random(10), np.random.random(10))
nice_plot([axd[a] for a in axd], position=-0.1)
cleanAxes([axd["B"], axd["C"]])
calbar([axd["B"], axd["C"]], calbar=[0.5, 0.5, 0.2, 0.2])
# labelPanels([axd[a] for a in axd], axd.keys())
# mpl.tight_layout(pad=2, w_pad=0.5, h_pad=2.0)
mpl.show()