copying to personal repo

This commit is contained in:
Alan
2022-06-19 13:45:53 -05:00
commit bf2ffa7315
287 changed files with 54032 additions and 0 deletions

69
cnmodel/util/Params.py Executable file
View File

@@ -0,0 +1,69 @@
# -*- encoding: utf-8 -*-
import sys
import os
import unittest
class Params(object):
def __init__(self, **kwds):
"""
Utility class to create parameter lists
create using:
p = Params(abc=2.0, defg = 3.0, lunch='sandwich')
reference using:
p.abc, p.defg, etc.
Supports getting the keys, finding whether a key exists, returning the strucure as a simple dictionary,
and printing (show) the parameter structure.
"""
self.__dict__.update(kwds)
def additem(self, key, value):
self.__dict__[key] = value
def getkeys(self):
"""
Get the keys in the current dictionary
"""
return self.__dict__.keys()
def haskey(self, key):
"""
Find out if the param list has a specific key in it
"""
if key in self.__dict__.keys():
return True
else:
return False
def todict(self):
"""
convert param list to standard dictionary
Useful when writing the data
"""
r = {}
for dictelement in self.__dict__:
if isinstance(self.__dict__[dictelement], Params):
# print 'nested: ', dictelement
r[dictelement] = self.__dict__[dictelement].todict()
else:
r[dictelement] = self.__dict__[dictelement]
return r
def show(self, printFlag=True):
"""
print the parameter block created in Parameter Init
"""
print("-------- Parameter Block ----------")
for key in self.__dict__.keys():
print("%15s = " % (key), eval("self.%s" % key))
print("-------- ---------------------- ----------")
class ParamTests(unittest.TestCase):
def setUp(self):
pass
if __name__ == "__main__":
unittest.main()

1415
cnmodel/util/PlotHelpers.py Executable file

File diff suppressed because it is too large Load Diff

9
cnmodel/util/__init__.py Normal file
View File

@@ -0,0 +1,9 @@
from .stim import *
from .find_point import *
from .pynrnutilities import *
from .nrnutils import *
from .expfitting import *
from .user_tester import UserTester
from .get_anspikes import *
from .Params import *
from .talbotetalTicks import Extended

125
cnmodel/util/ccstim.py Normal file
View File

@@ -0,0 +1,125 @@
__author__ = "pbmanis"
"""
ccstim
Generate current-clamp (or voltage-clamp) stimulus waveforms from a dictionary
used for vectory play modes in current clamp
(prior version was called 'makestim')
Can generate several types of pulses
"""
import numpy as np
def ccstim(stim, dt, pulsetype="square"):
"""
Create stimulus pulse waveforms of different types.
Parameters
----------
stim : dict
a dictionary with keys [required]
delay (delay to start of pulse train, msec [all]
duration: duration of pulses in train, msec [all]
Sfreq: stimulus train frequency (Hz) [timedSpikes]
PT: post-train test delay [all]
NP: number of pulses in the train [timedSpikes, exp]
amp: amplitude of the pulses in the train [all]
hypamp: amplitude of prehyperpolarizing pulse [hyp]
hypdur: duration of prehyperpolarizing pulse [hyp]
spikeTimes" times of spikes [timedSpikes]
dt : time (microseconds) [required]
step time, in microseconds. Required parameter
pulsetype : string (default: 'square')
Type of pulse to generate: one of square, hyp, timedspikes or exp
square produces a train of "square" (retangular) pulses levels 0 and ampitude
hyp is like square, but precedes the pulse train with a single prepulse
of hypamp and hypdur
timedspikes is like square, excpet the pulses are generated at times specified
in the spikeTimes key in the stim dictionary
exp: pulses with an exponential decay.
TO DO:
add pulsetypes, including sine wave, rectified sine wave, etc.
Returns
-------
list containing [waveform (numpy array),
maxtime(float),
timebase (numpy array)]
"""
assert dt is not None
assert "delay" in stim.keys()
delay = int(np.floor(stim["delay"] / dt))
if pulsetype in ["square", "hyp", "exp"]:
ipi = int(np.floor((1000.0 / stim["Sfreq"]) / dt))
pdur = int(np.floor(stim["duration"] / dt))
posttest = int(np.floor(stim["PT"] / dt))
if pulsetype not in ["timedSpikes"]:
NP = int(stim["NP"])
else:
NP = len(stim["spikeTimes"])
tstims = [0] * NP
if pulsetype == "hyp":
assert "hypamp" in stim.keys()
assert "hypdur" in stim.keys()
hypdur = int(np.floor(stim["hypdur"] / dt))
delay0 = delay # save original delay
if pulsetype in ["square", "hyp"]:
maxt = dt * (stim["delay"] + (ipi * (NP + 2)) + posttest + pdur * 2)
if pulsetype == "hyp":
maxt = maxt + dt * stim["hypdur"]
delay = delay + hypdur
w = np.zeros(int(np.floor(maxt / dt)))
for j in range(0, NP):
t = (delay + j * ipi) * dt
w[delay + ipi * j : delay + (ipi * j) + pdur] = stim["amp"]
tstims[j] = delay + ipi * j
if stim["PT"] > 0.0:
send = delay + ipi * j
for i in range(send + posttest, send + posttest + pdur):
w[i] = stim["amp"]
if pulsetype == "hyp": # fill in the prepulse now
for i in range(delay0, delay0 + hypdur):
w[i] = stim["hypamp"]
if pulsetype == "timedSpikes":
maxt = np.max(stim["spikeTimes"]) + stim["PT"] + stim["duration"] * 2
w = np.zeros(int(np.floor(maxt / dt)))
for j in range(len(stim["spikeTimes"])):
st = delay + int(np.floor(stim["spikeTimes"][j] / dt))
t = st * dt
w[st : st + pdur] = stim["amp"]
tstims[j] = st
if stim["PT"] > 0.0:
for i in range(st + posttest, st + posttest + pdur):
w[i] = stim["amp"]
if pulsetype == "exp":
maxt = dt * (stim["delay"] + (ipi * (NP + 2)) + posttest + pdur * 2)
w = np.zeros(int(np.floor(maxt / dt)))
for j in range(0, NP):
for i in range(0, len(w)):
if delay + ipi * j + i < len(w):
w[delay + ipi * j + i] = w[delay + ipi * j + i] + stim["amp"] * (
(1.0 - np.exp(-i / (pdur / 3.0)))
* np.exp(-(i - (pdur / 3.0)) / pdur)
)
tstims[j] = delay + ipi * j
if stim["PT"] > 0.0:
send = delay + ipi * j
for i in range(send + posttest, len(w)):
w[i] += (
stim["amp"]
* (1.0 - np.exp(-i / (pdur / 3.0)))
* np.exp(-(i - (pdur / 3.0)) / pdur)
)
return (w, maxt, tstims)

View File

@@ -0,0 +1,658 @@
"""
Test synaptic connections between two different cell types.
Usage: python compare_simple_multisynapses.py <pre_celltype> <post_celltype>
This script:
1. Creates single pre- and postsynaptic cells
2. Creates a single synaptic terminal between the two cells, using the multisite synapse method.
3. Stimulates the presynaptic cell by current injection.
4. Records and analyzes the resulting post-synaptic events.
5. Repeats 3, 4 100 times to get an average postsynaptic event.
6. stores the resulting waveform in a pandas database
This is used mainly to check that the strength, kinetics, and dynamics of
each synapse type is working as expected.
Requires Python 3.6
"""
import sys
import argparse
from pathlib import Path
import numpy as np
# import pyqtgraph as pg
import matplotlib.pyplot as mpl
import neuron as h
import cnmodel.util.PlotHelpers as PH
from cnmodel.protocols import SynapseTest
from cnmodel import cells
from cnmodel.synapses import Synapse
import pickle
import lmfit
convergence = {
"sgc": {
"bushy": 1,
"tstellate": 1,
"dstellate": 1,
"octopus": 1,
"pyramidal": 1,
"tuberculoventral": 1,
},
"dstellate": {
"bushy": 1,
"tstellate": 1,
"dstellate": 1,
"pyramidal": 1,
"tuberculoventral": 1,
},
"tuberculoventral": {
"bushy": 1,
"tstellate": 1,
"pyramidal": 1,
"tuberculoventral": 1,
},
}
class Exp2SynFitting:
"""
Fit waveform against Exp2SYN function (used by Neuron)
THe function is (from the documentation):
i = G * (v - e) i(nanoamps), g(micromhos);
G = weight * factor * (exp(-t/tau2) - exp(-t/tau1))
WHere factor is evaluated when initializing(in neuron; see exp2syn.mod source) as:
tp = (tau1*tau2)/(tau2 - tau1) * log(tau2/tau1) (log10)
factor = -exp(-tp/tau1) + exp(-tp/tau2)
factor = 1/factor
Parameters
----------
initpars : dict
dict of initial parameters. For example: {'tau1': 0.1,
'tau2': 0.3, 'weight': 0.1, 'delay' : 0.0, 'erev': -80.} (erev in mV)
bounds : dict
dictionary of bounds for each parameter, with a list of lower and upper values.
"""
def __init__(self, initpars=None, bounds=None, functype="exp2syn"):
self.fitpars = lmfit.Parameters()
if functype == "exp2syn": # future handle other functions like alpha
# (Name, Value, Vary, Min, Max, Expr)
self.fitpars.add_many(
("tau1", initpars["tau1"], True, 0.05, 25.0, None),
# ('tau2', initpars['tau2'], True, 0.1, 50., None),
("tauratio", initpars["tauratio"], True, 1.0001, 100.0, None),
("weight", initpars["weight"], True, 1e-6, 1, None),
("erev", initpars["erev"], False), # do not adjust!
("v", initpars["v"], False),
("delay", initpars["delay"], True, 0.0, 5.0, None),
)
self.func = self.exp2syn_err
else:
raise ValueError
def fit(self, x, y, p, verbose=False):
"""
Perform the curve fit against the specified function
Parameters
----------
x : time base for waveform (np array or list)
y : waveform (1-d np array or list)
p : parameters in lmfit Parameters structure
verbose : boolean (default: False)
If true, print the parameters in a nice format
"""
kws = {"maxfev": 5000}
# print('p: ', p)
self.mim = lmfit.minimize(
self.func, p, method="least_squares", args=(x, y)
) # , kws=kws)
if verbose:
lmfit.printfuncs.report_fit(self.mim.params)
fitpars = self.mim.params
return fitpars
# @staticmethod
def exp2syn(self, x, tau1, tauratio, weight, erev, v, delay):
"""
Compute the exp2syn waveform current as it is done in Neuron
Note on units:
The units are assumed to be "self consistent"
Thus, if the time base is in msec, tau1, tau2 and the delay are
in msec. erev and v should be in matching units (e.g., mV)
Note thtat the weight is not equal to the conducdtance G, because
of the scaling of the waveform
Parameters
----------
x : time base
tau1 : rising tau for waveform
tau2 : falling tau for waveform
weight : amplitude of the waveform (conductance)
erev : reversal potential for ionic species (used to compute i)
v : holding voltage at which waveform is computed
delay : delay to start of function
Returns
-------
i, the calucated current trace for these parameters.
"""
# we handle the requirement that tau2 > tau1 by setting the
# expression in lmfit, and using tauratio rather than a direct
# tau2.
# if tau1/tau2 > 1.0: # make sure tau1 is less than tau2
# tau1 = 0.999*tau2
tau2 = tauratio * tau1
tp = (tau1 * tau2) / (tau2 - tau1) * np.log(tau2 / tau1)
factor = -np.exp(-tp / tau1) + np.exp(-tp / tau2)
factor = 1.0 / factor
G = (
weight
* factor
* (np.exp(-(x - delay) / tau2) - np.exp(-(x - delay) / tau1))
)
G[x - delay < 0] = 0.0
i = G * (v - erev) # i(nanoamps), g(micromhos);
return i
def exp2syn_err(self, p, x, y):
return np.fabs(
y - self.exp2syn(x, **dict([(k, p.value) for k, p in p.items()]))
)
def factor(self, tau1, tauratio, weight):
"""
calculate tau-scaled weight
"""
tau2 = tau1 * tauratio
tp = (tau1 * tau2) / (tau2 - tau1) * np.log(tau2 / tau1)
factor = -np.exp(-tp / tau1) + np.exp(-tp / tau2)
factor = 1.0 / factor
G = weight * factor
return G
def testexp():
"""
Test the exp2syn fitting function
"""
pars = {
"tau1": 0.1,
"tauratio": 2.0,
"weight": 0.1,
"erev": -70.0,
"v": -65.0,
"delay": 1,
}
F = Exp2SynFitting(
initpars={
"tau1": 0.2,
"tauratio": 2.0,
"weight": 0.1,
"erev": -70.0,
"v": -65.0,
"delay": 0,
}
)
t = np.arange(0, 10.0, 0.01)
p = F.fitpars
target = F.exp2syn(
t,
pars["tau1"],
pars["tauratio"],
pars["weight"],
pars["erev"],
pars["v"],
pars["delay"],
)
# print(F.fitpars)
pars_fit = F.fit(t, target, F.fitpars)
print("\nTest fit result: ")
lmfit.printfuncs.report_fit(F.mim.params)
print("( tau2 = ", pars_fit["tau1"].value * pars_fit["tauratio"].value, ")")
print("target parameters: ", pars)
print("\n")
def compute_psc(synapsetype="multisite", celltypes=["sgc", "tstellate"]):
"""
Compute the PSC between the specified two cell tpes
The type of PSC is set by synspase type and must be 'multisite' or 'simple'
"""
assert synapsetype in ["multisite", "simple"]
c = []
for cellType in celltypes:
if cellType == "sgc":
cell = cells.SGC.create()
elif cellType == "tstellate":
cell = cells.TStellate.create(debug=True, ttx=False)
elif (
cellType == "dstellate"
): # Type I-II Rothman model, similiar excitability (Xie/Manis, unpublished)
cell = cells.DStellate.create(model="RM03", debug=True, ttx=False)
# elif cellType == 'dstellate_eager': # From Eager et al.
# cell = cells.DStellate.create(model='Eager', debug=True, ttx=False)
elif cellType == "bushy":
cell = cells.Bushy.create(debug=True, ttx=True)
elif cellType == "octopus":
cell = cells.Octopus.create(debug=True, ttx=True)
elif cellType == "tuberculoventral":
cell = cells.Tuberculoventral.create(debug=True, ttx=True)
elif cellType == "pyramidal":
cell = cells.Pyramidal.create(debug=True, ttx=True)
elif cellType == "cartwheel":
cell = cells.Cartwheel.create(debug=True, ttx=True)
else:
raise ValueError("Unknown cell type '%s'" % cellType)
c.append(cell)
preCell, postCell = c
print(f"Computing psc for connection {celltypes[0]:s} -> {celltypes[1]:s}")
nTerminals = convergence.get(celltypes[0], {}).get(celltypes[1], None)
if nTerminals is None:
nTerminals = 1
print(
f"Warning: Unknown convergence for {celltypes[0]:s} -> {celltypes[1]:s}, ASSUMING {nTerminals:d} terminals"
)
if celltypes == ["sgc", "bushy"]:
niter = 50
else:
niter = 200
assert synapsetype in ["simple", "multisite"]
if synapsetype == "simple":
niter = 1
st = SynapseTest()
dt = 0.010
stim = {
"NP": 1,
"Sfreq": 100.0,
"delay": 0.0,
"dur": 0.5,
"amp": 10.0,
"PT": 0.0,
"dt": dt,
}
st.run(
preCell.soma,
postCell.soma,
nTerminals,
dt=dt,
vclamp=-65.0,
iterations=niter,
synapsetype=synapsetype,
tstop=50.0,
stim_params=stim,
)
# st.show_result() # pyqtgraph plotting -
# st.plots['VPre'].setYRange(-70., 10.)
# st.plots['EPSC'].setYRange(-2.0, 0.5)
# st.plots['latency2080'].setYRange(0., 1.0)
# st.plots['halfwidth'].setYRange(0., 1.0)
# st.plots['RT'].setYRange(0., 0.2)
# st.plots['latency'].setYRange(0., 1.0)
# st.plots['latency_distribution'].setYRange(0., 1.0)
return st # need to keep st alive in memory
def fit_one(st, stk):
"""
Fit one trace to the exp2syn function
Parameters
----------
st : dict (no default)
A dictionary containing (at least) the following keys:
't' : the time base for the trace to be fit
'i' : a list of arrays or a 2d numpy array of the data to be fit
Returns
-------
pars : The fitting parameters
The fitted parameters are returned as an lmfit Parameters object, so access
individual parameters a pars['parname'].value
fitted : The fitted waveform, of the same length as the source data waveform
"""
if stk[0] in ["sgc", "tstellate", "granule"]:
erev = (
0.0
) # set erev according to the source cell (excitatory: 0, inhibitory: -80)
else:
erev = -70.0 # value used in gly_psd
print("\nstk, erev: ", stk, erev)
F = Exp2SynFitting(
initpars={
"tau1": 1.0,
"tauratio": 5.0,
"weight": 0.0001,
"erev": erev,
"v": -65.0,
"delay": 0,
}
)
t = st["t"]
p = F.fitpars
target = np.mean(np.array(st["i"]), axis=0)
# print(F.fitpars)
pars = F.fit(t, target, F.fitpars)
gw = F.factor(pars["tau1"].value, pars["tauratio"].value, pars["weight"].value)
pars.add("GWeight", value=gw, vary=False)
fitted = F.exp2syn(
st["t"],
pars["tau1"],
pars["tauratio"],
pars["weight"],
pars["erev"],
pars["v"],
pars["delay"],
)
lmfit.printfuncs.report_fit(F.mim.params)
return (pars, fitted)
def fit_all():
"""
Fit exp2syn against the traces in stm
This fits all pre-post cell pairs, and returns the fit
"""
stm = read_pickle("multisite.pkl")
fits = {}
fitted = {}
# print('stm keys: ', stm.keys())
# exit()
for i, stk in enumerate(stm.keys()): # for each pre-post cell pair
fitp, fit = fit_one(stm[stk], stk)
fits[stk] = fitp
fitted[stk] = {"t": stm[stk]["t"], "i": [fit], "pars": fitp}
with (open("simple.pkl", "wb")) as fh:
pickle.dump(fitted, fh)
return fits
def plot_all(stm, sts=None):
P = PH.Plotter((3, 5), figsize=(11, 6))
ax = P.axarr.ravel()
keypairorder = []
for i, stk in enumerate(stm.keys()):
keypairorder.append(stk)
data = stm[stk]
idat = np.array(data["i"])
ax[i].plot(np.array(data["t"]), np.mean(idat, axis=0), "c-", linewidth=1.5)
sd = np.std(idat, axis=0)
ax[i].plot(
np.array(data["t"]), np.mean(idat, axis=0) + sd, "c--", linewidth=0.5
) # for j in range(idat.shape[0]):
ax[i].plot(
np.array(data["t"]), np.mean(idat, axis=0) - sd, "c--", linewidth=0.5
) # for j in range(idat.shape[0]):
rel = 0
for j in range(idat.shape[0]):
if np.min(idat[j, :]) < -1e-2 or np.max(idat[j, :]) > 1e-2:
rel += 1
print(f"{str(stk):s} {rel:d} of {idat.shape[0]:d} {str(idat.shape):s}")
# ax[i].plot(data['t'], idat[j], 'k-', linewidth=0.5, alpha=0.25)
ax[i].set_title("%s : %s" % (stk[0], stk[1]), fontsize=7)
if sts is not None: # plot the matching exp2syn on this
for i, stks in enumerate(keypairorder):
datas = sts[stks]
isdat = np.array(datas["i"])
ax[i].plot(datas["t"], np.mean(isdat, axis=0), "m-", linewidth=3, alpha=0.5)
mpl.show()
def print_exp2syn_fits(st):
stkeys = list(st.keys())
postcells = [
"bushy",
"tstellate",
"dstellate",
"octopus",
"pyramidal",
"tuberculoventral",
"cartwheel",
]
fmts = {
"weight": "{0:<18.6f}",
"GWeight": "{0:<18.6f}",
"tau1": "{0:<18.3f}",
"tau2": "{0:<18.3f}",
"delay": "{0:<18.3f}",
"erev": "{0:<18.1f}",
}
pars = list(fmts.keys())
for pre in ["sgc", "dstellate", "tuberculoventral"]:
firstline = False
vrow = dict.fromkeys(pars, False)
for v in pars:
for post in postcells:
if (pre, post) not in stkeys:
print("{0:<18s}".format("0"), end="")
continue
if not firstline:
print("\n%s" % pre.upper())
print("{0:<18s}".format(" "), end="")
for i in range(len(postcells)):
print("{0:<18s}".format(postcells[i]), end="")
print()
firstline = True
if firstline:
if not vrow[v]:
if v == "tauratio":
print("{0:<18s}".format("tau2"), end="")
else:
print("{0:<18s}".format(v), end="")
vrow[v] = True
# print(fits[(pre, post)].keys())
fits = st[(pre, post)]["pars"]
# print(v, fits)
if v in ["tauratio", "tau2"]:
print(
fmts[v].format(fits["tau1"].value * fits["tauratio"].value),
end="",
)
else:
print(fmts[v].format(fits[v].value), end="")
print()
def read_pickle(mode):
"""
Read the pickled file generated by runall
Parameters
----------
mode : str
either 'simple' or 'multisite'
Returns:
the resulting data, which is a dictionary
"""
with (open(Path(mode).with_suffix(".pkl"), "rb")) as fh:
st = pickle.load(fh)
return st
def run_all():
"""
Run all of the multisite synapse calculations for each cell pair
Save the results in the multisite.pkl file
"""
st = {}
pkst = {}
for pre in convergence.keys():
for post in convergence[pre]:
# if pre == 'sgc' and post in ['bushy', 'tstellate']:
sti = compute_psc(synapsetype="multisite", celltypes=[pre, post])
st[(pre, post)] = sti
# remove neuron objects before pickling
pkst[(pre, post)] = {
"t": sti["t"],
"i": sti.isoma,
"v": sti["v_soma"],
"pre": sti["v_pre"],
}
with (open(Path("multisite").with_suffix(".pkl"), "wb")) as fh:
pickle.dump(pkst, fh)
plot_all(pkst)
def run_one(pre, post):
"""
Run all of the multisite synapse calculations for each cell pair
Save the results in the multisite.pkl file
"""
assert pre in ["sgc", "tstellate", "dstellate", "tuberculoventral"]
assert post in ["bushy", "tstellate", "dstellate", "tuberculoventral", "pyramidal"]
st = {}
pkst = read_pickle("multisite")
# if pre == 'sgc' and post in ['bushy', 'tstellate']:
sti = compute_psc(synapsetype="multisite", celltypes=[pre, post])
st[(pre, post)] = sti
# remove neuron objects before pickling
pkst[(pre, post)] = {
"t": sti["t"],
"i": sti.isoma,
"v": sti["v_soma"],
"pre": sti["v_pre"],
}
with (open(Path("multisite").with_suffix(".pkl"), "wb")) as fh:
pickle.dump(pkst, fh)
plot_all(pkst)
def main():
parser = argparse.ArgumentParser(
description="Compare simple and multisite synapses"
)
parser.add_argument(
"-m",
"--mode",
type=str,
dest="mode",
default="None",
choices=["simple", "multisite", "both"],
help="Select mode [simple, multisite, compare]",
)
parser.add_argument(
"-r",
"--run",
action="store_true",
dest="run",
help="Run multisite models between all cell types",
)
parser.add_argument(
"-o",
"--runone",
action="store_true",
dest="runone",
help="Run multisite models between specified cell types",
)
parser.add_argument(
"--pre",
type=str,
dest="pre",
default="None",
help="Select presynaptic cell for runone",
)
parser.add_argument(
"--post",
type=str,
dest="post",
default="None",
help="Select postsynaptic cell for runone",
)
parser.add_argument(
"-f",
"--fit",
action="store_true",
dest="fit",
help="Fit exp2syn waveforms to multisite data",
)
parser.add_argument(
"-t",
"--test",
action="store_true",
dest="test",
help="Run the test on exp2syn fitting",
)
parser.add_argument(
"-p",
"--plot",
action="store_true",
dest="plot",
help="Plot the current comparison between simple and multisite",
)
parser.add_argument(
"-l",
"--list",
action="store_true",
dest="list",
help="List the simple fit parameters",
)
args = parser.parse_args()
if args.test:
testexp()
exit()
if args.run:
run_all()
exit()
if args.runone:
run_one(args.pre, args.post)
exit()
if args.fit:
fit_all() # self contained - always fits exp2syn against the current multistie data
ds = read_pickle("simple")
dm = read_pickle("multisite")
plot_all(dm, ds)
exit()
if args.plot:
if args.mode in ["simple", "multisite"]:
d = read_pickle(args.mode)
plot_all(d)
exit()
elif args.mode in ["both"]:
ds = read_pickle("simple")
dm = read_pickle("multisite")
plot_all(dm, ds)
exit()
else:
print(f"Mode {args.mode:s} is not valid")
exit()
if args.list:
ds = read_pickle("simple")
print_exp2syn_fits(ds)
# if sys.flags.interactive == 0:
# pg.QtGui.QApplication.exec_()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,131 @@
# -*- coding: utf-8 -*-
from pyqtgraph.Qt import QtGui, QtCore
from pyqtgraph.pgcollections import OrderedDict
from .TableWidget import TableWidget
from pyqtgraph.python2_3 import asUnicode
import types, traceback
import numpy as np
try:
import metaarray
HAVE_METAARRAY = True
except:
HAVE_METAARRAY = False
__all__ = ["DataTreeWidget"]
class DataTreeWidget(QtGui.QTreeWidget):
"""
Widget for displaying hierarchical python data structures
(eg, nested dicts, lists, and arrays)
"""
def __init__(self, parent=None, data=None):
QtGui.QTreeWidget.__init__(self, parent)
self.setVerticalScrollMode(self.ScrollPerPixel)
self.setData(data)
self.setColumnCount(3)
self.setHeaderLabels(["key / index", "type", "value"])
self.setAlternatingRowColors(True)
def setData(self, data, hideRoot=False):
"""data should be a dictionary."""
self.clear()
self.widgets = []
self.nodes = {}
self.buildTree(data, self.invisibleRootItem(), hideRoot=hideRoot)
self.expandToDepth(3)
self.resizeColumnToContents(0)
def buildTree(self, data, parent, name="", hideRoot=False, path=()):
if hideRoot:
node = parent
else:
node = QtGui.QTreeWidgetItem([name, "", ""])
parent.addChild(node)
# record the path to the node so it can be retrieved later
# (this is used by DiffTreeWidget)
self.nodes[path] = node
typeStr, desc, childs, widget = self.parse(data)
node.setText(1, typeStr)
node.setText(2, desc)
# Truncate description and add text box if needed
if len(desc) > 100:
desc = desc[:97] + "..."
if widget is None:
widget = QtGui.QPlainTextEdit(asUnicode(data))
widget.setMaximumHeight(200)
widget.setReadOnly(True)
# Add widget to new subnode
if widget is not None:
self.widgets.append(widget)
subnode = QtGui.QTreeWidgetItem(["", "", ""])
node.addChild(subnode)
self.setItemWidget(subnode, 0, widget)
self.setFirstItemColumnSpanned(subnode, True)
# recurse to children
for key, data in childs.items():
self.buildTree(data, node, asUnicode(key), path=path + (key,))
def parse(self, data):
"""
Given any python object, return:
* type
* a short string representation
* a dict of sub-objects to be parsed
* optional widget to display as sub-node
"""
# defaults for all objects
typeStr = type(data).__name__
if typeStr == "instance":
typeStr += ": " + data.__class__.__name__
widget = None
desc = ""
childs = {}
# type-specific changes
if isinstance(data, dict):
desc = "length=%d" % len(data)
if isinstance(data, OrderedDict):
childs = data
else:
childs = OrderedDict(sorted(data.items()))
elif isinstance(data, (list, tuple)):
desc = "length=%d" % len(data)
childs = OrderedDict(enumerate(data))
elif HAVE_METAARRAY and (
hasattr(data, "implements") and data.implements("MetaArray")
):
childs = OrderedDict(
[("data", data.view(np.ndarray)), ("meta", data.infoCopy())]
)
elif isinstance(data, np.ndarray):
desc = "shape=%s dtype=%s" % (data.shape, data.dtype)
table = TableWidget()
table.setData(data)
table.setMaximumHeight(200)
widget = table
elif isinstance(
data, types.TracebackType
): ## convert traceback to a list of strings
frames = list(
map(str.strip, traceback.format_list(traceback.extract_tb(data)))
)
# childs = OrderedDict([
# (i, {'file': child[0], 'line': child[1], 'function': child[2], 'code': child[3]})
# for i, child in enumerate(frames)])
# childs = OrderedDict([(i, ch) for i,ch in enumerate(frames)])
widget = QtGui.QPlainTextEdit(asUnicode("\n".join(frames)))
widget.setMaximumHeight(200)
widget.setReadOnly(True)
else:
desc = asUnicode(data)
return typeStr, desc, childs, widget

View File

@@ -0,0 +1,170 @@
# -*- coding: utf-8 -*-
from pyqtgraph.Qt import QtGui, QtCore
from pyqtgraph.pgcollections import OrderedDict
from .DataTreeWidget import DataTreeWidget
import pyqtgraph.functions as fn
import types, traceback
import numpy as np
__all__ = ["DiffTreeWidget"]
class DiffTreeWidget(QtGui.QWidget):
"""
Widget for displaying differences between hierarchical python data structures
(eg, nested dicts, lists, and arrays)
"""
def __init__(self, parent=None, a=None, b=None):
QtGui.QWidget.__init__(self, parent)
self.layout = QtGui.QHBoxLayout()
self.setLayout(self.layout)
self.trees = [DataTreeWidget(self), DataTreeWidget(self)]
for t in self.trees:
self.layout.addWidget(t)
if a is not None:
self.setData(a, b)
def setData(self, a, b):
"""
Set the data to be compared in this widget.
"""
self.data = (a, b)
self.trees[0].setData(a)
self.trees[1].setData(b)
return self.compare(a, b)
def compare(self, a, b, path=()):
"""
Compare data structure *a* to structure *b*.
Return True if the objects match completely.
Otherwise, return a structure that describes the differences:
{ 'type': bool
'len': bool,
'str': bool,
'shape': bool,
'dtype': bool,
'mask': array,
}
"""
bad = (255, 200, 200)
diff = []
# generate typestr, desc, childs for each object
typeA, descA, childsA, _ = self.trees[0].parse(a)
typeB, descB, childsB, _ = self.trees[1].parse(b)
if typeA != typeB:
self.setColor(path, 1, bad)
if descA != descB:
self.setColor(path, 2, bad)
if isinstance(a, dict) and isinstance(b, dict):
keysA = set(a.keys())
keysB = set(b.keys())
for key in keysA - keysB:
self.setColor(path + (key,), 0, bad, tree=0)
for key in keysB - keysA:
self.setColor(path + (key,), 0, bad, tree=1)
for key in keysA & keysB:
self.compare(a[key], b[key], path + (key,))
elif isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)):
for i in range(max(len(a), len(b))):
if len(a) <= i:
self.setColor(path + (i,), 0, bad, tree=1)
elif len(b) <= i:
self.setColor(path + (i,), 0, bad, tree=0)
else:
self.compare(a[i], b[i], path + (i,))
elif (
isinstance(a, np.ndarray)
and isinstance(b, np.ndarray)
and a.shape == b.shape
):
tableNodes = [tree.nodes[path].child(0) for tree in self.trees]
if a.dtype.fields is None and b.dtype.fields is None:
eq = self.compareArrays(a, b)
if not np.all(eq):
for n in tableNodes:
n.setBackground(0, fn.mkBrush(bad))
# for i in np.argwhere(~eq):
else:
if a.dtype == b.dtype:
for i, k in enumerate(a.dtype.fields.keys()):
eq = self.compareArrays(a[k], b[k])
if not np.all(eq):
for n in tableNodes:
n.setBackground(0, fn.mkBrush(bad))
# for j in np.argwhere(~eq):
# dict: compare keys, then values where keys match
# list:
# array: compare elementwise for same shape
def compareArrays(self, a, b):
intnan = -9223372036854775808 # happens when np.nan is cast to int
anans = np.isnan(a) | (a == intnan)
bnans = np.isnan(b) | (b == intnan)
eq = anans == bnans
mask = ~anans
eq[mask] = np.allclose(a[mask], b[mask])
return eq
def setColor(self, path, column, color, tree=None):
brush = fn.mkBrush(color)
# Color only one tree if specified.
if tree is None:
trees = self.trees
else:
trees = [self.trees[tree]]
for tree in trees:
item = tree.nodes[path]
item.setBackground(column, brush)
def _compare(self, a, b):
"""
Compare data structure *a* to structure *b*.
"""
# Check test structures are the same
assert type(info) is type(expect)
if hasattr(info, "__len__"):
assert len(info) == len(expect)
if isinstance(info, dict):
for k in info:
assert k in expect
for k in expect:
assert k in info
self.compare_results(info[k], expect[k])
elif isinstance(info, list):
for i in range(len(info)):
self.compare_results(info[i], expect[i])
elif isinstance(info, np.ndarray):
assert info.shape == expect.shape
assert info.dtype == expect.dtype
if info.dtype.fields is None:
intnan = -9223372036854775808 # happens when np.nan is cast to int
inans = np.isnan(info) | (info == intnan)
enans = np.isnan(expect) | (expect == intnan)
assert np.all(inans == enans)
mask = ~inans
assert np.allclose(info[mask], expect[mask])
else:
for k in info.dtype.fields.keys():
self.compare_results(info[k], expect[k])
else:
try:
assert info == expect
except Exception:
raise NotImplementedError(
"Cannot compare objects of type %s" % type(info)
)

View File

@@ -0,0 +1,2 @@
Copied from github.com/campagnola/pyqtgraph datatree-arrays branch; this gives us DiffTreeWidget.
This file can be removed if DiffTreeWidget has been merged into pyqtgraph.

View File

@@ -0,0 +1,515 @@
# -*- coding: utf-8 -*-
import numpy as np
from pyqtgraph.Qt import QtGui, QtCore
from pyqtgraph.python2_3 import asUnicode, basestring
import pyqtgraph.metaarray as metaarray
__all__ = ["TableWidget"]
def _defersort(fn):
def defersort(self, *args, **kwds):
# may be called recursively; only the first call needs to block sorting
setSorting = False
if self._sorting is None:
self._sorting = self.isSortingEnabled()
setSorting = True
self.setSortingEnabled(False)
try:
return fn(self, *args, **kwds)
finally:
if setSorting:
self.setSortingEnabled(self._sorting)
self._sorting = None
return defersort
class TableWidget(QtGui.QTableWidget):
"""Extends QTableWidget with some useful functions for automatic data handling
and copy / export context menu. Can automatically format and display a variety
of data types (see :func:`setData() <pyqtgraph.TableWidget.setData>` for more
information.
"""
def __init__(self, *args, **kwds):
"""
All positional arguments are passed to QTableWidget.__init__().
===================== =================================================
**Keyword Arguments**
editable (bool) If True, cells in the table can be edited
by the user. Default is False.
sortable (bool) If True, the table may be soted by
clicking on column headers. Note that this also
causes rows to appear initially shuffled until
a sort column is selected. Default is True.
*(added in version 0.9.9)*
===================== =================================================
"""
QtGui.QTableWidget.__init__(self, *args)
self.itemClass = TableWidgetItem
self.setVerticalScrollMode(self.ScrollPerPixel)
self.setSelectionMode(QtGui.QAbstractItemView.ContiguousSelection)
self.setSizePolicy(QtGui.QSizePolicy.Preferred, QtGui.QSizePolicy.Preferred)
self.clear()
kwds.setdefault("sortable", True)
kwds.setdefault("editable", False)
self.setEditable(kwds.pop("editable"))
self.setSortingEnabled(kwds.pop("sortable"))
if len(kwds) > 0:
raise TypeError("Invalid keyword arguments '%s'" % kwds.keys())
self._sorting = None # used when temporarily disabling sorting
self._formats = {
None: None
} # stores per-column formats and entire table format
self.sortModes = {} # stores per-column sort mode
self.itemChanged.connect(self.handleItemChanged)
self.contextMenu = QtGui.QMenu()
self.contextMenu.addAction("Copy Selection").triggered.connect(self.copySel)
self.contextMenu.addAction("Copy All").triggered.connect(self.copyAll)
self.contextMenu.addAction("Save Selection").triggered.connect(self.saveSel)
self.contextMenu.addAction("Save All").triggered.connect(self.saveAll)
def clear(self):
"""Clear all contents from the table."""
QtGui.QTableWidget.clear(self)
self.verticalHeadersSet = False
self.horizontalHeadersSet = False
self.items = []
self.setRowCount(0)
self.setColumnCount(0)
self.sortModes = {}
def setData(self, data):
"""Set the data displayed in the table.
Allowed formats are:
* numpy arrays
* numpy record arrays
* metaarrays
* list-of-lists [[1,2,3], [4,5,6]]
* dict-of-lists {'x': [1,2,3], 'y': [4,5,6]}
* list-of-dicts [{'x': 1, 'y': 4}, {'x': 2, 'y': 5}, ...]
"""
self.clear()
self.appendData(data)
self.resizeColumnsToContents()
@_defersort
def appendData(self, data):
"""
Add new rows to the table.
See :func:`setData() <pyqtgraph.TableWidget.setData>` for accepted
data types.
"""
startRow = self.rowCount()
fn0, header0 = self.iteratorFn(data)
if fn0 is None:
self.clear()
return
it0 = fn0(data)
try:
first = next(it0)
except StopIteration:
return
fn1, header1 = self.iteratorFn(first)
if fn1 is None:
self.clear()
return
firstVals = [x for x in fn1(first)]
self.setColumnCount(len(firstVals))
if not self.verticalHeadersSet and header0 is not None:
labels = [self.verticalHeaderItem(i).text() for i in range(self.rowCount())]
self.setRowCount(startRow + len(header0))
self.setVerticalHeaderLabels(labels + header0)
self.verticalHeadersSet = True
if not self.horizontalHeadersSet and header1 is not None:
self.setHorizontalHeaderLabels(header1)
self.horizontalHeadersSet = True
i = startRow
self.setRow(i, firstVals)
for row in it0:
i += 1
self.setRow(i, [x for x in fn1(row)])
if (
self._sorting
and self.horizontalHeadersSet
and self.horizontalHeader().sortIndicatorSection() >= self.columnCount()
):
self.sortByColumn(0, QtCore.Qt.AscendingOrder)
def setEditable(self, editable=True):
self.editable = editable
for item in self.items:
item.setEditable(editable)
def setFormat(self, format, column=None):
"""
Specify the default text formatting for the entire table, or for a
single column if *column* is specified.
If a string is specified, it is used as a format string for converting
float values (and all other types are converted using str). If a
function is specified, it will be called with the item as its only
argument and must return a string. Setting format = None causes the
default formatter to be used instead.
Added in version 0.9.9.
"""
if (
format is not None
and not isinstance(format, basestring)
and not callable(format)
):
raise ValueError(
"Format argument must string, callable, or None. (got %s)" % format
)
self._formats[column] = format
if column is None:
# update format of all items that do not have a column format
# specified
for c in range(self.columnCount()):
if self._formats.get(c, None) is None:
for r in range(self.rowCount()):
item = self.item(r, c)
if item is None:
continue
item.setFormat(format)
else:
# set all items in the column to use this format, or the default
# table format if None was specified.
if format is None:
format = self._formats[None]
for r in range(self.rowCount()):
item = self.item(r, column)
if item is None:
continue
item.setFormat(format)
def iteratorFn(self, data):
## Return 1) a function that will provide an iterator for data and 2) a list of header strings
if isinstance(data, list) or isinstance(data, tuple):
return lambda d: d.__iter__(), None
elif isinstance(data, dict):
return lambda d: iter(d.values()), list(map(asUnicode, data.keys()))
elif hasattr(data, "implements") and data.implements("MetaArray"):
if data.axisHasColumns(0):
header = [
asUnicode(data.columnName(0, i)) for i in range(data.shape[0])
]
elif data.axisHasValues(0):
header = list(map(asUnicode, data.xvals(0)))
else:
header = None
return self.iterFirstAxis, header
elif isinstance(data, np.ndarray):
return self.iterFirstAxis, None
elif isinstance(data, np.void):
return self.iterate, list(map(asUnicode, data.dtype.names))
elif data is None:
return (None, None)
elif np.isscalar(data):
return self.iterateScalar, None
else:
msg = "Don't know how to iterate over data type: {!s}".format(type(data))
raise TypeError(msg)
def iterFirstAxis(self, data):
for i in range(data.shape[0]):
yield data[i]
def iterate(self, data):
# for numpy.void, which can be iterated but mysteriously
# has no __iter__ (??)
for x in data:
yield x
def iterateScalar(self, data):
yield data
def appendRow(self, data):
self.appendData([data])
@_defersort
def addRow(self, vals):
row = self.rowCount()
self.setRowCount(row + 1)
self.setRow(row, vals)
@_defersort
def setRow(self, row, vals):
if row > self.rowCount() - 1:
self.setRowCount(row + 1)
for col in range(len(vals)):
val = vals[col]
item = self.itemClass(val, row)
item.setEditable(self.editable)
sortMode = self.sortModes.get(col, None)
if sortMode is not None:
item.setSortMode(sortMode)
format = self._formats.get(col, self._formats[None])
item.setFormat(format)
self.items.append(item)
self.setItem(row, col, item)
item.setValue(val) # Required--the text-change callback is invoked
# when we call setItem.
def setSortMode(self, column, mode):
"""
Set the mode used to sort *column*.
============== ========================================================
**Sort Modes**
value Compares item.value if available; falls back to text
comparison.
text Compares item.text()
index Compares by the order in which items were inserted.
============== ========================================================
Added in version 0.9.9
"""
for r in range(self.rowCount()):
item = self.item(r, column)
if hasattr(item, "setSortMode"):
item.setSortMode(mode)
self.sortModes[column] = mode
def sizeHint(self):
# based on http://stackoverflow.com/a/7195443/54056
width = sum(self.columnWidth(i) for i in range(self.columnCount()))
width += self.verticalHeader().sizeHint().width()
width += self.verticalScrollBar().sizeHint().width()
width += self.frameWidth() * 2
height = sum(self.rowHeight(i) for i in range(self.rowCount()))
height += self.verticalHeader().sizeHint().height()
height += self.horizontalScrollBar().sizeHint().height()
return QtCore.QSize(width, height)
def serialize(self, useSelection=False):
"""Convert entire table (or just selected area) into tab-separated text values"""
if useSelection:
selection = self.selectedRanges()[0]
rows = list(range(selection.topRow(), selection.bottomRow() + 1))
columns = list(range(selection.leftColumn(), selection.rightColumn() + 1))
else:
rows = list(range(self.rowCount()))
columns = list(range(self.columnCount()))
data = []
if self.horizontalHeadersSet:
row = []
if self.verticalHeadersSet:
row.append(asUnicode(""))
for c in columns:
row.append(asUnicode(self.horizontalHeaderItem(c).text()))
data.append(row)
for r in rows:
row = []
if self.verticalHeadersSet:
row.append(asUnicode(self.verticalHeaderItem(r).text()))
for c in columns:
item = self.item(r, c)
if item is not None:
row.append(asUnicode(item.value))
else:
row.append(asUnicode(""))
data.append(row)
s = ""
for row in data:
s += "\t".join(row) + "\n"
return s
def copySel(self):
"""Copy selected data to clipboard."""
QtGui.QApplication.clipboard().setText(self.serialize(useSelection=True))
def copyAll(self):
"""Copy all data to clipboard."""
QtGui.QApplication.clipboard().setText(self.serialize(useSelection=False))
def saveSel(self):
"""Save selected data to file."""
self.save(self.serialize(useSelection=True))
def saveAll(self):
"""Save all data to file."""
self.save(self.serialize(useSelection=False))
def save(self, data):
fileName = QtGui.QFileDialog.getSaveFileName(
self, "Save As..", "", "Tab-separated values (*.tsv)"
)
if fileName == "":
return
open(fileName, "w").write(data)
def contextMenuEvent(self, ev):
self.contextMenu.popup(ev.globalPos())
def keyPressEvent(self, ev):
if ev.key() == QtCore.Qt.Key_C and ev.modifiers() == QtCore.Qt.ControlModifier:
ev.accept()
self.copySel()
else:
QtGui.QTableWidget.keyPressEvent(self, ev)
def handleItemChanged(self, item):
item.itemChanged()
class TableWidgetItem(QtGui.QTableWidgetItem):
def __init__(self, val, index, format=None):
QtGui.QTableWidgetItem.__init__(self, "")
self._blockValueChange = False
self._format = None
self._defaultFormat = "%0.3g"
self.sortMode = "value"
self.index = index
flags = QtCore.Qt.ItemIsSelectable | QtCore.Qt.ItemIsEnabled
self.setFlags(flags)
self.setValue(val)
self.setFormat(format)
def setEditable(self, editable):
"""
Set whether this item is user-editable.
"""
if editable:
self.setFlags(self.flags() | QtCore.Qt.ItemIsEditable)
else:
self.setFlags(self.flags() & ~QtCore.Qt.ItemIsEditable)
def setSortMode(self, mode):
"""
Set the mode used to sort this item against others in its column.
============== ========================================================
**Sort Modes**
value Compares item.value if available; falls back to text
comparison.
text Compares item.text()
index Compares by the order in which items were inserted.
============== ========================================================
"""
modes = ("value", "text", "index", None)
if mode not in modes:
raise ValueError("Sort mode must be one of %s" % str(modes))
self.sortMode = mode
def setFormat(self, fmt):
"""Define the conversion from item value to displayed text.
If a string is specified, it is used as a format string for converting
float values (and all other types are converted using str). If a
function is specified, it will be called with the item as its only
argument and must return a string.
Added in version 0.9.9.
"""
if fmt is not None and not isinstance(fmt, basestring) and not callable(fmt):
raise ValueError(
"Format argument must string, callable, or None. (got %s)" % fmt
)
self._format = fmt
self._updateText()
def _updateText(self):
self._blockValueChange = True
try:
self._text = self.format()
self.setText(self._text)
finally:
self._blockValueChange = False
def setValue(self, value):
self.value = value
self._updateText()
def itemChanged(self):
"""Called when the data of this item has changed."""
if self.text() != self._text:
self.textChanged()
def textChanged(self):
"""Called when this item's text has changed for any reason."""
self._text = self.text()
if self._blockValueChange:
# text change was result of value or format change; do not
# propagate.
return
try:
self.value = type(self.value)(self.text())
except ValueError:
self.value = str(self.text())
def format(self):
if callable(self._format):
return self._format(self)
if isinstance(self.value, (float, np.floating)):
if self._format is None:
return self._defaultFormat % self.value
else:
return self._format % self.value
else:
return asUnicode(self.value)
def __lt__(self, other):
if self.sortMode == "index" and hasattr(other, "index"):
return self.index < other.index
if self.sortMode == "value" and hasattr(other, "value"):
return self.value < other.value
else:
return self.text() < other.text()
if __name__ == "__main__":
app = QtGui.QApplication([])
win = QtGui.QMainWindow()
t = TableWidget()
win.setCentralWidget(t)
win.resize(800, 600)
win.show()
ll = [[1, 2, 3, 4, 5]] * 20
ld = [{"x": 1, "y": 2, "z": 3}] * 20
dl = {"x": list(range(20)), "y": list(range(20)), "z": list(range(20))}
a = np.ones((20, 5))
ra = np.ones((20,), dtype=[("x", int), ("y", int), ("z", int)])
t.setData(ll)
ma = metaarray.MetaArray(
np.ones((20, 3)),
info=[
{"values": np.linspace(1, 5, 20)},
{"cols": [{"name": "x"}, {"name": "y"}, {"name": "z"}]},
],
)
t.setData(ma)

View File

@@ -0,0 +1 @@
from .DiffTreeWidget import DiffTreeWidget

83
cnmodel/util/expfitting.py Executable file
View File

@@ -0,0 +1,83 @@
#!/usr/bin/env python
# encoding: utf-8
"""
expfitting.py
Provide single or double exponential fits to data.
"""
import lmfit
import numpy as np
import scipy.optimize
class ExpFitting:
"""
Parameters
----------
nexp : int
1 or 2 for single or double exponential fit
initpars : dict
dict of initial parameters. For example: {'dc': 0.,
'a1': 1., 't1': 3, 'a2' : 0.5, 'delta': 3.}, where
delta determines the ratio between the time constants.
bounds : dict
dictionary of bounds for each parameter, with a list of lower and upper values.
"""
def __init__(self, nexp=1, initpars=None, bounds=None):
self.fitpars = lmfit.Parameters()
if nexp == 1:
# (Name, Value, Vary, Min, Max, Expr)
self.fitpars.add_many(
("dc", 0, True, -100.0, 0.0, None),
("a1", 1.0, True, -25.0, 25.0, None),
("t1", 10.0, True, 0.1, 50, None),
)
self.efunc = self.exp1_err
elif nexp == 2:
self.fitpars.add_many(
("dc", 0, True, -100.0, 0.0, None),
("a1", 1.0, True, 0.0, 25.0, None),
("t1", 10.0, True, 0.1, 50, None),
("a2", 1.0, True, 0.0, 25.0, None),
("delta", 3.0, True, 3.0, 100.0, None),
)
if initpars is not None:
assert len(initpars) == 5
for k, v in initpars.iteritems():
self.fitpars[k].value = v
if bounds is not None:
assert len(bounds) == 5
for k, v in bounds.iteritems():
self.fitpars[k].min = v[0]
self.fitpars[k].max = v[1]
self.efunc = self.exp2_err
else:
raise ValueError
def fit(self, x, y, p, verbose=False):
kws = {"maxfev": 5000}
mim = lmfit.minimize(
self.efunc, p, method="least_squares", args=(x, y)
) # , kws=kws)
if verbose:
lmfit.printfuncs.report_fit(mim.params)
fitpars = mim.params
return fitpars
@staticmethod
def exp1(x, dc, t1, a1):
return dc + a1 * np.exp(-x / t1)
def exp1_err(self, p, x, y):
return np.fabs(y - self.exp1(x, **dict([(k, p.value) for k, p in p.items()])))
@staticmethod
def exp2(x, dc, t1, a1, a2, delta):
return dc + a1 * np.exp(-x / t1) + a2 * np.exp(-x / (t1 * delta))
def exp2_err(self, p, x, y):
return np.fabs(y - self.exp2(x, **dict([(k, p.value) for k, p in p.items()])))

133
cnmodel/util/filelock.py Normal file
View File

@@ -0,0 +1,133 @@
# Copyright (c) 2009, Evan Fosmark
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# The views and conclusions contained in the software and documentation are those
# of the authors and should not be interpreted as representing official policies,
# either expressed or implied, of the FreeBSD Project.
import os
import time
import errno
class FileLockException(Exception):
pass
class FileLock(object):
""" A file locking mechanism that has context-manager support so
you can use it in a with statement. This should be relatively cross
compatible as it doesn't rely on msvcrt or fcntl for the locking.
"""
lock_count = {} # lock count for each locked file
lock_handles = {} # file handle for each locked file
def __init__(self, file_name, timeout=10, delay=0.05):
""" Prepare the file locker. Specify the file to lock and optionally
the maximum timeout and the delay between each attempt to lock.
"""
self.fd = None
self.is_locked = False
self.lockfile = os.path.join(os.getcwd(), "%s.lock" % file_name)
self.file_name = file_name
self.timeout = timeout
self.delay = delay
def acquire(self):
""" Acquire the lock, if possible. If the lock is in use, it check again
every `wait` seconds. It does this until it either gets the lock or
exceeds `timeout` number of seconds, in which case it throws
an exception.
"""
# Don't try to lock the same file more than once
if FileLock.lock_count.setdefault(self.lockfile, 0) > 0:
self.is_locked = True
self.fd = FileLock.lock_handles[self.lockfile]
FileLock.lock_count[self.lockfile] += 1
return
start_time = time.time()
while True:
try:
# create the cache directory if it does not exist (new installations will not have a cache directory)
stimdir = os.path.dirname(self.lockfile)
cachedir = os.path.dirname(stimdir)
if not os.path.isdir(cachedir): # make sure the cache exists
os.mkdir(cachedir)
if not os.path.isdir(stimdir): # and the specific stimulus dir
os.mkdir(stimdir)
self.fd = os.open(self.lockfile, os.O_CREAT | os.O_EXCL | os.O_RDWR)
open(self.lockfile, "w").write(str(os.getpid()))
break
except OSError as e:
if e.errno != errno.EEXIST:
raise
if (time.time() - start_time) >= self.timeout:
try:
pid = open(self.lockfile).read()
except Exception:
pid = "[error reading lockfile: %s]" % sys.exc_info()[0]
raise FileLockException(
"Timeout occured. (%s is locked by pid %s)"
% (self.lockfile, pid)
)
time.sleep(self.delay)
self.is_locked = True
FileLock.lock_count[self.lockfile] += 1
FileLock.lock_handles[self.lockfile] = self.fd
def release(self):
""" Get rid of the lock by deleting the lockfile.
When working in a `with` statement, this gets automatically
called at the end.
"""
if self.is_locked:
self.is_locked = False
FileLock.lock_count[self.lockfile] -= 1
if FileLock.lock_count[self.lockfile] == 0:
os.close(self.fd)
os.unlink(self.lockfile)
del FileLock.lock_handles[self.lockfile]
def __enter__(self):
""" Activated when used in the with statement.
Should automatically acquire a lock to be used in the with block.
"""
if not self.is_locked:
self.acquire()
return self
def __exit__(self, type, value, traceback):
""" Activated at the end of the with statement.
It automatically releases the lock if it isn't locked.
"""
if self.is_locked:
self.release()
def __del__(self):
""" Make sure that the FileLock instance doesn't leave a lockfile
lying around.
"""
self.release()

View File

@@ -0,0 +1,95 @@
from __future__ import print_function
from scipy import interpolate
import numpy as np
def find_point(x, y, peakindex, val, direction="left", limits=None):
"""
Given a waveform defined by *x* and *y* arrays, return the first time
at which the waveform crosses (y[peakindex] * val). The search begins at
*peakindex* and proceeds in *direction*.
Optionally, *limits* may specify a smaller search region in the form of
(t0, t1, dt).
"""
# F = interpolate.UnivariateSpline(x, y, s=0) # declare function
# To find x at y then do:
istart = 0
iend = len(y)
if limits is not None:
istart = int(limits[0] / limits[2])
iend = int(limits[1] / limits[2])
yToFind = y[peakindex] * val
if direction == "left":
yreduced = np.array(y[istart:peakindex]) - yToFind
try:
Fr = interpolate.UnivariateSpline(x[istart:peakindex], yreduced, s=0)
except:
print("find_point: insufficient time points for analysis")
print("arg lengths:", len(x[istart:peakindex]), len(yreduced))
print("istart, peakindex: ", istart, peakindex)
print("ytofine: ", yToFind)
raise
res = float("nan")
return res
res = Fr.roots()
if len(res) > 1:
res = res[-1]
else:
yreduced = np.array(y[peakindex:iend]) - yToFind
try:
Fr = interpolate.UnivariateSpline(x[peakindex:iend], yreduced, s=0)
except:
print("find_point: insufficient time points for analysis?")
print("arg lengths:", len(x[peakindex:iend]), len(yreduced))
raise
res = float("nan")
return res
res = Fr.roots()
if len(res) > 1:
res = res[0]
# pdb.set_trace()
try:
res.pop()
except:
pass
if not res: # tricky - an empty list is False, but does not evaluate to False
res = float("nan") # replace with a NaN
else:
res = float(res) # make sure is just a simple number (no arrays)
return res
def find_crossing(data, start=0, direction=1, threshold=0):
"""Return the index at which *data* crosses *threshold*, starting
at index *start* and proceeding in *direction* (+/-1).
The value returned is a float indicating the interpolated index
position where the data crosses threshold, or NaN if the threshold was
never crossed.
"""
# Note: this function is very similar in purpose to find_point, but was
# added due to issues with interpolate.UnivariateSpline
assert direction in (1, -1)
cross_rising = data[start] < threshold
def test(x):
if cross_rising:
return x > threshold
else:
return x < threshold
while True:
next_ind = start + direction
if next_ind < 0 or next_ind >= len(data):
return np.nan
if test(data[next_ind]):
# crossed; return interpolated value
s1 = data[next_ind] - threshold
s2 = threshold = data[start]
return (next_ind * s2 + start * s1) / (s2 + s1)
start = next_ind

219
cnmodel/util/fitting.py Normal file
View File

@@ -0,0 +1,219 @@
import numpy as np
import lmfit
import pyqtgraph as pg
from pyqtgraph.Qt import QtCore, QtGui
class FitModel(lmfit.Model):
""" Simple extension of lmfit.Model that allows one-line fitting.
Example uses:
# single exponential fit::
fit = expfitting.Exp1.fit(data,
x=time_vals,
xoffset=(0, 'fixed'),
yoffset=(yoff_guess, -120, 0),
amp=(amp_guess, 0, 50),
tau=(tau_guess, 0.1, 50))
# plot the fit::
fit_curve = fit.eval()
plot(time_vals, fit_curve)
# double exponential fit with tau ratio constraint
# note that 'tau_ratio' does not appear in the exp2 model;
# we can define new parameters here.::
fit = expfitting.Exp2.fit(data,
x=time_vals,
xoffset=(0, 'fixed'),
yoffset=(yoff_guess, -120, 0),
amp1=(amp_guess, 0, 50),
tau1=(tau_guess, 0.1, 50),
amp2=(-0.5, -50, 0),
tau_ratio=(10, 3, 50),
tau2='tau1 * tau_ratio'
)
"""
def fit(self, data, interactive=False, **params):
""" Return a fit of data to this model.
Parameters
----------
data : array
dependent data to fit
interactive : bool
If True, show a GUI used for interactively exploring fit parameters
Extra keyword arguments are passed to make_params() if they are model
parameter names, or passed directly to Model.fit() for independent
variable names.
Returns
-------
fit of data to model as an lmfit object
"""
fit_params = {}
model_params = {}
for k, v in params.items():
if k in self.independent_vars or k in [
"weights",
"method",
"scale_covar",
"iter_cb",
]:
fit_params[k] = v
else:
model_params[k] = v
p = self.make_params(**model_params)
# print ('params: ', p)
# print ('fitparams: ', fit_params)
# import matplotlib.pyplot as mpl
# mpl.plot(data)
# mpl.show()
fit = lmfit.Model.fit(self, data, params=p, **fit_params)
if interactive:
self.show_interactive(fit)
return fit
def make_params(self, **params):
"""
Make parameters used for fitting with this model.
Keyword arguments are used to generate parameters for the fit. Each
parameter may be specified by the following formats:
param=value :
The initial value of the parameter
param=(value, 'fixed') :
Fixed value for the parameter
param=(value, min, max) :
Initial value and min, max values, which may be float or None
param='expression' :
Expression used to compute parameter value. See:
http://lmfit.github.io/lmfit-py/constraints.html#constraints-chapter
"""
p = lmfit.Parameters()
for k in self.param_names:
p.add(k)
for param, val in params.items():
if param not in p:
p.add(param)
if isinstance(val, str):
p[param].expr = val
elif np.isscalar(val):
p[param].value = val
elif isinstance(val, tuple):
if len(val) == 2:
assert val[1] == "fixed"
p[param].value = val[0]
p[param].vary = False
elif len(val) == 3:
p[param].value = val[0]
p[param].min = val[1]
p[param].max = val[2]
else:
raise TypeError(
"Tuple parameter specifications must be (val, 'fixed')"
" or (val, min, max)."
)
else:
raise TypeError("Invalid parameter specification: %r" % val)
# set initial values for parameters with mathematical constraints
# this is to allow fit.eval(**fit.init_params)
global_ns = np.__dict__
for param, val in params.items():
if isinstance(val, str):
p[param].value = eval(val, global_ns, p.valuesdict())
return p
def show_interactive(self, fit=None):
""" Show an interactive GUI for exploring fit parameters.
"""
if not hasattr(self, "_interactive_win"):
self._interactive_win = FitExplorer(model=self, fit=fit)
self._interactive_win.show()
def exp1(x, xoffset, yoffset, tau, amp):
return yoffset + amp * np.exp(-(x - xoffset) / tau)
class Exp1(FitModel):
""" Single exponential fitting model.
Parameters are xoffset, yoffset, amp, and tau.
"""
def __init__(self):
FitModel.__init__(self, exp1, independent_vars=["x"])
def exp2(x, xoffset, yoffset, tau1, amp1, tau2, amp2):
xoff = x - xoffset
return yoffset + amp1 * np.exp(-xoff / tau1) + amp2 * np.exp(-xoff / tau2)
class Exp2(FitModel):
""" Double exponential fitting model.
Parameters are xoffset, yoffset, amp1, tau1, amp2, and tau2.
"""
def __init__(self):
FitModel.__init__(self, exp2, independent_vars=["x"])
class FitExplorer(QtGui.QWidget):
def __init__(self, model, fit):
QtGui.QWidget.__init__(self)
self.model = model
self.fit = fit
self.layout = QtGui.QGridLayout()
self.setLayout(self.layout)
self.splitter = QtGui.QSplitter(QtCore.Qt.Horizontal)
self.layout.addWidget(self.splitter)
self.ptree = pg.parametertree.ParameterTree()
self.splitter.addWidget(self.ptree)
self.plot = pg.PlotWidget()
self.splitter.addWidget(self.plot)
self.params = pg.parametertree.Parameter.create(
name="param_root",
type="group",
children=[
dict(name="fit", type="action"),
dict(name="parameters", type="group"),
],
)
for k in fit.params:
p = pg.parametertree.Parameter.create(
name=k, type="float", value=fit.params[k].value
)
self.params.param("parameters").addChild(p)
self.ptree.setParameters(self.params)
self.update_plots()
self.params.param("parameters").sigTreeStateChanged.connect(self.update_plots)
def update_plots(self):
for k in self.fit.params:
self.fit.params[k].value = self.params["parameters", k]
self.plot.clear()
self.plot.plot(self.fit.data)
self.plot.plot(self.fit.eval(), pen="y")

View File

@@ -0,0 +1,477 @@
from __future__ import print_function
__author__ = "pbmanis"
"""
ManageANSpikes is a class to read the output of the Zilany et al. 2009 AN model into
python, and provides services to access that data.
Basic usage is to create an instance of the class, and specify the data directory
if necessary.
You may then get the data in the format of a list using one of the "get" routines.
The data is pulled from the nearest CF or the specified CF.
"""
import os
import re
import numpy as np
from collections import OrderedDict
import scipy.io
import matplotlib.pyplot as MP
class ManageANSpikes:
"""
ManageANSpikes is a class to read the output of the Zilany et al. 2009 AN model into
python, and provides services to access that data.
Basic usage is to create an instance of the class, and specify the data directory
if necessary.
You may then get the data in the format of a list using one of the "get" routines.
The data is pulled from the nearest CF or the specified CF.
"""
def __init__(self):
# self.datadir = environment['HOME'] + '/Desktop/Matlab/ZilanyCarney-JASAcode-2009/'
self.data_dir = os.environ["HOME"] + "/Desktop/Matlab/ANData/"
self.data_read_flag = False
self.dataType = "RI"
self.set_CF_map(4000, 38000, 25) # the default list
self.all_AN = None
def get_data_dir(self):
return self.data_dir
def set_data_dir(self, directory):
if os.path.isdir(directory):
self.data_dir = directory
else:
raise ValueError("ManageANSpikes.set_data_dir: Path %d is not a directory")
def get_data_read(self):
return self.data_read_flag
def get_CF_map(self):
return self.CF_map
def set_CF_map(self, low, high, nfreq):
self.CF_map = np.round(np.logspace(np.log10(low), np.log10(high), nfreq))
def get_dataType(self):
return self.dataType
def set_dataType(self, datatype):
if datatype in ["RI", "PL"]: # only for recognized data types
self.dataType = datatype
else:
raise ValueError(
"get_anspikes.set_dataType: unrecognized type %s " % datatype
)
def plot_RI_vs_F(self, freq=10000.0, spontclass="HS", display=False):
cfd = {}
if display:
MP.figure(10)
for i, fr in enumerate(self.CF_map):
self.read_AN_data(freq=freq, CF=fr, spontclass=spontclass, ignoreflag=True)
nsp = np.zeros(len(self.SPLs)) # this is the same regardless
for i, db in enumerate(self.SPLs):
spkl = self.combine_reps(self.spikelist[i])
nsp[i] = len(spkl) # /self.SPLs[i]
if display:
MP.plot(self.SPLs, nsp)
def plot_Rate_vs_F(self, freq=10000.0, spontclass="HS", display=False, SPL=0.0):
"""
assumes we have done a "read all"
:param freq:
:param spontclass:
:param display:
:param SPL:
:return:
"""
cfd = {}
if display:
MP.figure(10)
for i, db in enumerate(self.SPLs): # plot will have lines at each SPL
# retrieve_from_all_AN(self, cf, SPL) # self.read_AN_data(freq=freq, CF=fr, spontclass=spontclass)
nsp = np.zeros(len(self.CF_map)) # this is the same regardless
for j, fr in enumerate(self.CF_map):
spikelist = self.retrieve_from_all_AN(fr, db)
spkl = self.combine_reps(spikelist)
nsp[j] = len(spkl) / len(spikelist)
if display:
MP.plot(self.CF_map, nsp)
def getANatFandSPL(self, spontclass="MS", freq=10000.0, CF=None, SPL=0):
"""
Get the AN data at a particular frequency and SPL. Note that the tone freq is specified,
but the CF of the fiber might be different. If CF is None, we try to get the closest
fiber data to the tone Freq.
Other arguments are the spontaneous rate and the SPL level.
The data must exist, or we epically fail.
:param spontclass: 'HS', 'MS', or 'LS' for high, middle or low spont groups (1,2,3 in Zilany et al model)
:param freq: The stimulus frequency, in Hz
:param CF: The 'characteristic frequency' of the desired AN fiber, in Hz
:param SPL: The sound pressure level, in dB SPL for the stimulus
:return: an array of nReps of spike times.
"""
if CF is None:
closest = np.argmin(
np.abs(self.CF_map - freq)
) # find closest to the stim freq
else:
closest = np.argmin(
np.abs(self.CF_map - CF)
) # find closest to the stim freq
CF = self.CF_map[closest]
print("closest f: ", CF)
self.read_AN_data(spontclass=spontclass, freq=10000.0, CF=CF)
return self.get_AN_at_SPL(SPL)
def get_AN_at_SPL(self, spl=None):
"""
grabs the AN data for the requested SPL for the data set currently loaded
The data must exist, or we epically fail.
:param spl: sound pressure level, in dB SPL
:return: spike trains (times) as a list of nReps numpy arrays.
"""
if not self.data_read_flag:
print("getANatSPL: No data read yet")
return None
try:
k = int(np.where(self.SPLs == spl)[0])
except (ValueError, TypeError) as e:
print(
"get_anspikes::getANatSPL: spl=%6.1f not in list of spls:" % (spl),
self.SPLs,
)
exit() # no match
# now clean up array to keep it iterable upon return
spkl = [[]] * len(self.spikelist[k])
for i in xrange(len(self.spikelist[k])):
try:
len(self.spikelist[k][i])
spkl[i] = self.spikelist[k][i]
except:
spkl[i] = [np.array(self.spikelist[k][i])]
return spkl # returns all the trials in the list
def read_all_ANdata(
self, freq=10000.0, CFList=None, spontclass="HS", stim="BFTone"
):
"""
Reads a bank of AN data, across frequency and intensity, for a given stimulus frequency
Assumptions: the bank of data is consistent in terms of nReps and SPLs
:param freq: Tone stimulus frequency (if stim is Tone or BFTone)
:param CFList: A list of the CFs to read
:param spontclass: 'HS', 'MS', or 'LS', corresponding to spont rate groups
:param stim: Stimulus type - Tone/BFTone, or Noise
:return: Nothing. The data are stored in an array accessible directly or through a selector function
"""
self.all_AN = OrderedDict(
[(f, None) for f in CFList]
) # access through dictionary with keys as freq
for cf in CFList:
self.read_AN_data(
freq=freq, CF=cf, spontclass=spontclass, stim=stim, setflag=False
)
self.all_AN[cf] = self.spikelist
self.data_read_flag = True # inform and block
def retrieve_from_all_AN(self, cf, SPL):
"""
:param cf:
:param SPL:
:return: spike list (all nreps) at the cf and SPL requested, for the loaded stimulus set
"""
if not self.data_read_flag or self.all_AN is None:
print("get_anspikes::retrieve_from_all_AN: No data read yet: ")
exit()
ispl = int(np.where(self.SPLs == float(SPL))[0])
icf = self.all_AN.keys().index(cf)
spikelist = self.all_AN[cf][ispl]
spkl = [[]] * len(spikelist)
# print len(spikelist)
for i in xrange(len(spikelist)): # across trials
try:
len(spikelist[i])
spkl[i] = spikelist[i]
except:
spkl[i] = [np.array(spikelist[i])]
# print 'cf: %f spl: %d nspk: %f' % (cf, SPL, len(spkl[i]))
return spikelist
def read_AN_data(
self,
freq=10000,
CF=5300,
spontclass="HS",
display=False,
stim="BFTone",
setflag=True,
ignoreflag=True,
):
"""
read responses of auditory nerve model of Zilany et al. (2009).
display = True plots the psth's for all ANF channels in the dataset
Request response to stimulus at Freq, for fiber with CF
and specified spont rate
This version is for rate-intensity runs March 2014.
Returns: Nothing
"""
if not ignoreflag:
assert self.data_read_flag == False
# print 'Each instance of ManageANSPikes is allowed ONE data set to manage'
# print 'This is a design decision to avoid confusion about which data is in the instance'
# exit()
if stim in ["BFTone", "Tone"]:
fname = "%s_F%06.3f_CF%06.3f_%2s.mat" % (
self.dataType,
freq / 1000.0,
CF / 1000.0,
spontclass,
)
elif stim == "Noise":
fname = "%s_Noise_CF%06.3f_%2s.mat" % (
self.dataType,
CF / 1000.0,
spontclass,
)
# print 'Reading: %s' % (fname)
try:
mfile = scipy.io.loadmat(
os.path.join(self.data_dir, fname), squeeze_me=True
)
except IOError:
print("get_anspikes::read_AN_data: Failed to find data file %s" % (fname))
print(
"Corresponding to Freq: %f CF: %f spontaneous rate class: %s"
% (freq, CF, spontclass)
)
exit()
n_spl = len(mfile[self.dataType])
if display:
n = int(np.sqrt(n_spl)) + 1
mg = np.meshgrid(range(n), range(n))
mg = [mg[0].flatten(), mg[1].flatten()]
spkl = [[]] * n_spl
spl = np.zeros(n_spl)
for k in range(n_spl):
spkl[k] = mfile[self.dataType]["data"][k] # get data for one SPL
spl[k] = mfile[self.dataType]["SPL"][k] # get SPL for this set of runs
# print spkl[k].shape
if display:
self.display(spkl[k], k, n, mg)
if display:
MP.show()
self.spikelist = spkl # save these so we can have a single point to parse them
self.SPLs = spl
self.n_reps = mfile[self.dataType]["nrep"]
if setflag:
self.data_read_flag = True
# return spkl, spl, mfile['RI']['nrep']
def getANatSPL(self, spl=None):
if not self.dataRead:
print("getANatSPL: No data read yet")
return None
try:
k = int(np.where(self.SPLs == spl)[0])
except (ValueError, TypeError) as e:
print(
"get_anspikes::getANatSPL: spl=%6.1f not in list of spls:" % (spl),
self.SPLs,
)
exit() # no match
# now clean up array to keep it iterable upon return
spkl = [[]] * len(self.spikelist[k])
for i in xrange(len(self.spikelist[k])):
try:
len(self.spikelist[k][i])
spkl[i] = self.spikelist[k][i]
except:
spkl[i] = [np.array(self.spikelist[k][i])]
return spkl # returns all the trials in the list
def get_AN_info(self):
"""
:return: Dictionary of nReps and SPLs that are in the current data set (instance).
"""
if not self.data_read_flag:
print("getANatSPL: No data read yet")
return None
return {"nReps": self.n_reps, "SPLs": self.SPLs}
def read_cmrr_data(self, P, signal="S0", display=False):
"""
read responses of auditory nerve model of Zilany et al. (2009).
The required parameters are passed via the class P.
This includes:
the SN (s2m) is the signal-to-masker ratio to select
the mode is 'CMR', 'CMD' (deviant) or 'REF' (reference, no flanking bands),
'Tone', or 'Noise'
the Spont rate group
signal is either 'S0' (signal present) or 'NS' (no signal)
display = True plots the psth's for all ANF channels in the dataset
Returns:
tuple of (spikelist and frequency list for modulated tones)
"""
if P.modF == 10:
datablock = "%s_F4000.0" % (P.mode)
else:
datablock = "%s_F4000.0_M%06.1f" % (
P.mode,
P.modF,
) # 100 Hz modulation directory
# print P.mode
fs = os.listdir(self.datadir + datablock)
s_sn = "SN%03d" % (P.s2m)
fntemplate = "(\S+)_%s_%s_%s_%s.mat" % (s_sn, P.mode, signal, P.SR)
p = re.compile(fntemplate)
fl = [re.match(p, file) for file in fs]
fl = [f.group(0) for f in fl if f != None] # returns list of files matching...
fr = [float(f[1:7]) for f in fl] # frequency list
if display:
self.makeFig()
i = 0
j = 0
spkl = [[]] * len(fl)
n = int(np.sqrt(len(fl)))
mg = np.meshgrid(range(n), range(n))
mg = [mg[0].flatten(), mg[1].flatten()]
for k, fi in enumerate(fl):
mfile = scipy.io.loadmat(
self.datadir + datablock + "/" + fi, squeeze_me=True
)
spkl[k] = mfile["Cpsth"]
if display:
self.display(spkl[k], k, n, mg)
if display:
MP.show()
return (spkl, fr)
def make_fig(self):
self.fig = MP.figure(1)
def combine_reps(self, spkl):
"""
Just turns the spike list into one big linear sequence.
:param spkl: a spike train (nreps of numpy arrays)
:return: all the spikes in one long array.
"""
# print spkl
allsp = np.array(self.flatten(spkl))
# print allsp.shape
if allsp.shape == (): # nothing to show
return
# print allsp
spks = []
for i in range(len(allsp)):
if allsp[i] == ():
continue
if isinstance(allsp[i], float):
spks.append(allsp[i])
else:
spks.extend(np.array(allsp[i]))
return spks
def display(self, spkl, k, n, mg):
# print spkl
spks = self.combine_reps(spkl)
if spks != []:
spks = np.sort(spks)
MP.hist(spks, 100)
return
def flatten(self, x):
"""
flatten(sequence) -> list
Returns a single, flat list which contains all elements retrieved
from the sequence and all recursively contained sub-sequences
(iterables).
"""
result = []
if isinstance(x, float) or len([x]) <= 1:
return x
for el in x:
if hasattr(el, "__iter__") and not isinstance(el, basestring):
result.extend(self.flatten(el))
else:
result.append(el)
return result
if __name__ == "__main__":
from cnmodel.util import Params
modes = ["CMR", "CMD", "REF"]
sr_types = ["H", "M", "L"]
sr_type = sr_types[0]
sr_names = {
"H": "HS",
"M": "MS",
"L": "LS",
} # spontaneous rate groups (AN fiber input selection)
# define params for reading/testing. This is a subset of params...
P = Params(
mode=modes[0],
s2m=0,
n_rep=0,
start_rep=0,
sr=sr_names["L"],
modulation_freq=10,
DS_phase=0,
dataset="test",
sr_types=sr_types,
sr_names=sr_names,
fileVersion=1.5,
)
manager = ManageANSpikes() # create instance of the manager
print(dir(manager))
test = "RI"
if test == "all":
manager.read_all_ANdata(
freq=10000.0, CFList=manager.CF_map, spontclass="HS", stim="BFTone"
)
# spkl = manager.retrieve_from_all_AN(manager.CF_map[0], 40.)
# spks = manager.combine_reps(spkl)
manager.plot_Rate_vs_F(freq=manager.CF_map[0], display=True, spontclass="HS")
MP.show()
if test == manager.dataType:
spikes = manager.getANatFandSPL(spontclass="MS", freq=10000.0, CF=None, SPL=50)
spks = manager.combine_reps(spikes)
if spks != []:
print("spikes.")
manager.plot_RI_vs_F(freq=10000.0, display=True, spontclass="MS")
MP.figure(11)
spks = np.sort(spks)
MP.hist(spks, 100)
MP.show()

402
cnmodel/util/matlab_proc.py Normal file
View File

@@ -0,0 +1,402 @@
from __future__ import print_function
"""
Simple system for interfacing with a MATLAB process using stdin/stdout pipes.
"""
from .process import Process
from io import StringIO
import scipy.io
import numpy as np
import tempfile
import os, sys, glob, signal
import weakref
import atexit
class MatlabProcess(object):
""" This class starts a new matlab process, allowing remote control of
the interpreter and transfer of data between python and matlab.
"""
# Implements a MATLAB CLI that is a bit easier for us to parse
_bootstrap = r"""
while true
fprintf('\n::ready\n');
% Accumulate command lines until we see ::cmd_done
cmd = '';
empty_count = 0;
while true
line = input('', 's');
if length(line) == 0
% If we encounter many empty lines, assume the host process
% has ended. Is there a better way to detect this?
empty_count = empty_count + 1;
if empty_count > 100
fprintf('Shutting down!\n');
exit;
end
end
if strcmp(line, '::cmd_done')
break
end
cmd = [cmd, line, sprintf('\n')];
end
% Evaluate command
try
%fprintf('EVAL: %s\n', cmd);
eval(cmd);
fprintf('\n::ok\n');
catch err
fprintf('\n::err\n');
fprintf(['::message:', err.message, '\n']);
fprintf(['::identifier:', err.identifier, '\n']);
for i = 1:length(err.stack)
frame = err.stack(i,1);
fprintf(['::stack:', frame.name, ' in ', frame.file, ' line ', frame.line]);
end
end
end
"""
def __init__(self, executable=None, **kwds):
self.__closed = False
self.__refs = weakref.WeakValueDictionary()
# Decide which executables to try
if executable is not None:
execs = [executable]
else:
execs = ["matlab"] # always pick the matlab in the path, if available
if sys.platform == "darwin":
installed = glob.glob("/Applications/MATLAB_R*")
installed.sort(reverse=True)
execs.extend([os.path.join(p, "bin", "matlab") for p in installed])
# try starting each in order until one works
self.__proc = None
for exe in execs:
try:
self.__proc = Process([exe, "-nodesktop", "-nosplash"], **kwds)
break
except Exception as e:
last_exception = e
pass
# bail out if we couldn't start any
if self.__proc is None:
raise RuntimeError(
"Could not start MATLAB process.\nPaths attempted: %s "
"\nLast error: %s" % (str(execs), str(e))
)
# Wait a moment for MATLAB to start up,
# read the version string
while True:
line = self.__proc.stdout.readline()
if "Copyright" in line:
# next line is version info
self.__version_str = self.__proc.stdout.readline().strip()
break
# start input loop
self.__proc.stdin.write(self._bootstrap)
# wait for input loop to be ready
while True:
line = self.__proc.stdout.readline()
if line == "::ready\n":
break
atexit.register(self.close)
def __call__(self, cmd, parse_result=True):
"""
Execute the specified statement(s) on the MATLAB interpreter and return
the output string or raise an exception if there was an error.
"""
assert not self.__closed, "MATLAB process has already closed."
if cmd[-1] != "\n":
cmd += "\n"
cmd += "::cmd_done\n"
self.__proc.stdout.read()
self.__proc.stdin.write(cmd)
if parse_result:
return self._parse_result()
def _parse_result(self):
assert not self.__closed, "MATLAB process has already closed."
output = []
while True:
line = self.__proc.stdout.readline()
if line == "::ready\n":
break
output.append(line)
for i in reversed(range(len(output))):
line = output[i]
if line == "::ok\n":
return "".join(output[:i])
elif line == "::err\n":
raise MatlabError(output[i + 1 :], output[:i])
raise RuntimeError("No success/failure code found in output (printed above).")
def _get(self, name):
"""
Transfer an object from MATLAB to Python.
"""
assert isinstance(name, str)
tmp = tempfile.mktemp(suffix=".mat")
out = self("save('%s', '%s', '-v7')" % (tmp, name))
objs = scipy.io.loadmat(tmp)
os.remove(tmp)
return objs[name]
def _set(self, **kwds):
"""
Transfer an object from Python to MATLAB and assign it to the given
variable name.
"""
tmp = tempfile.mktemp(suffix=".mat")
scipy.io.savemat(tmp, kwds)
self("load('%s')" % tmp)
os.remove(tmp)
def _get_via_pipe(self, name):
"""
Transfer an object from MATLAB to Python.
This method sends data over the pipe, but is less reliable than get().
"""
assert isinstance(name, str)
out = self("save('stdio', '%s', '-v7')" % name)
start = stop = None
for i, line in enumerate(out):
if line.startswith("start_binary"):
start = i
elif line.startswith("::ok"):
stop = i
data = "".join(out[start + 1 : stop])
io = StringIO(data[:-1])
objs = scipy.io.loadmat(io)
return objs[name]
def _set_via_pipe(self, **kwds):
"""
Transfer an object from Python to MATLAB and assign it to the given
variable name.
This method sends data over the pipe, but is less reliable than set().
"""
assert not self.__closed, "MATLAB process has already closed."
io = StringIO()
scipy.io.savemat(io, kwds)
io.seek(0)
strn = io.read()
self.__proc.stdout.read()
self.__proc.stdin.write("load('stdio')\n::cmd_done\n")
while True:
line = self.__proc.stdout.readline()
if line == "ack load stdio\n":
# now it is safe to send data
break
self.__proc.stdin.write(strn)
self.__proc.stdin.write("\n")
while True:
line = self.__proc.stdout.readline()
if line == "ack load finished\n":
break
self._parse_result()
def exist(self, name):
if name == "exist":
return 5
else:
for line in self("exist %s" % name).split("\n"):
try:
return int(line.strip())
except ValueError:
pass
def __getattr__(self, name):
if name not in self.__refs:
ex = self.exist(name)
if ex == 0:
raise AttributeError("No object named '%s' in matlab workspace." % name)
elif ex in (2, 3, 5):
r = MatlabFunction(self, name)
elif ex == 1:
r = self._mkref(name)
else:
typ = {4: "library file", 6: "P-file", 7: "folder", 8: "class"}[ex]
raise TypeError("Variable '%s' has unsupported type '%s'" % (name, typ))
self.__refs[name] = r
return self.__refs[name]
def _mkref(self, name):
assert name not in self.__refs
ref = MatlabReference(self, name)
self.__refs[name] = ref
return ref
def __setattr__(self, name, value):
if name.startswith("_MatlabProcess__"):
object.__setattr__(self, name, value)
else:
self._set(**{name: value})
def close(self):
if self.__closed:
return
self("exit;\n", parse_result=False)
self.__closed = True
class MatlabReference(object):
""" Reference to a variable in the matlab workspace.
"""
def __init__(self, proc, name):
self._proc = proc
self._name = name
@property
def name(self):
return self._name
def get(self):
return self._proc._get(self._name)
def clear(self):
self._proc("clear %s;" % self._name)
def __del__(self):
self.clear()
class MatlabFunction(object):
"""
Proxy to a MATLAB function.
Calling this object transfers the arguments to MATLAB, invokes the function
remotely, and then transfers the return value back to Python.
"""
def __init__(self, proc, name):
self._proc = proc
self._name = name
self._nargout = None
@property
def nargout(self):
""" Number of output arguments for this function.
For some functions, requesting nargout() will fail. In these cases,
the nargout property must be set manually before calling the function.
"""
if self._nargout is None:
cmd = "fprintf('%%d\\n', nargout('%s'));" % (self._name)
ret = self._proc(cmd)
self._nargout = int(ret.strip())
return self._nargout
@nargout.setter
def nargout(self, n):
self._nargout = n
def __call__(self, *args, **kwds):
"""
Call this function with the given arguments.
If _transfer is False, then the return values are left in MATLAB and
references to these values are returned instead.
"""
import pyqtgraph as pg
_transfer = kwds.pop("_transfer", True)
assert len(kwds) == 0
# for generating unique variable names
rand = np.random.randint(1e12)
# store args to temporary variables, excluding those already present
# in the workspace
argnames = []
upload = {}
for i, arg in enumerate(args):
if isinstance(arg, MatlabReference):
argnames.append(arg.name)
elif np.isscalar(arg):
argnames.append(repr(arg))
else:
argname = "%s_%d_%d" % (self._name, i, rand)
argnames.append(argname)
upload[argname] = arg
if len(upload) > 0:
self._proc._set(**upload)
try:
# get number of output args
nargs = self.nargout
# invoke function, fetch return value(s)
retvars = ["%s_rval_%d_%d" % (self._name, i, rand) for i in range(nargs)]
cmd = "[%s] = %s(%s);" % (",".join(retvars), self._name, ",".join(argnames))
self._proc(cmd)
if _transfer:
ret = [self._proc._get(var) for var in retvars]
self._proc("clear %s;" % (" ".join(retvars)))
else:
ret = [self._proc._mkref(name) for name in retvars]
if len(ret) == 1:
ret = ret[0]
else:
ret = tuple(ret)
return ret
finally:
# clear all temp variables
clear = list(upload.keys())
# if _transfer:
# clear += retvars
if len(clear) > 0:
cmd = "clear %s;" % (" ".join(clear))
self._proc(cmd)
return ret
class MatlabError(Exception):
def __init__(self, error, output):
self.output = "".join(output)
for line in error:
self.stack = []
if line.startswith("::message:"):
self.message = line[10:].strip()
elif line.startswith("::identifier:"):
self.identifier = line[13:].strip()
elif line.startswith("::stack:"):
self.stack.append(line[8:].strip())
def __repr__(self):
return "MatlabError(message=%s, identifier=%s)" % (
repr(self.message),
repr(self.identifier),
)
def __str__(self):
if len(self.stack) > 0:
stack = "\nMATLAB Stack:\n%s\nMATLAB Error: " % "\n".join(self.stack)
return stack + self.message
else:
return self.message
if __name__ == "__main__":
p = MatlabProcess()
io = StringIO()
scipy.io.savemat(io, {"x": 1})
io.seek(0)
strn = io.read()

218
cnmodel/util/nrnutils.py Normal file
View File

@@ -0,0 +1,218 @@
from __future__ import print_function
"""
Wrapper classes to make working with NEURON easier.
Author: Andrew P. Davison, UNIC, CNRS
"""
__version__ = "0.3.0"
from neuron import nrn, h, hclass
h.load_file("stdrun.hoc")
PROXIMAL = 0
DISTAL = 1
class Mechanism(object):
"""
Examples:
>>> leak = Mechanism('pas', {'e': -65, 'g': 0.0002})
>>> hh = Mechanism('hh')
added set_parameters to allow post-instantiation parameter modification
"""
def __init__(self, name, **parameters):
self.name = name
self.parameters = parameters
def set_parameters(self, parameters):
self.parameters = parameters
def insert_into(self, section):
section.insert(self.name)
for name, value in self.parameters.items():
for segment in section:
mech = getattr(segment, self.name)
setattr(mech, name, value)
class Section(nrn.Section):
"""
Examples:
>>> soma = Section(L=30, diam=30, mechanisms=[hh, leak])
>>> apical = Section(L=600, diam=2, nseg=5, mechanisms=[leak],
... parent=soma, connection_point=DISTAL)
"""
def __init__(
self,
L,
diam,
nseg=1,
Ra=100,
cm=1,
mechanisms=[],
parent=None,
connection_point=DISTAL,
):
nrn.Section.__init__(self)
# set geometry
self.L = L
self.diam = diam
self.nseg = nseg
# set cable properties
self.Ra = Ra
self.cm = cm
# connect to parent section
if parent:
self.connect(parent, connection_point, PROXIMAL)
# add ion channels
for mechanism in mechanisms:
mechanism.insert_into(self)
def add_synapses(self, label, type, locations=[0.5], **parameters):
if hasattr(self, label):
raise Exception("Can't overwrite synapse labels (to keep things simple)")
synapse_group = []
for location in locations:
synapse = getattr(h, type)(location, sec=self)
for name, value in parameters.items():
setattr(synapse, name, value)
synapse_group.append(synapse)
if len(synapse_group) == 1:
synapse_group = synapse_group[0]
setattr(self, label, synapse_group)
add_synapse = add_synapses # for backwards compatibility
def plot(self, variable, location=0.5, tmin=0, tmax=5, xmin=-80, xmax=40):
import neuron.gui
self.graph = h.Graph()
h.graphList[0].append(self.graph)
self.graph.size(tmin, tmax, xmin, xmax)
self.graph.addvar("%s(%g)" % (variable, location), sec=self)
def record_spikes(self, threshold=-30):
self.spiketimes = h.Vector()
self.spikecount = h.APCount(0.5, sec=self)
self.spikecount.thresh = threshold
self.spikecount.record(self.spiketimes)
def alias(attribute_path):
"""
Returns a new property, mapping an attribute nested in an object hierarchy
to a simpler name
For example, suppose that an object of class A has an attribute b which
itself has an attribute c which itself has an attribute d. Then placing
e = alias('b.c.d')
in the class definition of A makes A.e an alias for A.b.c.d
"""
parts = attribute_path.split(".")
attr_name = parts[-1]
attr_path = parts[:-1]
def set(self, value):
obj = reduce(getattr, [self] + attr_path)
setattr(obj, attr_name, value)
def get(self):
obj = reduce(getattr, [self] + attr_path)
return getattr(obj, attr_name)
return property(fset=set, fget=get)
def uniform_property(section_list, attribute_path):
"""
Define a property that will have a uniform value across a list of sections.
For example, suppose we define a neuron model as a class A, which contains
three compartments: soma, dendrite and axon. Then placing
gnabar = uniform_property(["soma", "axon"], "hh.gnabar")
in the class definition of A means that setting a.gnabar (where a is an
instance of A) will set the value of hh.gnabar in both the soma and axon, i.e.
a.gnabar = 0.01
is equivalent to:
for sec in [a.soma, a.axon]:
for seg in sec:
seg.hh.gnabar = 0.01
"""
parts = attribute_path.split(".")
attr_name = parts[-1]
attr_path = parts[:-1]
def set(self, value):
for sec_name in section_list:
sec = getattr(self, sec_name)
for seg in sec:
obj = reduce(getattr, [seg] + attr_path)
setattr(obj, attr_name, value)
def get(self):
sec = getattr(self, section_list[0])
obj = reduce(getattr, [sec(0.5)] + attr_path)
return getattr(obj, attr_name)
return property(fset=set, fget=get)
if __name__ == "__main__":
class SimpleNeuron(object):
def __init__(self):
# define ion channel parameters
leak = Mechanism("pas", e=-65, g=0.0002)
hh = Mechanism("hh")
# create cable sections
self.soma = Section(L=30, diam=30, mechanisms=[hh])
self.apical = Section(
L=600,
diam=2,
nseg=5,
mechanisms=[leak],
parent=self.soma,
connection_point=DISTAL,
)
self.basilar = Section(
L=600,
diam=2,
nseg=5,
mechanisms=[leak],
parent=self.soma,
connection_point=0.5,
)
self.axon = Section(
L=1000, diam=1, nseg=37, mechanisms=[hh], connection_point=0
)
# synaptic input
self.soma.add_synapses("syn", "AlphaSynapse", onset=0.5, gmax=0.05, e=0)
gnabar = uniform_property(["soma", "axon"], "hh.gnabar")
gkbar = uniform_property(["soma", "axon"], "hh.gkbar")
neuron = SimpleNeuron()
neuron.soma.plot("v")
neuron.apical.plot("v")
print("gNa_bar: ", neuron.gnabar)
neuron.gnabar = 0.15
assert neuron.soma(0.5).hh.gnabar == 0.15
h.dt = 0.025
v_init = -65
tstop = 5
h.finitialize(v_init)
h.run()

111
cnmodel/util/process.py Normal file
View File

@@ -0,0 +1,111 @@
from __future__ import print_function
"""
Utility class for spawning and controlling CLI processes.
See: http://stackoverflow.com/questions/375427/non-blocking-read-on-a-subprocess-pipe-in-python
"""
import sys, time
from subprocess import PIPE, Popen
from threading import Thread
try:
from Queue import Queue, Empty
except ImportError:
from queue import Queue, Empty # python 3.x
ON_POSIX = "posix" in sys.builtin_module_names
class Process(object):
"""
Process encapsulates a subprocess with queued stderr/stdout pipes.
Wraps most methods from subprocess.Popen.
For non-blocking reads, use proc.stdout.get_nowait() or
proc.stdout.get(timeout=0.1).
"""
def __init__(self, exec_args, cwd=None):
self.proc = Popen(
exec_args,
stdout=PIPE,
stdin=PIPE,
stderr=PIPE,
bufsize=1,
close_fds=ON_POSIX,
universal_newlines=True,
cwd=cwd,
)
self.stdin = self.proc.stdin
# self.stdin = PipePrinter(self.proc.stdin)
self.stdout = PipeQueue(self.proc.stdout)
self.stderr = PipeQueue(self.proc.stderr)
for method in ["poll", "wait", "send_signal", "kill", "terminate"]:
setattr(self, method, getattr(self.proc, method))
class PipePrinter(object):
""" For debugging writes to a pipe.
"""
def __init__(self, pipe):
self._pipe = pipe
def __getattr__(self, attr):
return getattr(self._pipe, attr)
def write(self, strn):
print("WRITE:" + repr(strn))
return self._pipe.write(strn)
class PipeQueue(Queue):
"""
Queue that starts a second process to monitor a PIPE for new data.
This is needed to allow non-blocking pipe reads.
"""
def __init__(self, pipe):
Queue.__init__(self)
self.thread = Thread(target=self.enqueue_output, args=(pipe, self))
self.thread.daemon = True # thread dies with the program
self.thread.start()
@staticmethod
def enqueue_output(out, queue):
for line in iter(out.readline, b""):
queue.put(line)
# print "READ: " + repr(line)
out.close()
def read(self):
"""
Read all available lines from the queue, concatenated into a single string.
"""
out = ""
while True:
try:
out += self.get_nowait()
except Empty:
break
return out
def readline(self, timeout=None):
""" Read a single line from the queue.
"""
# we break this up into multiple short reads to allow keyboard
# interrupts
start = time.time()
ret = ""
while True:
if timeout is not None:
remaining = start + timeout - time.time()
if remaining <= 0:
return ""
else:
remaining = 1
try:
return self.get(timeout=min(0.1, remaining))
except Empty:
pass

734
cnmodel/util/pynrnutilities.py Executable file
View File

@@ -0,0 +1,734 @@
from __future__ import print_function
#!/usr/bin/python
#
# utilities for NEURON, in Python
# Module neuron for cnmodel
import numpy as np
import numpy.ma as ma # masked array
import re, sys, gc, collections
import neuron
_mechtype_cache = None
def all_mechanism_types():
"""Return a dictionary of all available mechanism types.
Each dictionary key is the name of a mechanism and each value is
another dictionary containing information about the mechanism::
mechanism_types = {
'mech_name1': {
'point_process': bool,
'artificial_cell': bool,
'netcon_target': bool,
'has_netevent': bool,
'internal_type': int,
'globals': {name:size, ...},
'parameters': {name:size, ...},
'assigned': {name:size, ...},
'state': {name:size, ...},
},
'mech_name2': {...},
'mech_name3': {...},
...
}
* point_process: False for distributed mechanisms, True for point
processes and artificial cells.
* artificial_cell: True for artificial cells, False otherwise
* netcon_target: True if the mechanism can receive NetCon events
* has_netevent: True if the mechanism can emit NetCon events
* internal_type: Integer specifying the NEURON internal type index of
the mechanism
* globals: dict of the name and vector size of the mechanism's global
variables
* parameters: dict of the name and vector size of the mechanism's
parameter variables
* assigned: dict of the name and vector size of the mechanism's
assigned variables
* state: dict of the name and vector size of the mechanism's state
variables
Note: The returned data structure is cached; do not modify it.
For more information on global, parameter, assigned, and state
variables see:
http://www.neuron.yale.edu/neuron/static/docs/help/neuron/nmodl/nmodl.html
"""
global _mechtype_cache
if _mechtype_cache is None:
_mechtype_cache = collections.OrderedDict()
mname = neuron.h.ref("")
# Iterate over two mechanism types (distributed, point/artificial)
for i in [0, 1]:
mt = neuron.h.MechanismType(i)
nmech = int(mt.count())
# Iterate over all mechanisms of this type
for j in range(nmech):
mt.select(j)
mt.selected(mname)
# General mechanism properties
name = mname[0] # convert hoc string ptr to python str
desc = {
"point_process": bool(i),
"netcon_target": bool(mt.is_netcon_target(j)),
"has_netevent": bool(mt.has_net_event(j)),
"artificial_cell": bool(mt.is_artificial(j)),
"internal_type": int(mt.internal_type()),
}
# Collect information about 4 different types of variables
for k, ptype in [
(-1, "globals"),
(1, "parameters"),
(2, "assigned"),
(3, "state"),
]:
desc[ptype] = {} # collections.OrderedDict()
ms = neuron.h.MechanismStandard(name, k)
for l in range(int(ms.count())):
psize = ms.name(mname, l)
pname = mname[0] # parameter name
desc[ptype][pname] = int(psize)
# Assemble everything in one place
_mechtype_cache[name] = desc
return _mechtype_cache
def reset(raiseError=True):
"""Introspect the NEURON kernel to verify that no objects are left over
from previous simulation runs.
"""
# Release objects held by an internal buffer
# See https://www.neuron.yale.edu/phpBB/viewtopic.php?f=2&t=3221
neuron.h.Vector().size()
# Make sure nothing is hanging around in an old exception or because of
# reference cycles
# sys.exc_clear()
gc.collect(2)
neuron.h.Vector().size()
numsec = 0
remaining = []
n = len(list(neuron.h.allsec()))
if n > 0:
remaining.append((n, "Section"))
n = len(neuron.h.List("NetCon"))
if n > 0:
remaining.append((n, "NetCon"))
# No point processes or artificial cells left
for name, typ in all_mechanism_types().items():
if typ["artificial_cell"] or typ["point_process"]:
n = len(neuron.h.List(name))
if n > 0:
remaining.append((n, name))
if (
len(remaining) > 0 and raiseError
): # note that not raising the error leads to memory leak
msg = "Cannot reset--old objects have not been cleared: %s" % ", ".join(
["%d %s" % rem for rem in remaining]
)
raise RuntimeError(msg)
def custom_init(v_init=-60.0):
"""
Perform a custom initialization of the current model/section.
This initialization follows the scheme outlined in the
NEURON book, 8.4.2, p 197 for initializing to steady state.
N.B.: For a complex model with dendrites/axons and different channels,
this initialization will not find the steady-state the whole cell,
leading to initial transient currents. In that case, this initialization
should be followed with a 0.1-5 second run (depends on the rates in the
various channel mechanisms) with no current injection or active point
processes to allow the system to settle to a steady- state. Use either
h.svstate or h.batch_save to save states and restore them. Batch is
preferred
Parameters
----------
v_init : float (default: -60 mV)
Voltage to start the initialization process. This should
be close to the expected resting state.
"""
inittime = -1e10
tdt = neuron.h.dt # save current step size
dtstep = 1e9
neuron.h.finitialize(v_init)
neuron.h.t = inittime # set time to large negative value (avoid activating
# point processes, we hope)
tmp = neuron.h.cvode.active() # check state of variable step integrator
if tmp != 0: # turn off CVode variable step integrator if it was active
neuron.h.cvode.active(0) # now just use backward Euler with large step
neuron.h.dt = dtstep
n = 0
while neuron.h.t < -1e9: # Step forward
neuron.h.fadvance()
n += 1
# print('advances: ', n)
if tmp != 0:
neuron.h.cvode.active(1) # restore integrator
neuron.h.t = 0
if neuron.h.cvode.active():
neuron.h.cvode.re_init() # update d(state)/dt and currents
else:
neuron.h.fcurrent() # recalculate currents
neuron.h.frecord_init() # save new state variables
neuron.h.dt = tdt # restore original time step
# routine to convert conductances from nS as given elsewhere
# to mho/cm2 as required by NEURON 1/28/99 P. Manis
# units: nano siemens, soma area in um^2
#
def nstomho(ns, somaarea, refarea=None):
if refarea == None:
return 1e-9 * float(ns) / float(somaarea)
else:
return 1e9 * float(ns) / float(refarea)
def mho2ns(mho, somaarea):
return float(mho) * somaarea / 1e-9
def spherearea(dia):
"""
given diameter in microns, return sphere area in cm2
"""
r = dia * 1e-4 # convert to cm
return 4 * np.pi * r ** 2
def get_sections(h):
"""
go through all the sections and find the names of the sections and all of their
parts (ids). Returns a dict, of sec: [id0, id1...]
"""
secnames = {}
resec = re.compile("(\w+)\[(\d*)\]")
for sec in h.allsec():
g = resec.match(sec.name())
if g.group(1) not in secnames.keys():
secnames[g.group(1)] = [int(g.group(2))]
else:
secnames[g.group(1)].append(int(g.group(2)))
return secnames
def all_objects():
""" Return a dict of all objects known to NEURON.
Keys are 'Section', 'Segment', 'Mechanism', 'Vector', 'PointProcess',
'NetCon', ...
"""
objs = {}
objs["Section"] = list(h.all_sec())
objs["Segment"] = []
for sec in objs["Section"]:
objs["Segment"].extend(list(sec.allseg()))
objs["PointProcess"] = []
for seg in objs["Segment"]:
objs["PointProcess"].extend(list(seg.point_processes()))
return objs
def alpha(alpha=0.1, delay=1, amp=1.0, tdur=50.0, dt=0.010):
tvec = np.arange(0, tdur, dt)
aw = np.zeros(tvec.shape)
i = 0
for t in tvec:
if t > delay:
aw[i] = (
amp * (t - delay) * (1.0 / alpha) * np.exp(-(t - delay) / alpha)
) # alpha waveform time course
else:
aw[i] = 0.0
i += 1
return (aw, tvec)
def syns(
alpha=0.1,
rate=10,
delay=0,
dur=50,
amp=1.0,
dt=0.020,
N=1,
mindur=120,
makewave=True,
):
""" Calculate a poisson train of alpha waves
with mean rate rate, with a delay and duration (in mseco) dt in msec.
N specifies the number of such independent waveforms to sum """
deadtime = 0.7
if dur + delay < mindur:
tvec = np.arange(0.0, mindur, dt)
else:
tvec = np.arange(0.0, dur + delay, dt)
npts = len(tvec)
ta = np.arange(0.0, 20.0, dt)
aw = ta * alpha * np.exp(-ta / alpha) / alpha # alpha waveform time course
spt = [[]] * N # list of spike times
wave = np.array([]) # waveform
sptime = []
for j in range(0, N):
done = False
t = 0.0
nsp = 0
while not done:
a = np.random.sample(1)
if t < delay:
t = delay
continue
if t >= delay and t <= (delay + dur):
ti = -np.log(a) / (
rate / 1000.0
) # convert to exponential distribution with rate
if ti < deadtime:
continue
t = t + ti # running time
if t > delay + dur:
done = True
continue
if nsp is 0:
sptime = t
nsp = nsp + 1
else:
sptime = np.append(sptime, t)
nsp = nsp + 1
if j is 0:
wavej = np.zeros(len(tvec))
for i in range(0, len(sptime)):
st = int(sptime[i] / dt)
wavej[st] = wavej[st] + 1
spt[j] = sptime
if makewave:
w = np.convolve(wavej, aw / max(aw)) * amp
if len(w) < npts:
w = np.append(w, np.zeros(npts - len(w)))
if len(w) > npts:
w = w[0:npts]
if j is 0:
wave = w
else:
wave = wave + w
return (spt, wave, tvec, N)
def an_syn(
alpha=0.1,
spont=10,
driven=100,
delay=50,
dur=100,
post=20,
amp=0.1,
dt=0.020,
N=1,
makewave=True,
):
# constants for AN:
deadtime = 0.7 # min time between spikes, msec
trise = 0.2 # rise rate, ms
tfall = 0.5 # fall rate, ms
rss = driven / 1000.0 # spikes/millisecond
rr = 3 * rss # transient driven rate
rst = rss
taur = 3 # rapid decay, msec
taust = 10
ton = delay # msec
stim_end = ton + dur
trace_end = stim_end + post
tvec = np.arange(0.0, trace_end, dt) # dt is in msec, so tvec is in milliseconds
ta = np.arange(0.0, 20.0, dt)
aw = ta * alpha * np.exp(-ta / alpha) / alpha # alpha waveform time course
spt = [[]] * N # list of spike times
wave = [[]] * N # waveform
for j in range(0, N): # for each
done = False
sptime = []
qe = 0
nsp = 0
i = int(0)
if spont <= 0:
q = 1e6
else:
q = 1000.0 / spont # q is in msec (spont in spikes/second)
t = 0.0
while not done:
a = np.random.sample(1)
if t < ton:
if spont <= 0:
t = ton
continue
ti = -(np.log(a) / (spont / 1000.0)) # convert to msec
if ti < deadtime: # delete intervals less than deadtime
continue
t = t + ti
if t > ton: # if the interval would step us to the stimulus onset
t = ton # then set to to the stimulus onset
continue
if t >= ton and t < stim_end:
if t > ton:
rise = 1.0 - np.exp(-(t - ton) / trise)
else:
rise = 1.0
ra = rr * np.exp(-(t - ton) / taur)
rs = rst * np.exp(-(t - ton) / taust)
q = rise * (ra + rs + rss)
ti = -np.log(a) / (q + spont / 1000) # random.negexp(1000/q)
if ti < deadtime:
continue
t = t + ti
if t > stim_end: # only include interval if it falls inside window
t = stim_end
continue
if t >= stim_end and t <= trace_end:
if spont <= 0.0:
t = trace_end
continue
if qe is 0: # have not calculated the new qe at end of stimulus
rise = 1.0 - np.exp(-(stim_end - ton) / trise)
ra = rr * np.exp(-(stim_end - ton) / taur)
rs = rst * np.exp(-(stim_end - ton) / taust)
qe = rise * (ra + rs + rss) # calculate the rate at the end
fall = np.exp(-(t - stim_end) / tfall)
q = qe * fall
ti = -np.log(a) / (
q + spont / 1000.0
) # keeps rate from falling below spont rate
if ti < deadtime:
continue
t = t + ti
if t >= trace_end:
done = True
continue
# now add the spike time to the list
if nsp is 0:
sptime = t
nsp = nsp + 1
else:
sptime = np.append(sptime, t)
nsp = nsp + 1
# end of for loop on i
if j is 0:
wavej = np.zeros(len(tvec))
for i in range(0, len(sptime)):
st = int(sptime[i] / dt)
wavej[st] = wavej[st] + 1
spt[j] = sptime
npts = len(tvec)
if makewave:
w = np.convolve(wavej, aw / max(aw)) * amp
wave[j] = w[0:npts]
return (spt, wave, tvec, N)
def findspikes(t, v, thresh):
""" findspikes identifies the times of action potential in the trace v, with the
times in t. An action potential is simply timed at the first point that exceeds
the threshold.
"""
tm = np.array(t)
s0 = (
np.array(v) > thresh
) # np.where(v > thresh) # np.array(v) > thresh # find points above threshold
# print ('v: ', v)
dsp = tm[s0]
if dsp.shape[0] == 1:
dsp = np.array(dsp)
sd = np.append(True, np.diff(dsp) > 1.0) # find first points of spikes
if len(dsp) > 0:
sp = dsp[sd]
else:
sp = []
return sp # list of spike times.
def measure(mode, x, y, x0, x1):
""" return the mean and standard deviation of y in the window x0 to x1
"""
xm = ma.masked_outside(x, x0, x1)
ym = ma.array(y, mask=ma.getmask(xm))
if mode == "mean":
r1 = ma.mean(ym)
r2 = ma.std(ym)
if mode == "max":
r1 = ma.max(ym)
r2 = 0
if mode == "min":
r1 = ma.min(ym)
r2 = 0
if mode == "median":
r1 = ma.median(ym)
r2 = 0
if mode == "p2p": # peak to peak
r1 = ma.ptp(ym)
r2 = 0
return (r1, r2)
def mask(x, xm, x0, x1):
xmask = ma.masked_outside(xm, x0, x1)
xnew = ma.array(x, mask=ma.getmask(xmask))
return xnew.compressed()
def vector_strength(spikes, freq):
"""
Calculate vector strength and related parameters from a spike train, for the specified frequency
:param spikes: Spike train, in msec.
:param freq: Stimulus frequency in Hz
:return: a dictionary containing:
r: vector strength
n: number of spikes
R: Rayleigh coefficient
p: p value (is distribution not flat?)
ph: the circularized spike train over period of the stimulus freq, freq, in radians
d: the "dispersion" computed according to Ashida et al., 2010, etc.
"""
per = 1e3 / freq # convert from Hz to period in msec
ph = 2 * np.pi * np.fmod(spikes, per) / (per) # convert to radians within a cycle
c = np.sum(np.cos(ph)) ** 2
s = np.sum(np.sin(ph)) ** 2
vs = (1.0 / len(ph)) * np.sqrt(c + s) # standard vector strength computation
n = len(spikes)
R = n * vs # Raleigh coefficient
Rp = np.exp(-n * vs * vs) # p value for n > 50 (see Ashida et al. 2010).
d = np.sqrt(2.0 * (1 - vs)) / (2 * np.pi * freq)
return {"r": vs, "n": n, "R": R, "p": Rp, "ph": ph, "d": d}
def isi_cv2(splist, binwidth=1, t0=0, t1=300, tgrace=25):
""" compute the cv and regularity according to Young et al., J. Neurophys, 60: 1, 1988.
Analysis is limited to isi's starting at or after t0 but before t1, and ending completely
before t1 + tgrace(to avoid end effects). t1 should correspond to the
the end of the stimulus
VERSION using dictionary for cvisi
"""
cvisit = np.arange(0, t1, binwidth) # build time bins
cvisi = {} # isi is dictionary, since each bin may have different length
for i in range(0, len(splist)): # for all the traces
isit = splist[i] # get the spike times for this trial [1:-1]
if len(isit) <= 1: # need 2 spikes to get an interval
continue
isib = np.floor(isit[0:-2] / binwidth) # discreetize
isii = np.diff(splist[i]) # isis.
for j in range(0, len(isib)): # loop over possible start time bins
if (
isit[j] < t0 or isit[j] > t1 or isit[j + 1] > t1 + tgrace
): # start time and interval in the window
continue
if isib[j] in cvisi:
print("spike in bin: %d" % (isib[j]))
cvisi[isib[j]] = np.append(
cvisi[isib[j]], isii[j]
) # and add the isi in that bin
else:
cvisi[isib[j]] = isii[j] # create it
cvm = np.array([]) # set up numpy arrays for mean, std and time for cv analysis
cvs = np.array([])
cvt = np.array([])
for i in cvisi.keys(): # for each entry (possible bin)
c = [cvisi[i]]
s = c.shape
# print c
if len(s) > 1 and s[1] >= 3: # require 3 spikes in a bin for statistics
cvm = np.append(cvm, np.mean(c))
cvs = np.append(cvs, np.std(c))
cvt = np.append(cvt, i * binwidth)
return (cvisit, cvisi, cvt, cvm, cvs)
def isi_cv(splist, binwidth=1, t0=0, t1=300, tgrace=25):
""" compute the cv and regularity according to Young et al., J. Neurophys, 60: 1, 1988.
Analysis is limited to isi's starting at or after t0 but before t1, and ending completely
before t1 + tgrace(to avoid end effects). t1 should correspond to the
the end of the stimulus
Version using a list of numpy arrays for cvisi
"""
cvisit = np.arange(0, t1, binwidth) # build time bins
cvisi = [[]] * len(cvisit)
for i in range(0, len(splist)): # for all the traces
if len(splist[i]) < 2: # need at least 2 spikes
continue
isib = np.floor(
splist[i][0:-2] / binwidth
) # begining spike times for each interval
isii = np.diff(splist[i]) # associated intervals
for j in range(0, len(isib)): # loop over spikes
if (
splist[i][j] < t0 or splist[i][j] > t1 or splist[i][j + 1] > t1 + tgrace
): # start time and interval in the window
continue
cvisi[int(isib[j])] = np.append(
cvisi[int(isib[j])], isii[j]
) # and add the isi in that bin
cvm = np.array([]) # set up numpy arrays for mean, std and time for cv analysis
cvs = np.array([])
cvt = np.array([])
for i in range(0, len(cvisi)): # for each entry (possible bin)
c = cvisi[i]
if len(c) >= 3: # require 3 spikes in a bin for statistics
cvm = np.append(cvm, np.mean(c))
cvs = np.append(cvs, np.std(c))
cvt = np.append(cvt, i * binwidth)
return (cvisit, cvisi, cvt, cvm, cvs)
if __name__ == "__main__":
test = "isicv"
if test == "isicv":
""" this test is not perfect. Given an ISI, we calculate spike times
by drawing from a normal distribution whose standard deviation varies
with time, from 0 (regular) to 1 (irregular). As coded, the standard
deviation never reaches the target value because spikes fall before or
after previous spikes (thus reducing the stdev). Nonetheless, this shows
that the CV calculation works correctly. """
nreps = 500
# cv will be 0 for first 50 msec, 0.5 for next 50 msec, and 1 for next 50 msec
d = [[]] * nreps
isi = 5.0 # mean isi
# we create 100 msec of data where the CV goes from 0 to 1
maxt = 100.0
for i in range(nreps):
for j in range(int(maxt / isi) + 1):
t = float(j) * isi
sd = float(j) / isi
if sd == 0.0:
d[i] = np.append(d[i], t)
else:
d[i] = np.append(d[i], np.random.normal(t, sd, 1))
for j in range(1, 10): # add more intervals at the end
te = t + float(j) * isi
d[i] = np.append(d[i], np.random.normal(te, sd, 1))
d[i] = np.sort(d[i])
# print d[i]
# print diff(d[i])
sh = np.array([])
for i in range(len(d)):
sh = np.append(sh, np.array(d[i]))
(hist, bins) = np.histogram(sh, bins=250, range=(0, 250), new=True)
if len(bins) > len(hist):
bins = bins[0 : len(hist)]
pl.figure(1)
pl.subplot(2, 2, 1)
pl.plot(bins, hist)
(cvisit, cvisi, cvt, cvm, cvs) = isi_cv(
d, binwidth=0.5, t0=0, t1=100, tgrace=25
)
order = np.argsort(cvt)
cvt = cvt[order]
cvs = cvs[order]
cvm = cvm[order]
pl.subplot(2, 2, 2)
pl.plot(cvt, cvm)
pl.hold(True)
pl.plot(cvt, cvs)
pl.subplot(2, 2, 4)
pl.plot(cvt, cvs / cvm)
pl.show()
if test == "measure":
x = np.arange(0, 100, 0.1)
s = np.shape(x)
y = np.random.randn(s[0])
for i in range(0, 4):
print("\ni is : %d" % (i))
x0 = i * 20
x1 = x0 + 20
(r0, r1) = measure("mean", x, y, x0, x1)
print("mean: %f std: %f [0, 20]" % (r0, r1))
(r0, r1) = measure("max", x, y, x0, x1)
print("max: %f std: %f [0, 20]" % (r0, r1))
(r0, r1) = measure("min", x, y, x0, x1)
print("min: %f std: %f [0, 20]" % (r0, r1))
(r0, r1) = measure("median", x, y, x0, x1)
print("median: %f std: %f [0, 20]" % (r0, r1))
(r0, r1) = measure("p2p", x, y, x0, x1)
print("peak to peak: %f std: %f [0, 20]" % (r0, r1))
if test == "an_syn":
(s, w, t, n) = an_syn(N=50, spont=50, driven=150, post=100, makewave=True)
sh = np.array([])
for i in range(len(s)):
sh = np.append(sh, np.array(s[i]))
(hist, bins) = np.histogram(sh, bins=250, range=(0, 250), new=True)
if len(bins) > len(hist):
bins = bins[0 : len(hist)]
import pylab as pl
pl.figure(1)
pl.subplot(2, 2, 1)
pl.plot(bins, hist)
pl.subplot(2, 2, 3)
for i in range(len(w)):
pl.plot(t, w[i])
pl.hold = True
(cvisit, cvisi, cvt, cvm, cvs) = isi_cv(s)
order = np.argsort(cvt)
cvt = cvt[order]
cvs = cvs[order]
cvm = cvm[order]
pl.subplot(2, 2, 2)
pl.plot(cvt, cvs / cvm)
pl.show()
if test == "syns":
(s, w, t, n) = syns(rate=20, delay=0, dur=100.0, N=5, makewave=True)
sh = np.array([])
for i in range(len(s)):
sh = np.append(sh, np.array(s[i]))
(hist, bins) = np.histogram(sh, bins=250, range=(0, 250), new=True)
if len(bins) > len(hist):
bins = bins[0 : len(hist)]
import pylab as pl
pl.figure(1)
pl.subplot(2, 2, 1)
pl.plot(bins, hist)
pl.subplot(2, 2, 3)
pl.plot(t, w)
pl.hold = True
(cvisit, cvisi, cvt, cvm, cvs) = isi_cv(s)
order = np.argsort(cvt)
cvt = cvt[order]
cvs = cvs[order]
cvm = cvm[order]
pl.subplot(2, 2, 2)
pl.plot(cvt, cvs / cvm)
pl.show()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,28 @@
import numpy as np
import hashlib, struct
_current_seed = 0
def set_seed(seed):
"""
Set the random seed to be used globally. If a string is supplied, it
will be converted to int using hash().
This immediately seeds the numpy RNG. Any other RNGs must be seeded using
current_seed()
"""
if isinstance(seed, str):
seed = struct.unpack("=I", hashlib.md5(seed.encode("utf-8")).digest()[:4])[0]
np.random.seed(seed)
assert seed < 2 ** 64 # neuron RNG fails if seed is too large
global _current_seed
_current_seed = seed
return seed
def current_seed():
"""
Return the currently-set global random seed.
"""
return _current_seed

1415
cnmodel/util/sound.py Normal file

File diff suppressed because it is too large Load Diff

87
cnmodel/util/stim.py Normal file
View File

@@ -0,0 +1,87 @@
import numpy as np
def make_pulse(stim):
"""
Generate a pulse train for current / voltage command. Returns a tuple.
Parameters
----------
stim : dict
Holds parameters that determine stimulus shape:
* delay : time before first pulse
* Sfreq : frequency of pulses
* dur : duration of one pulse or main pulse
* predur : duration of prepulse (default should be 0 for no prepulse)
* amp : pulse amplitude
* preamp : amplitude of prepulse
* PT : delay between end of train and test pulse (0 for no test)
* NP : number of pulses
* hold : holding level (optional)
* dt : timestep
Returns
-------
w : stimulus waveform
maxt : duration of waveform
tstims : index of each pulse in the train
"""
defaults = {
"delay": 10,
"Sfreq": 50,
"dur": 100,
"predur": 0.0,
"post": 50.0,
"amp": None,
"preamp": 0.0,
"PT": 0,
"NP": 1,
"hold": 0.0,
"dt": None,
}
for k in stim:
if k not in defaults:
raise Exception("Stim parameter '%s' not accepted." % k)
defaults.update(stim)
stim = defaults
for k, v in stim.items():
if v is None:
raise Exception("Must specify stim parameter '%s'." % k)
dt = stim["dt"]
delay = int(np.floor(stim["delay"] / dt))
ipi = int(np.floor((1000.0 / stim["Sfreq"]) / dt))
pdur = int(np.floor(stim["dur"] / dt))
posttest = int(np.floor(stim["PT"] / dt))
ndur = 5
if stim["predur"] > 0.0:
predur = int(np.floor(stim["predur"] / dt))
else:
predur = 0.0
if stim["PT"] == 0:
ndur = 1
maxt = dt * (delay + predur + (ipi * (stim["NP"] + 3)) + posttest + pdur * ndur)
hold = stim.get("hold", None)
w = np.zeros(int(np.floor(maxt / dt)))
if hold is not None:
w += hold
# make pulse
tstims = [0] * int(stim["NP"])
for j in range(0, int(stim["NP"])):
prestart = delay
start = int(prestart + predur + j * ipi)
if predur > 0.0:
w[prestart : prestart + predur] = stim["preamp"]
w[start : start + pdur] = stim["amp"]
tstims[j] = start
if stim["PT"] > 0.0:
for i in range(start + posttest, start + posttest + pdur):
w[i] = stim["amp"]
w = np.append(w, 0.0)
maxt = maxt + dt
return (w, maxt, tstims)

View File

@@ -0,0 +1,208 @@
import math
import numpy as np
# import matplotlib
# import matplotlib.pyplot as plt
# import matplotlib.ticker as tckr
# import matplotlib.transforms as mtransforms
# import matplotlib.mlab as mlab
# An alpha version of the Talbot, Lin, Hanrahan tick mark generator for matplotlib.
# Described in "An Extension of Wilkinson's Algorithm for Positioning Tick Labels on Axes"
# by Justin Talbot, Sharon Lin, and Pat Hanrahan, InfoVis 2010.
# Implementation by Justin Talbot
# This implementation is in the public domain.
# Report bugs to jtalbot@stanford.edu
# A shortcoming:
# The weights used in the paper were designed for static plots where the extent of
# the tick marks unioned with the extent of the data defines the extent of the plot.
# In a plot where the extent of the plot is defined by the user (e.g. an interactive
# plot supporting panning and zooming), the weights don't work as well. In particular,
# you would want to retune them assuming that the tick labels must be inside
# the provided view range. You probably want higher weighting on simplicity and lower
# on coverage and possibly density. But I haven't experimented in any detail with this.
#
# If you do intend on using this for static plots in matplotlib, you should set
# only_inside to False in the call to Extended.extended. And then you should
# manually set your view extent to include the min and max ticks if they are outside
# the data range. This should produce the same results as the paper.
# class Extended(tckr.Locator):
class Extended:
# density is labels per inch
def __init__(self, density=1, steps=None, figure=None, range=(0, 1), axis="x"):
"""
Keyword args:
"""
self._density = density
self._figure = figure
self._axis = axis
self.range = range
if steps is None:
self._steps = [1, 5, 2, 2.5, 4, 3]
else:
self._steps = steps
def coverage(self, dmin, dmax, lmin, lmax):
range = dmax - dmin
return 1 - 0.5 * (
math.pow(dmax - lmax, 2) + math.pow(dmin - lmin, 2)
) / math.pow(0.1 * range, 2)
def coverage_max(self, dmin, dmax, span):
range = dmax - dmin
if span > range:
half = (span - range) / 2.0
return 1 - math.pow(half, 2) / math.pow(0.1 * range, 2)
else:
return 1
def density(self, k, m, dmin, dmax, lmin, lmax):
r = (k - 1.0) / (lmax - lmin)
rt = (m - 1.0) / (max(lmax, dmax) - min(lmin, dmin))
return 2 - max(r / rt, rt / r)
def density_max(self, k, m):
if k >= m:
return 2 - (k - 1.0) / (m - 1.0)
else:
return 1
def simplicity(self, q, Q, j, lmin, lmax, lstep):
eps = 1e-10
n = len(Q)
i = Q.index(q) + 1
v = (
1
if (
(lmin % lstep < eps or (lstep - lmin % lstep) < eps)
and lmin <= 0
and lmax >= 0
)
else 0
)
return (n - i) / (n - 1.0) + v - j
def simplicity_max(self, q, Q, j):
n = len(Q)
i = Q.index(q) + 1
v = 1
return (n - i) / (n - 1.0) + v - j
def legibility(self, lmin, lmax, lstep):
return 1
def legibility_max(self, lmin, lmax, lstep):
return 1
def extended(
self,
dmin,
dmax,
m,
Q=[1, 5, 2, 2.5, 4, 3],
only_inside=False,
w=[0.25, 0.2, 0.5, 0.05],
):
n = len(Q)
best_score = -2.0
j = 1.0
while j < float("infinity"):
for q in Q:
sm = self.simplicity_max(q, Q, j)
if w[0] * sm + w[1] + w[2] + w[3] < best_score:
j = float("infinity")
break
k = 2.0
while k < float("infinity"):
dm = self.density_max(k, m)
if w[0] * sm + w[1] + w[2] * dm + w[3] < best_score:
break
delta = (dmax - dmin) / (k + 1.0) / j / q
z = math.ceil(math.log(delta, 10))
while z < float("infinity"):
step = j * q * math.pow(10, z)
cm = self.coverage_max(dmin, dmax, step * (k - 1.0))
if w[0] * sm + w[1] * cm + w[2] * dm + w[3] < best_score:
break
min_start = math.floor(dmax / step) * j - (k - 1.0) * j
max_start = math.ceil(dmin / step) * j
if min_start > max_start:
z = z + 1
break
for start in range(int(min_start), int(max_start) + 1):
lmin = start * (step / j)
lmax = lmin + step * (k - 1.0)
lstep = step
s = self.simplicity(q, Q, j, lmin, lmax, lstep)
c = self.coverage(dmin, dmax, lmin, lmax)
d = self.density(k, m, dmin, dmax, lmin, lmax)
l = self.legibility(lmin, lmax, lstep)
score = w[0] * s + w[1] * c + w[2] * d + w[3] * l
if score > best_score and (
not only_inside or (lmin >= dmin and lmax <= dmax)
):
best_score = score
best = (lmin, lmax, lstep, q, k)
z = z + 1
k = k + 1
j = j + 1
return best
def __call__(self):
vmin, vmax = self.range # self.axis.get_view_interval()
fsize = {"x": 5.0, "y": 4.0}
size = fsize[self._axis] # self._figure.get_size_inches()[self._axis]
# density * size gives target number of intervals,
# density * size + 1 gives target number of tick marks,
# the density function converts this back to a density in data units (not inches)
# should probably make this cleaner.
best = self.extended(
vmin,
vmax,
self._density * size + 1.0,
only_inside=True,
w=[0.25, 0.2, 0.5, 0.05],
)
locs = np.arange(best[4]) * best[2] + best[0]
return locs
if __name__ == "__main__":
pass
# fig = plt.figure()
# ax = fig.add_subplot(111)
# ax.plot(10*np.random.randn(100), 10*np.random.randn(100), 'o')
#
# xmin, xmax = ax.xaxis.get_data_interval()
# xrange = xmax-xmin
# xmin, xmax = (xmin - xrange * 0.05, xmax + xrange * 0.05)
#
# ymin, ymax = ax.yaxis.get_data_interval()
# yrange = ymax-ymin
# ymin, ymax = (ymin - yrange * 0.05, ymax + yrange * 0.05)
#
# ax.xaxis.set_view_interval(xmin, xmax, ignore=True)
# ax.yaxis.set_view_interval(ymin, ymax, ignore=True)
# ax.xaxis.set_major_locator(Extended(density=0.5, figure=fig, which=0))
# ax.yaxis.set_major_locator(Extended(density=0.5, figure=fig, which=1))
#
# ax.set_title('Talbot, Lin, Hanrahan 2010')
#
# plt.show()

View File

@@ -0,0 +1,36 @@
from __future__ import print_function
from cnmodel.util import ExpFitting
import numpy as np
def test_fit1():
fit = ExpFitting(nexp=1)
x = np.linspace(0.0, 50, 500)
p = [-50.0, 4.0, 5.0]
y = p[0] + p[1] * np.exp(-x / p[2])
res = fit.fit(x, y, fit.fitpars)
pr = [float(res[k].value) for k in res.keys()]
print("\noriginal: ", p)
print("fit res: ", pr)
for i, v in enumerate(p):
assert np.allclose(v, pr[i])
def test_fit2():
fit = ExpFitting(nexp=2)
x = np.linspace(0.0, 50, 500)
p = [
-50.0,
4.0,
5.0,
1.0,
4.5,
] # last term is ratio of the two time constants (t2 = delta*t1)
y = p[0] + p[1] * np.exp(-x / p[2]) + p[3] * np.exp(-x / (p[2] * p[4]))
res = fit.fit(x, y, fit.fitpars)
pr = [float(res[k].value) for k in res.keys()]
print("\noriginal: ", p)
print("fit res: ", pr)
# we can only do this approximately for 2 exp fits
for i, v in enumerate(p): # test each one individually
assert np.allclose(pr[i] / v, 1.0, atol=1e-4, rtol=1e-2)

View File

@@ -0,0 +1,40 @@
import pytest
import numpy as np
# from cnmodel.util.matlab_proc import MatlabProcess
import matlab.engine
def test_matlab():
global proc
try:
# proc = MatlabProcess()
proc = matlab.engine.start_matlab()
except RuntimeError:
# no matlab available; skip this test
pytest.skip("MATLAB unavailable")
base_vcount = proc.who(nargout=1) # .shape[0]
assert len(base_vcount) == 0
e4 = np.array(proc.eye(4, nargout=1))
assert isinstance(e4, np.ndarray)
assert np.all(e4 == np.eye(4))
# assert proc.who().shape[0] == base_vcount
o6_ref = np.array(proc.ones(6, nargout=1)) # _transfer=False))
# o6 = o6_ref.get()
o6 = np.array(proc.ones(6, nargout=1))
assert np.all(o6 == np.ones(6))
# print(proc.who(nargout=1))
# assert proc.who(nargout=1) == base_vcount + 1
del o6_ref
# assert proc.who().shape[0] == base_vcount
proc.close()
if __name__ == "__main__":
test_matlab()

View File

@@ -0,0 +1,98 @@
import numpy as np
from cnmodel.util import sound
def test_conversions():
pa = np.array([3990.5, 20, 0.3639, 2e-5])
db = np.array([166, 120, 85.2, 0])
assert np.allclose(sound.pa_to_dbspl(pa), db, atol=0.1, rtol=0.002)
assert np.allclose(sound.dbspl_to_pa(db), pa, atol=0.1, rtol=0.002)
def test_tonepip():
rate = 100000
dur = 0.1
ps = 0.01
rd = 0.02
pd = 0.08
db = 60
s1 = sound.TonePip(
rate=rate,
duration=dur,
f0=5321,
dbspl=db,
pip_duration=pd,
pip_start=[ps],
ramp_duration=rd,
)
# test array sizes
assert s1.sound.size == s1.time.size == int(dur * rate) + 1
# test for consistency
assert np.allclose(
[s1.sound.min(), s1.sound.mean(), s1.sound.max()],
[-0.028284253158247834, -1.0954891976168953e-10, 0.028284270354167296],
)
# test that we got the requested amplitude
assert np.allclose(s1.measure_dbspl(ps + rd, ps + pd - rd), db, atol=0.1, rtol=0.01)
# test for quiet before and after pip
assert np.all(s1.sound[: int(ps * rate) - 1] == 0)
assert np.all(s1.sound[int((ps + pd) * rate) + 1 :] == 0)
# test the sound can be recreated from its key
key = s1.key()
s2 = sound.create(**key)
assert np.all(s1.time == s2.time)
assert np.all(s1.sound == s2.sound)
def test_noisepip():
rate = 100000
dur = 0.1
ps = 0.01
rd = 0.02
pd = 0.08
db = 60
s1 = sound.NoisePip(
rate=rate,
duration=dur,
seed=184724,
dbspl=db,
pip_duration=pd,
pip_start=[ps],
ramp_duration=rd,
)
# test array sizes
assert s1.sound.size == s1.time.size == int(dur * rate) + 1
# test for consistency
assert np.allclose(
[s1.sound.min(), s1.sound.mean(), s1.sound.max()],
[-0.082260796003197786, -0.00018484322982972046, 0.069160217220832404],
)
# test that we got the requested amplitude
assert np.allclose(s1.measure_dbspl(ps + rd, ps + pd - rd), db, atol=0.1, rtol=0.01)
# test for quiet before and after pip
assert np.all(s1.sound[: int(ps * rate) - 1] == 0)
assert np.all(s1.sound[int((ps + pd) * rate) + 1 :] == 0)
# test the sound can be recreated from its key
key = s1.key()
s2 = sound.create(**key)
# also test new seed works, and does not affect other sounds
key["seed"] += 1
s3 = sound.create(**key)
s3.sound # generate here to advance rng before generating s2
assert np.all(s1.time == s2.time)
assert np.all(s1.sound == s2.sound)
start = int(ps * rate) + 1
end = int((ps + pd) * rate) - 1
assert not np.any(s1.sound[start:end] == s3.sound[start:end])

View File

@@ -0,0 +1,29 @@
import numpy as np
from numpy.testing import assert_raises
from cnmodel.util import stim
from neuron import h
h.dt = 0.025
def test_make_pulse():
params = dict(delay=10, Sfreq=50, dur=1, amp=15, PT=0, NP=5)
assert_raises(Exception, lambda: stim.make_pulse(params))
params["dt"] = 0.025
w, maxt, times = stim.make_pulse(params)
assert w.min() == 0.0
assert w.max() == 15
assert w.dtype == np.float64
triggers = np.argwhere(np.diff(w) > 0)[:, 0] + 1
assert np.all(triggers == times)
assert w.sum() == 15 * len(times) * int(1 / h.dt)
params["PT"] = 100
w, maxt, times = stim.make_pulse(params)
triggers = np.argwhere(np.diff(w) > 0)[:, 0] + 1
assert triggers[-1] - triggers[-2] == 100 / h.dt

205
cnmodel/util/user_tester.py Normal file
View File

@@ -0,0 +1,205 @@
from __future__ import print_function
import os, sys, pickle, pprint
import numpy as np
import pyqtgraph as pg
from .. import AUDIT_TESTS
class UserTester(object):
"""
Base class for testing when a human is required to verify the results.
When a test is passed by the user, its output is saved and used as a basis
for future tests. If future test results do not match the stored results,
then the user is asked to decide whether to fail the test, or pass the
test and store new results.
Subclasses must reimplement run_test() to return a dictionary of results
to store. Optionally, compare_results and audit_result may also be
reimplemented to customize the testing behavior.
By default, test results are stored in a 'test_data' directory relative
to the file that defines the UserTester subclass in use.
"""
data_dir = "test_data"
def __init__(self, key, *args, **kwds):
"""Initialize with a string *key* that provides a short, unique
description of this test. All other arguments are passed to run_test().
*key* is used to determine the file name for storing test results.
"""
self.audit = AUDIT_TESTS
self.key = key
self.rtol = 1e-3
self.args = args
self.assert_test_info(*args, **kwds)
def run_test(self, *args, **kwds):
"""
Exceute the test. All arguments are taken from __init__.
Return a picklable dictionary of test results.
"""
raise NotImplementedError()
def compare_results(self, key, info, expect):
"""
Compare *result* of the current test against the previously stored
result *expect*. If *expect* is None, then no previous result was
stored.
If *result* and *expect* do not match, then raise an exception.
"""
# Check test structures are the same
assert type(info) is type(expect)
if hasattr(info, "__len__"):
assert len(info) == len(expect)
if isinstance(info, dict):
for k in info:
assert k in expect
for k in expect:
assert k in info
self.compare_results(k, info[k], expect[k])
elif isinstance(info, list):
for i in range(len(info)):
self.compare_results(key, info[i], expect[i])
elif isinstance(info, np.ndarray):
assert info.shape == expect.shape
if len(info) == 0:
return
# assert info.dtype == expect.dtype
if info.dtype.fields is None:
intnan = -9223372036854775808 # happens when np.nan is cast to int
inans = np.isnan(info) | (info == intnan)
enans = np.isnan(expect) | (expect == intnan)
assert np.all(inans == enans)
mask = ~inans
if not np.allclose(info[mask], expect[mask], rtol=self.rtol):
print(
"\nComparing data array, shapes match: ",
info.shape == expect.shape,
)
print("Model tested: %s, measure: %s" % (self.key, key))
# print( 'args: ', dir(self.args[0]))
print("Array expected: ", expect[mask])
print("Array received: ", info[mask])
try:
self.args[0].print_all_mechs()
except:
print("args[0] is string: ", self.args[0])
assert np.allclose(info[mask], expect[mask], rtol=self.rtol)
else:
for k in info.dtype.fields.keys():
self.compare_results(k, info[k], expect[k])
elif np.isscalar(info):
if not np.allclose(info, expect, rtol=self.rtol):
print("Comparing Scalar data, model: %s, measure: %s" % (self.key, key))
# print 'args: ', dir(self.args[0])
print(
"Expected: ",
expect,
", received: ",
info,
" relative tolerance: ",
self.rtol,
)
if isinstance(self.args[0], str):
pass
# print ': ', str
else:
self.args[0].print_all_mechs()
assert np.allclose(info, expect, rtol=self.rtol)
else:
try:
assert info == expect
except AssertionError:
raise
except Exception:
raise NotImplementedError(
"Cannot compare objects of type %s" % type(info)
)
def audit_result(self, info, expect):
""" Display results and ask the user to decide whether the test passed.
Return True for pass, False for fail.
If *expect* is None, then no previous test results were stored.
"""
app = pg.mkQApp()
print("\n=== New test results for %s: ===\n" % self.key)
pprint.pprint(info)
# we use DiffTreeWidget to display differences between large data structures, but
# this is not present in mainline pyqtgraph yet
if hasattr(pg, "DiffTreeWidget"):
win = pg.DiffTreeWidget()
else:
from cnmodel.util.difftreewidget import DiffTreeWidget
win = DiffTreeWidget()
win.resize(800, 800)
win.setData(expect, info)
win.show()
print("Store new test results? [y/n]")
yn = raw_input()
win.hide()
return yn.lower().startswith("y")
def assert_test_info(self, *args, **kwds):
"""
Test *cell* and raise exception if the results do not match prior
data.
"""
result = self.run_test(*args, **kwds)
expect = self.load_test_result()
try:
assert expect is not None
self.compare_results(None, result, expect)
except:
if not self.audit:
if expect is None:
raise Exception(
"No prior test results for test '%s'. "
"Run test.py --audit store new test data." % self.key
)
else:
raise
store = self.audit_result(result, expect)
if store:
self.save_test_result(result)
else:
raise Exception("Rejected test results for '%s'" % self.key)
def result_file(self):
"""
Return a file name to be used for storing / retrieving test results
given *self.key*.
"""
modfile = sys.modules[self.__class__.__module__].__file__
path = os.path.dirname(modfile)
return os.path.join(path, self.data_dir, self.key + ".pk")
def load_test_result(self):
"""
Load prior test results for *self.key*.
If there are no prior results, return None.
"""
fn = self.result_file()
if os.path.isfile(fn):
return pickle.load(open(fn, "rb"), encoding="latin1")
return None
def save_test_result(self, result):
"""
Store test results for *self.key*.
Th e*result* argument must be picklable.
"""
fn = self.result_file()
dirname = os.path.dirname(fn)
if not os.path.isdir(dirname):
os.mkdir(dirname)
pickle.dump(result, open(fn, "wb"))