copying to personal repo
This commit is contained in:
69
cnmodel/util/Params.py
Executable file
69
cnmodel/util/Params.py
Executable file
@@ -0,0 +1,69 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import sys
|
||||
import os
|
||||
import unittest
|
||||
|
||||
|
||||
class Params(object):
|
||||
def __init__(self, **kwds):
|
||||
"""
|
||||
Utility class to create parameter lists
|
||||
create using:
|
||||
p = Params(abc=2.0, defg = 3.0, lunch='sandwich')
|
||||
reference using:
|
||||
p.abc, p.defg, etc.
|
||||
Supports getting the keys, finding whether a key exists, returning the strucure as a simple dictionary,
|
||||
and printing (show) the parameter structure.
|
||||
"""
|
||||
self.__dict__.update(kwds)
|
||||
|
||||
def additem(self, key, value):
|
||||
self.__dict__[key] = value
|
||||
|
||||
def getkeys(self):
|
||||
"""
|
||||
Get the keys in the current dictionary
|
||||
"""
|
||||
return self.__dict__.keys()
|
||||
|
||||
def haskey(self, key):
|
||||
"""
|
||||
Find out if the param list has a specific key in it
|
||||
"""
|
||||
if key in self.__dict__.keys():
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def todict(self):
|
||||
"""
|
||||
convert param list to standard dictionary
|
||||
Useful when writing the data
|
||||
"""
|
||||
r = {}
|
||||
for dictelement in self.__dict__:
|
||||
if isinstance(self.__dict__[dictelement], Params):
|
||||
# print 'nested: ', dictelement
|
||||
r[dictelement] = self.__dict__[dictelement].todict()
|
||||
else:
|
||||
r[dictelement] = self.__dict__[dictelement]
|
||||
return r
|
||||
|
||||
def show(self, printFlag=True):
|
||||
"""
|
||||
print the parameter block created in Parameter Init
|
||||
"""
|
||||
print("-------- Parameter Block ----------")
|
||||
for key in self.__dict__.keys():
|
||||
print("%15s = " % (key), eval("self.%s" % key))
|
||||
print("-------- ---------------------- ----------")
|
||||
|
||||
|
||||
class ParamTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
1415
cnmodel/util/PlotHelpers.py
Executable file
1415
cnmodel/util/PlotHelpers.py
Executable file
File diff suppressed because it is too large
Load Diff
9
cnmodel/util/__init__.py
Normal file
9
cnmodel/util/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .stim import *
|
||||
from .find_point import *
|
||||
from .pynrnutilities import *
|
||||
from .nrnutils import *
|
||||
from .expfitting import *
|
||||
from .user_tester import UserTester
|
||||
from .get_anspikes import *
|
||||
from .Params import *
|
||||
from .talbotetalTicks import Extended
|
||||
125
cnmodel/util/ccstim.py
Normal file
125
cnmodel/util/ccstim.py
Normal file
@@ -0,0 +1,125 @@
|
||||
__author__ = "pbmanis"
|
||||
"""
|
||||
ccstim
|
||||
Generate current-clamp (or voltage-clamp) stimulus waveforms from a dictionary
|
||||
used for vectory play modes in current clamp
|
||||
(prior version was called 'makestim')
|
||||
|
||||
Can generate several types of pulses
|
||||
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def ccstim(stim, dt, pulsetype="square"):
|
||||
"""
|
||||
Create stimulus pulse waveforms of different types.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
stim : dict
|
||||
a dictionary with keys [required]
|
||||
delay (delay to start of pulse train, msec [all]
|
||||
duration: duration of pulses in train, msec [all]
|
||||
Sfreq: stimulus train frequency (Hz) [timedSpikes]
|
||||
PT: post-train test delay [all]
|
||||
NP: number of pulses in the train [timedSpikes, exp]
|
||||
amp: amplitude of the pulses in the train [all]
|
||||
hypamp: amplitude of prehyperpolarizing pulse [hyp]
|
||||
hypdur: duration of prehyperpolarizing pulse [hyp]
|
||||
spikeTimes" times of spikes [timedSpikes]
|
||||
|
||||
dt : time (microseconds) [required]
|
||||
step time, in microseconds. Required parameter
|
||||
|
||||
pulsetype : string (default: 'square')
|
||||
Type of pulse to generate: one of square, hyp, timedspikes or exp
|
||||
square produces a train of "square" (retangular) pulses levels 0 and ampitude
|
||||
hyp is like square, but precedes the pulse train with a single prepulse
|
||||
of hypamp and hypdur
|
||||
timedspikes is like square, excpet the pulses are generated at times specified
|
||||
in the spikeTimes key in the stim dictionary
|
||||
exp: pulses with an exponential decay.
|
||||
|
||||
TO DO:
|
||||
add pulsetypes, including sine wave, rectified sine wave, etc.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list containing [waveform (numpy array),
|
||||
maxtime(float),
|
||||
timebase (numpy array)]
|
||||
"""
|
||||
|
||||
assert dt is not None
|
||||
assert "delay" in stim.keys()
|
||||
delay = int(np.floor(stim["delay"] / dt))
|
||||
if pulsetype in ["square", "hyp", "exp"]:
|
||||
ipi = int(np.floor((1000.0 / stim["Sfreq"]) / dt))
|
||||
pdur = int(np.floor(stim["duration"] / dt))
|
||||
posttest = int(np.floor(stim["PT"] / dt))
|
||||
if pulsetype not in ["timedSpikes"]:
|
||||
NP = int(stim["NP"])
|
||||
else:
|
||||
NP = len(stim["spikeTimes"])
|
||||
tstims = [0] * NP
|
||||
if pulsetype == "hyp":
|
||||
assert "hypamp" in stim.keys()
|
||||
assert "hypdur" in stim.keys()
|
||||
hypdur = int(np.floor(stim["hypdur"] / dt))
|
||||
delay0 = delay # save original delay
|
||||
|
||||
if pulsetype in ["square", "hyp"]:
|
||||
maxt = dt * (stim["delay"] + (ipi * (NP + 2)) + posttest + pdur * 2)
|
||||
if pulsetype == "hyp":
|
||||
maxt = maxt + dt * stim["hypdur"]
|
||||
delay = delay + hypdur
|
||||
w = np.zeros(int(np.floor(maxt / dt)))
|
||||
for j in range(0, NP):
|
||||
t = (delay + j * ipi) * dt
|
||||
w[delay + ipi * j : delay + (ipi * j) + pdur] = stim["amp"]
|
||||
tstims[j] = delay + ipi * j
|
||||
if stim["PT"] > 0.0:
|
||||
send = delay + ipi * j
|
||||
for i in range(send + posttest, send + posttest + pdur):
|
||||
w[i] = stim["amp"]
|
||||
if pulsetype == "hyp": # fill in the prepulse now
|
||||
for i in range(delay0, delay0 + hypdur):
|
||||
w[i] = stim["hypamp"]
|
||||
|
||||
if pulsetype == "timedSpikes":
|
||||
maxt = np.max(stim["spikeTimes"]) + stim["PT"] + stim["duration"] * 2
|
||||
w = np.zeros(int(np.floor(maxt / dt)))
|
||||
for j in range(len(stim["spikeTimes"])):
|
||||
st = delay + int(np.floor(stim["spikeTimes"][j] / dt))
|
||||
t = st * dt
|
||||
w[st : st + pdur] = stim["amp"]
|
||||
tstims[j] = st
|
||||
|
||||
if stim["PT"] > 0.0:
|
||||
for i in range(st + posttest, st + posttest + pdur):
|
||||
w[i] = stim["amp"]
|
||||
|
||||
if pulsetype == "exp":
|
||||
maxt = dt * (stim["delay"] + (ipi * (NP + 2)) + posttest + pdur * 2)
|
||||
w = np.zeros(int(np.floor(maxt / dt)))
|
||||
|
||||
for j in range(0, NP):
|
||||
for i in range(0, len(w)):
|
||||
if delay + ipi * j + i < len(w):
|
||||
w[delay + ipi * j + i] = w[delay + ipi * j + i] + stim["amp"] * (
|
||||
(1.0 - np.exp(-i / (pdur / 3.0)))
|
||||
* np.exp(-(i - (pdur / 3.0)) / pdur)
|
||||
)
|
||||
tstims[j] = delay + ipi * j
|
||||
if stim["PT"] > 0.0:
|
||||
send = delay + ipi * j
|
||||
for i in range(send + posttest, len(w)):
|
||||
w[i] += (
|
||||
stim["amp"]
|
||||
* (1.0 - np.exp(-i / (pdur / 3.0)))
|
||||
* np.exp(-(i - (pdur / 3.0)) / pdur)
|
||||
)
|
||||
|
||||
return (w, maxt, tstims)
|
||||
658
cnmodel/util/compare_simple_multisynapses.py
Normal file
658
cnmodel/util/compare_simple_multisynapses.py
Normal file
@@ -0,0 +1,658 @@
|
||||
"""
|
||||
Test synaptic connections between two different cell types.
|
||||
|
||||
Usage: python compare_simple_multisynapses.py <pre_celltype> <post_celltype>
|
||||
|
||||
This script:
|
||||
|
||||
1. Creates single pre- and postsynaptic cells
|
||||
2. Creates a single synaptic terminal between the two cells, using the multisite synapse method.
|
||||
3. Stimulates the presynaptic cell by current injection.
|
||||
4. Records and analyzes the resulting post-synaptic events.
|
||||
5. Repeats 3, 4 100 times to get an average postsynaptic event.
|
||||
6. stores the resulting waveform in a pandas database
|
||||
|
||||
This is used mainly to check that the strength, kinetics, and dynamics of
|
||||
each synapse type is working as expected.
|
||||
|
||||
Requires Python 3.6
|
||||
"""
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
# import pyqtgraph as pg
|
||||
import matplotlib.pyplot as mpl
|
||||
import neuron as h
|
||||
import cnmodel.util.PlotHelpers as PH
|
||||
from cnmodel.protocols import SynapseTest
|
||||
from cnmodel import cells
|
||||
from cnmodel.synapses import Synapse
|
||||
import pickle
|
||||
import lmfit
|
||||
|
||||
|
||||
convergence = {
|
||||
"sgc": {
|
||||
"bushy": 1,
|
||||
"tstellate": 1,
|
||||
"dstellate": 1,
|
||||
"octopus": 1,
|
||||
"pyramidal": 1,
|
||||
"tuberculoventral": 1,
|
||||
},
|
||||
"dstellate": {
|
||||
"bushy": 1,
|
||||
"tstellate": 1,
|
||||
"dstellate": 1,
|
||||
"pyramidal": 1,
|
||||
"tuberculoventral": 1,
|
||||
},
|
||||
"tuberculoventral": {
|
||||
"bushy": 1,
|
||||
"tstellate": 1,
|
||||
"pyramidal": 1,
|
||||
"tuberculoventral": 1,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Exp2SynFitting:
|
||||
"""
|
||||
Fit waveform against Exp2SYN function (used by Neuron)
|
||||
THe function is (from the documentation):
|
||||
i = G * (v - e) i(nanoamps), g(micromhos);
|
||||
G = weight * factor * (exp(-t/tau2) - exp(-t/tau1))
|
||||
WHere factor is evaluated when initializing(in neuron; see exp2syn.mod source) as:
|
||||
tp = (tau1*tau2)/(tau2 - tau1) * log(tau2/tau1) (log10)
|
||||
factor = -exp(-tp/tau1) + exp(-tp/tau2)
|
||||
factor = 1/factor
|
||||
|
||||
Parameters
|
||||
----------
|
||||
initpars : dict
|
||||
dict of initial parameters. For example: {'tau1': 0.1,
|
||||
'tau2': 0.3, 'weight': 0.1, 'delay' : 0.0, 'erev': -80.} (erev in mV)
|
||||
bounds : dict
|
||||
dictionary of bounds for each parameter, with a list of lower and upper values.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, initpars=None, bounds=None, functype="exp2syn"):
|
||||
self.fitpars = lmfit.Parameters()
|
||||
if functype == "exp2syn": # future handle other functions like alpha
|
||||
# (Name, Value, Vary, Min, Max, Expr)
|
||||
self.fitpars.add_many(
|
||||
("tau1", initpars["tau1"], True, 0.05, 25.0, None),
|
||||
# ('tau2', initpars['tau2'], True, 0.1, 50., None),
|
||||
("tauratio", initpars["tauratio"], True, 1.0001, 100.0, None),
|
||||
("weight", initpars["weight"], True, 1e-6, 1, None),
|
||||
("erev", initpars["erev"], False), # do not adjust!
|
||||
("v", initpars["v"], False),
|
||||
("delay", initpars["delay"], True, 0.0, 5.0, None),
|
||||
)
|
||||
self.func = self.exp2syn_err
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def fit(self, x, y, p, verbose=False):
|
||||
"""
|
||||
Perform the curve fit against the specified function
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : time base for waveform (np array or list)
|
||||
y : waveform (1-d np array or list)
|
||||
p : parameters in lmfit Parameters structure
|
||||
verbose : boolean (default: False)
|
||||
If true, print the parameters in a nice format
|
||||
"""
|
||||
|
||||
kws = {"maxfev": 5000}
|
||||
# print('p: ', p)
|
||||
self.mim = lmfit.minimize(
|
||||
self.func, p, method="least_squares", args=(x, y)
|
||||
) # , kws=kws)
|
||||
if verbose:
|
||||
lmfit.printfuncs.report_fit(self.mim.params)
|
||||
fitpars = self.mim.params
|
||||
return fitpars
|
||||
|
||||
# @staticmethod
|
||||
def exp2syn(self, x, tau1, tauratio, weight, erev, v, delay):
|
||||
"""
|
||||
Compute the exp2syn waveform current as it is done in Neuron
|
||||
|
||||
Note on units:
|
||||
The units are assumed to be "self consistent"
|
||||
Thus, if the time base is in msec, tau1, tau2 and the delay are
|
||||
in msec. erev and v should be in matching units (e.g., mV)
|
||||
Note thtat the weight is not equal to the conducdtance G, because
|
||||
of the scaling of the waveform
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : time base
|
||||
tau1 : rising tau for waveform
|
||||
tau2 : falling tau for waveform
|
||||
weight : amplitude of the waveform (conductance)
|
||||
erev : reversal potential for ionic species (used to compute i)
|
||||
v : holding voltage at which waveform is computed
|
||||
delay : delay to start of function
|
||||
|
||||
Returns
|
||||
-------
|
||||
i, the calucated current trace for these parameters.
|
||||
"""
|
||||
# we handle the requirement that tau2 > tau1 by setting the
|
||||
# expression in lmfit, and using tauratio rather than a direct
|
||||
# tau2.
|
||||
# if tau1/tau2 > 1.0: # make sure tau1 is less than tau2
|
||||
# tau1 = 0.999*tau2
|
||||
tau2 = tauratio * tau1
|
||||
tp = (tau1 * tau2) / (tau2 - tau1) * np.log(tau2 / tau1)
|
||||
factor = -np.exp(-tp / tau1) + np.exp(-tp / tau2)
|
||||
factor = 1.0 / factor
|
||||
G = (
|
||||
weight
|
||||
* factor
|
||||
* (np.exp(-(x - delay) / tau2) - np.exp(-(x - delay) / tau1))
|
||||
)
|
||||
G[x - delay < 0] = 0.0
|
||||
i = G * (v - erev) # i(nanoamps), g(micromhos);
|
||||
|
||||
return i
|
||||
|
||||
def exp2syn_err(self, p, x, y):
|
||||
return np.fabs(
|
||||
y - self.exp2syn(x, **dict([(k, p.value) for k, p in p.items()]))
|
||||
)
|
||||
|
||||
def factor(self, tau1, tauratio, weight):
|
||||
"""
|
||||
calculate tau-scaled weight
|
||||
"""
|
||||
tau2 = tau1 * tauratio
|
||||
tp = (tau1 * tau2) / (tau2 - tau1) * np.log(tau2 / tau1)
|
||||
factor = -np.exp(-tp / tau1) + np.exp(-tp / tau2)
|
||||
factor = 1.0 / factor
|
||||
G = weight * factor
|
||||
return G
|
||||
|
||||
|
||||
def testexp():
|
||||
"""
|
||||
Test the exp2syn fitting function
|
||||
"""
|
||||
pars = {
|
||||
"tau1": 0.1,
|
||||
"tauratio": 2.0,
|
||||
"weight": 0.1,
|
||||
"erev": -70.0,
|
||||
"v": -65.0,
|
||||
"delay": 1,
|
||||
}
|
||||
F = Exp2SynFitting(
|
||||
initpars={
|
||||
"tau1": 0.2,
|
||||
"tauratio": 2.0,
|
||||
"weight": 0.1,
|
||||
"erev": -70.0,
|
||||
"v": -65.0,
|
||||
"delay": 0,
|
||||
}
|
||||
)
|
||||
t = np.arange(0, 10.0, 0.01)
|
||||
p = F.fitpars
|
||||
target = F.exp2syn(
|
||||
t,
|
||||
pars["tau1"],
|
||||
pars["tauratio"],
|
||||
pars["weight"],
|
||||
pars["erev"],
|
||||
pars["v"],
|
||||
pars["delay"],
|
||||
)
|
||||
# print(F.fitpars)
|
||||
pars_fit = F.fit(t, target, F.fitpars)
|
||||
print("\nTest fit result: ")
|
||||
lmfit.printfuncs.report_fit(F.mim.params)
|
||||
print("( tau2 = ", pars_fit["tau1"].value * pars_fit["tauratio"].value, ")")
|
||||
print("target parameters: ", pars)
|
||||
print("\n")
|
||||
|
||||
|
||||
def compute_psc(synapsetype="multisite", celltypes=["sgc", "tstellate"]):
|
||||
"""
|
||||
Compute the PSC between the specified two cell tpes
|
||||
The type of PSC is set by synspase type and must be 'multisite' or 'simple'
|
||||
"""
|
||||
assert synapsetype in ["multisite", "simple"]
|
||||
|
||||
c = []
|
||||
for cellType in celltypes:
|
||||
if cellType == "sgc":
|
||||
cell = cells.SGC.create()
|
||||
elif cellType == "tstellate":
|
||||
cell = cells.TStellate.create(debug=True, ttx=False)
|
||||
elif (
|
||||
cellType == "dstellate"
|
||||
): # Type I-II Rothman model, similiar excitability (Xie/Manis, unpublished)
|
||||
cell = cells.DStellate.create(model="RM03", debug=True, ttx=False)
|
||||
# elif cellType == 'dstellate_eager': # From Eager et al.
|
||||
# cell = cells.DStellate.create(model='Eager', debug=True, ttx=False)
|
||||
elif cellType == "bushy":
|
||||
cell = cells.Bushy.create(debug=True, ttx=True)
|
||||
elif cellType == "octopus":
|
||||
cell = cells.Octopus.create(debug=True, ttx=True)
|
||||
elif cellType == "tuberculoventral":
|
||||
cell = cells.Tuberculoventral.create(debug=True, ttx=True)
|
||||
elif cellType == "pyramidal":
|
||||
cell = cells.Pyramidal.create(debug=True, ttx=True)
|
||||
elif cellType == "cartwheel":
|
||||
cell = cells.Cartwheel.create(debug=True, ttx=True)
|
||||
else:
|
||||
raise ValueError("Unknown cell type '%s'" % cellType)
|
||||
c.append(cell)
|
||||
|
||||
preCell, postCell = c
|
||||
|
||||
print(f"Computing psc for connection {celltypes[0]:s} -> {celltypes[1]:s}")
|
||||
nTerminals = convergence.get(celltypes[0], {}).get(celltypes[1], None)
|
||||
if nTerminals is None:
|
||||
nTerminals = 1
|
||||
print(
|
||||
f"Warning: Unknown convergence for {celltypes[0]:s} -> {celltypes[1]:s}, ASSUMING {nTerminals:d} terminals"
|
||||
)
|
||||
|
||||
if celltypes == ["sgc", "bushy"]:
|
||||
niter = 50
|
||||
else:
|
||||
niter = 200
|
||||
assert synapsetype in ["simple", "multisite"]
|
||||
if synapsetype == "simple":
|
||||
niter = 1
|
||||
st = SynapseTest()
|
||||
dt = 0.010
|
||||
stim = {
|
||||
"NP": 1,
|
||||
"Sfreq": 100.0,
|
||||
"delay": 0.0,
|
||||
"dur": 0.5,
|
||||
"amp": 10.0,
|
||||
"PT": 0.0,
|
||||
"dt": dt,
|
||||
}
|
||||
st.run(
|
||||
preCell.soma,
|
||||
postCell.soma,
|
||||
nTerminals,
|
||||
dt=dt,
|
||||
vclamp=-65.0,
|
||||
iterations=niter,
|
||||
synapsetype=synapsetype,
|
||||
tstop=50.0,
|
||||
stim_params=stim,
|
||||
)
|
||||
# st.show_result() # pyqtgraph plotting -
|
||||
# st.plots['VPre'].setYRange(-70., 10.)
|
||||
# st.plots['EPSC'].setYRange(-2.0, 0.5)
|
||||
# st.plots['latency2080'].setYRange(0., 1.0)
|
||||
# st.plots['halfwidth'].setYRange(0., 1.0)
|
||||
# st.plots['RT'].setYRange(0., 0.2)
|
||||
# st.plots['latency'].setYRange(0., 1.0)
|
||||
# st.plots['latency_distribution'].setYRange(0., 1.0)
|
||||
return st # need to keep st alive in memory
|
||||
|
||||
|
||||
def fit_one(st, stk):
|
||||
"""
|
||||
Fit one trace to the exp2syn function
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
st : dict (no default)
|
||||
A dictionary containing (at least) the following keys:
|
||||
't' : the time base for the trace to be fit
|
||||
'i' : a list of arrays or a 2d numpy array of the data to be fit
|
||||
|
||||
Returns
|
||||
-------
|
||||
pars : The fitting parameters
|
||||
The fitted parameters are returned as an lmfit Parameters object, so access
|
||||
individual parameters a pars['parname'].value
|
||||
|
||||
fitted : The fitted waveform, of the same length as the source data waveform
|
||||
"""
|
||||
|
||||
if stk[0] in ["sgc", "tstellate", "granule"]:
|
||||
erev = (
|
||||
0.0
|
||||
) # set erev according to the source cell (excitatory: 0, inhibitory: -80)
|
||||
else:
|
||||
erev = -70.0 # value used in gly_psd
|
||||
print("\nstk, erev: ", stk, erev)
|
||||
F = Exp2SynFitting(
|
||||
initpars={
|
||||
"tau1": 1.0,
|
||||
"tauratio": 5.0,
|
||||
"weight": 0.0001,
|
||||
"erev": erev,
|
||||
"v": -65.0,
|
||||
"delay": 0,
|
||||
}
|
||||
)
|
||||
t = st["t"]
|
||||
p = F.fitpars
|
||||
target = np.mean(np.array(st["i"]), axis=0)
|
||||
# print(F.fitpars)
|
||||
pars = F.fit(t, target, F.fitpars)
|
||||
gw = F.factor(pars["tau1"].value, pars["tauratio"].value, pars["weight"].value)
|
||||
pars.add("GWeight", value=gw, vary=False)
|
||||
fitted = F.exp2syn(
|
||||
st["t"],
|
||||
pars["tau1"],
|
||||
pars["tauratio"],
|
||||
pars["weight"],
|
||||
pars["erev"],
|
||||
pars["v"],
|
||||
pars["delay"],
|
||||
)
|
||||
lmfit.printfuncs.report_fit(F.mim.params)
|
||||
return (pars, fitted)
|
||||
|
||||
|
||||
def fit_all():
|
||||
"""
|
||||
Fit exp2syn against the traces in stm
|
||||
This fits all pre-post cell pairs, and returns the fit
|
||||
|
||||
"""
|
||||
stm = read_pickle("multisite.pkl")
|
||||
fits = {}
|
||||
fitted = {}
|
||||
# print('stm keys: ', stm.keys())
|
||||
# exit()
|
||||
for i, stk in enumerate(stm.keys()): # for each pre-post cell pair
|
||||
fitp, fit = fit_one(stm[stk], stk)
|
||||
fits[stk] = fitp
|
||||
|
||||
fitted[stk] = {"t": stm[stk]["t"], "i": [fit], "pars": fitp}
|
||||
with (open("simple.pkl", "wb")) as fh:
|
||||
pickle.dump(fitted, fh)
|
||||
return fits
|
||||
|
||||
|
||||
def plot_all(stm, sts=None):
|
||||
P = PH.Plotter((3, 5), figsize=(11, 6))
|
||||
ax = P.axarr.ravel()
|
||||
keypairorder = []
|
||||
for i, stk in enumerate(stm.keys()):
|
||||
keypairorder.append(stk)
|
||||
data = stm[stk]
|
||||
idat = np.array(data["i"])
|
||||
ax[i].plot(np.array(data["t"]), np.mean(idat, axis=0), "c-", linewidth=1.5)
|
||||
sd = np.std(idat, axis=0)
|
||||
ax[i].plot(
|
||||
np.array(data["t"]), np.mean(idat, axis=0) + sd, "c--", linewidth=0.5
|
||||
) # for j in range(idat.shape[0]):
|
||||
ax[i].plot(
|
||||
np.array(data["t"]), np.mean(idat, axis=0) - sd, "c--", linewidth=0.5
|
||||
) # for j in range(idat.shape[0]):
|
||||
rel = 0
|
||||
for j in range(idat.shape[0]):
|
||||
if np.min(idat[j, :]) < -1e-2 or np.max(idat[j, :]) > 1e-2:
|
||||
rel += 1
|
||||
print(f"{str(stk):s} {rel:d} of {idat.shape[0]:d} {str(idat.shape):s}")
|
||||
|
||||
# ax[i].plot(data['t'], idat[j], 'k-', linewidth=0.5, alpha=0.25)
|
||||
ax[i].set_title("%s : %s" % (stk[0], stk[1]), fontsize=7)
|
||||
|
||||
if sts is not None: # plot the matching exp2syn on this
|
||||
for i, stks in enumerate(keypairorder):
|
||||
datas = sts[stks]
|
||||
isdat = np.array(datas["i"])
|
||||
ax[i].plot(datas["t"], np.mean(isdat, axis=0), "m-", linewidth=3, alpha=0.5)
|
||||
mpl.show()
|
||||
|
||||
|
||||
def print_exp2syn_fits(st):
|
||||
stkeys = list(st.keys())
|
||||
postcells = [
|
||||
"bushy",
|
||||
"tstellate",
|
||||
"dstellate",
|
||||
"octopus",
|
||||
"pyramidal",
|
||||
"tuberculoventral",
|
||||
"cartwheel",
|
||||
]
|
||||
fmts = {
|
||||
"weight": "{0:<18.6f}",
|
||||
"GWeight": "{0:<18.6f}",
|
||||
"tau1": "{0:<18.3f}",
|
||||
"tau2": "{0:<18.3f}",
|
||||
"delay": "{0:<18.3f}",
|
||||
"erev": "{0:<18.1f}",
|
||||
}
|
||||
pars = list(fmts.keys())
|
||||
for pre in ["sgc", "dstellate", "tuberculoventral"]:
|
||||
firstline = False
|
||||
vrow = dict.fromkeys(pars, False)
|
||||
for v in pars:
|
||||
for post in postcells:
|
||||
if (pre, post) not in stkeys:
|
||||
print("{0:<18s}".format("0"), end="")
|
||||
continue
|
||||
if not firstline:
|
||||
print("\n%s" % pre.upper())
|
||||
print("{0:<18s}".format(" "), end="")
|
||||
for i in range(len(postcells)):
|
||||
print("{0:<18s}".format(postcells[i]), end="")
|
||||
print()
|
||||
firstline = True
|
||||
if firstline:
|
||||
if not vrow[v]:
|
||||
if v == "tauratio":
|
||||
print("{0:<18s}".format("tau2"), end="")
|
||||
else:
|
||||
print("{0:<18s}".format(v), end="")
|
||||
vrow[v] = True
|
||||
# print(fits[(pre, post)].keys())
|
||||
fits = st[(pre, post)]["pars"]
|
||||
# print(v, fits)
|
||||
if v in ["tauratio", "tau2"]:
|
||||
print(
|
||||
fmts[v].format(fits["tau1"].value * fits["tauratio"].value),
|
||||
end="",
|
||||
)
|
||||
else:
|
||||
print(fmts[v].format(fits[v].value), end="")
|
||||
print()
|
||||
|
||||
|
||||
def read_pickle(mode):
|
||||
"""
|
||||
Read the pickled file generated by runall
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mode : str
|
||||
either 'simple' or 'multisite'
|
||||
|
||||
Returns:
|
||||
the resulting data, which is a dictionary
|
||||
"""
|
||||
with (open(Path(mode).with_suffix(".pkl"), "rb")) as fh:
|
||||
st = pickle.load(fh)
|
||||
return st
|
||||
|
||||
|
||||
def run_all():
|
||||
"""
|
||||
Run all of the multisite synapse calculations for each cell pair
|
||||
Save the results in the multisite.pkl file
|
||||
"""
|
||||
st = {}
|
||||
pkst = {}
|
||||
for pre in convergence.keys():
|
||||
for post in convergence[pre]:
|
||||
# if pre == 'sgc' and post in ['bushy', 'tstellate']:
|
||||
sti = compute_psc(synapsetype="multisite", celltypes=[pre, post])
|
||||
st[(pre, post)] = sti
|
||||
# remove neuron objects before pickling
|
||||
pkst[(pre, post)] = {
|
||||
"t": sti["t"],
|
||||
"i": sti.isoma,
|
||||
"v": sti["v_soma"],
|
||||
"pre": sti["v_pre"],
|
||||
}
|
||||
|
||||
with (open(Path("multisite").with_suffix(".pkl"), "wb")) as fh:
|
||||
pickle.dump(pkst, fh)
|
||||
plot_all(pkst)
|
||||
|
||||
|
||||
def run_one(pre, post):
|
||||
"""
|
||||
Run all of the multisite synapse calculations for each cell pair
|
||||
Save the results in the multisite.pkl file
|
||||
"""
|
||||
assert pre in ["sgc", "tstellate", "dstellate", "tuberculoventral"]
|
||||
assert post in ["bushy", "tstellate", "dstellate", "tuberculoventral", "pyramidal"]
|
||||
st = {}
|
||||
pkst = read_pickle("multisite")
|
||||
# if pre == 'sgc' and post in ['bushy', 'tstellate']:
|
||||
sti = compute_psc(synapsetype="multisite", celltypes=[pre, post])
|
||||
st[(pre, post)] = sti
|
||||
# remove neuron objects before pickling
|
||||
pkst[(pre, post)] = {
|
||||
"t": sti["t"],
|
||||
"i": sti.isoma,
|
||||
"v": sti["v_soma"],
|
||||
"pre": sti["v_pre"],
|
||||
}
|
||||
|
||||
with (open(Path("multisite").with_suffix(".pkl"), "wb")) as fh:
|
||||
pickle.dump(pkst, fh)
|
||||
plot_all(pkst)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compare simple and multisite synapses"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--mode",
|
||||
type=str,
|
||||
dest="mode",
|
||||
default="None",
|
||||
choices=["simple", "multisite", "both"],
|
||||
help="Select mode [simple, multisite, compare]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--run",
|
||||
action="store_true",
|
||||
dest="run",
|
||||
help="Run multisite models between all cell types",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--runone",
|
||||
action="store_true",
|
||||
dest="runone",
|
||||
help="Run multisite models between specified cell types",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pre",
|
||||
type=str,
|
||||
dest="pre",
|
||||
default="None",
|
||||
help="Select presynaptic cell for runone",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--post",
|
||||
type=str,
|
||||
dest="post",
|
||||
default="None",
|
||||
help="Select postsynaptic cell for runone",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--fit",
|
||||
action="store_true",
|
||||
dest="fit",
|
||||
help="Fit exp2syn waveforms to multisite data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--test",
|
||||
action="store_true",
|
||||
dest="test",
|
||||
help="Run the test on exp2syn fitting",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--plot",
|
||||
action="store_true",
|
||||
dest="plot",
|
||||
help="Plot the current comparison between simple and multisite",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--list",
|
||||
action="store_true",
|
||||
dest="list",
|
||||
help="List the simple fit parameters",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.test:
|
||||
testexp()
|
||||
exit()
|
||||
|
||||
if args.run:
|
||||
run_all()
|
||||
exit()
|
||||
|
||||
if args.runone:
|
||||
run_one(args.pre, args.post)
|
||||
exit()
|
||||
|
||||
if args.fit:
|
||||
fit_all() # self contained - always fits exp2syn against the current multistie data
|
||||
ds = read_pickle("simple")
|
||||
dm = read_pickle("multisite")
|
||||
plot_all(dm, ds)
|
||||
exit()
|
||||
|
||||
if args.plot:
|
||||
if args.mode in ["simple", "multisite"]:
|
||||
d = read_pickle(args.mode)
|
||||
plot_all(d)
|
||||
exit()
|
||||
elif args.mode in ["both"]:
|
||||
ds = read_pickle("simple")
|
||||
dm = read_pickle("multisite")
|
||||
plot_all(dm, ds)
|
||||
exit()
|
||||
else:
|
||||
print(f"Mode {args.mode:s} is not valid")
|
||||
exit()
|
||||
|
||||
if args.list:
|
||||
ds = read_pickle("simple")
|
||||
print_exp2syn_fits(ds)
|
||||
|
||||
# if sys.flags.interactive == 0:
|
||||
# pg.QtGui.QApplication.exec_()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
131
cnmodel/util/difftreewidget/DataTreeWidget.py
Normal file
131
cnmodel/util/difftreewidget/DataTreeWidget.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from pyqtgraph.Qt import QtGui, QtCore
|
||||
from pyqtgraph.pgcollections import OrderedDict
|
||||
from .TableWidget import TableWidget
|
||||
from pyqtgraph.python2_3 import asUnicode
|
||||
import types, traceback
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import metaarray
|
||||
|
||||
HAVE_METAARRAY = True
|
||||
except:
|
||||
HAVE_METAARRAY = False
|
||||
|
||||
__all__ = ["DataTreeWidget"]
|
||||
|
||||
|
||||
class DataTreeWidget(QtGui.QTreeWidget):
|
||||
"""
|
||||
Widget for displaying hierarchical python data structures
|
||||
(eg, nested dicts, lists, and arrays)
|
||||
"""
|
||||
|
||||
def __init__(self, parent=None, data=None):
|
||||
QtGui.QTreeWidget.__init__(self, parent)
|
||||
self.setVerticalScrollMode(self.ScrollPerPixel)
|
||||
self.setData(data)
|
||||
self.setColumnCount(3)
|
||||
self.setHeaderLabels(["key / index", "type", "value"])
|
||||
self.setAlternatingRowColors(True)
|
||||
|
||||
def setData(self, data, hideRoot=False):
|
||||
"""data should be a dictionary."""
|
||||
self.clear()
|
||||
self.widgets = []
|
||||
self.nodes = {}
|
||||
self.buildTree(data, self.invisibleRootItem(), hideRoot=hideRoot)
|
||||
self.expandToDepth(3)
|
||||
self.resizeColumnToContents(0)
|
||||
|
||||
def buildTree(self, data, parent, name="", hideRoot=False, path=()):
|
||||
if hideRoot:
|
||||
node = parent
|
||||
else:
|
||||
node = QtGui.QTreeWidgetItem([name, "", ""])
|
||||
parent.addChild(node)
|
||||
|
||||
# record the path to the node so it can be retrieved later
|
||||
# (this is used by DiffTreeWidget)
|
||||
self.nodes[path] = node
|
||||
|
||||
typeStr, desc, childs, widget = self.parse(data)
|
||||
node.setText(1, typeStr)
|
||||
node.setText(2, desc)
|
||||
|
||||
# Truncate description and add text box if needed
|
||||
if len(desc) > 100:
|
||||
desc = desc[:97] + "..."
|
||||
if widget is None:
|
||||
widget = QtGui.QPlainTextEdit(asUnicode(data))
|
||||
widget.setMaximumHeight(200)
|
||||
widget.setReadOnly(True)
|
||||
|
||||
# Add widget to new subnode
|
||||
if widget is not None:
|
||||
self.widgets.append(widget)
|
||||
subnode = QtGui.QTreeWidgetItem(["", "", ""])
|
||||
node.addChild(subnode)
|
||||
self.setItemWidget(subnode, 0, widget)
|
||||
self.setFirstItemColumnSpanned(subnode, True)
|
||||
|
||||
# recurse to children
|
||||
for key, data in childs.items():
|
||||
self.buildTree(data, node, asUnicode(key), path=path + (key,))
|
||||
|
||||
def parse(self, data):
|
||||
"""
|
||||
Given any python object, return:
|
||||
* type
|
||||
* a short string representation
|
||||
* a dict of sub-objects to be parsed
|
||||
* optional widget to display as sub-node
|
||||
"""
|
||||
# defaults for all objects
|
||||
typeStr = type(data).__name__
|
||||
if typeStr == "instance":
|
||||
typeStr += ": " + data.__class__.__name__
|
||||
widget = None
|
||||
desc = ""
|
||||
childs = {}
|
||||
|
||||
# type-specific changes
|
||||
if isinstance(data, dict):
|
||||
desc = "length=%d" % len(data)
|
||||
if isinstance(data, OrderedDict):
|
||||
childs = data
|
||||
else:
|
||||
childs = OrderedDict(sorted(data.items()))
|
||||
elif isinstance(data, (list, tuple)):
|
||||
desc = "length=%d" % len(data)
|
||||
childs = OrderedDict(enumerate(data))
|
||||
elif HAVE_METAARRAY and (
|
||||
hasattr(data, "implements") and data.implements("MetaArray")
|
||||
):
|
||||
childs = OrderedDict(
|
||||
[("data", data.view(np.ndarray)), ("meta", data.infoCopy())]
|
||||
)
|
||||
elif isinstance(data, np.ndarray):
|
||||
desc = "shape=%s dtype=%s" % (data.shape, data.dtype)
|
||||
table = TableWidget()
|
||||
table.setData(data)
|
||||
table.setMaximumHeight(200)
|
||||
widget = table
|
||||
elif isinstance(
|
||||
data, types.TracebackType
|
||||
): ## convert traceback to a list of strings
|
||||
frames = list(
|
||||
map(str.strip, traceback.format_list(traceback.extract_tb(data)))
|
||||
)
|
||||
# childs = OrderedDict([
|
||||
# (i, {'file': child[0], 'line': child[1], 'function': child[2], 'code': child[3]})
|
||||
# for i, child in enumerate(frames)])
|
||||
# childs = OrderedDict([(i, ch) for i,ch in enumerate(frames)])
|
||||
widget = QtGui.QPlainTextEdit(asUnicode("\n".join(frames)))
|
||||
widget.setMaximumHeight(200)
|
||||
widget.setReadOnly(True)
|
||||
else:
|
||||
desc = asUnicode(data)
|
||||
|
||||
return typeStr, desc, childs, widget
|
||||
170
cnmodel/util/difftreewidget/DiffTreeWidget.py
Normal file
170
cnmodel/util/difftreewidget/DiffTreeWidget.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from pyqtgraph.Qt import QtGui, QtCore
|
||||
from pyqtgraph.pgcollections import OrderedDict
|
||||
from .DataTreeWidget import DataTreeWidget
|
||||
import pyqtgraph.functions as fn
|
||||
import types, traceback
|
||||
import numpy as np
|
||||
|
||||
__all__ = ["DiffTreeWidget"]
|
||||
|
||||
|
||||
class DiffTreeWidget(QtGui.QWidget):
|
||||
"""
|
||||
Widget for displaying differences between hierarchical python data structures
|
||||
(eg, nested dicts, lists, and arrays)
|
||||
"""
|
||||
|
||||
def __init__(self, parent=None, a=None, b=None):
|
||||
QtGui.QWidget.__init__(self, parent)
|
||||
self.layout = QtGui.QHBoxLayout()
|
||||
self.setLayout(self.layout)
|
||||
self.trees = [DataTreeWidget(self), DataTreeWidget(self)]
|
||||
for t in self.trees:
|
||||
self.layout.addWidget(t)
|
||||
if a is not None:
|
||||
self.setData(a, b)
|
||||
|
||||
def setData(self, a, b):
|
||||
"""
|
||||
Set the data to be compared in this widget.
|
||||
"""
|
||||
self.data = (a, b)
|
||||
self.trees[0].setData(a)
|
||||
self.trees[1].setData(b)
|
||||
|
||||
return self.compare(a, b)
|
||||
|
||||
def compare(self, a, b, path=()):
|
||||
"""
|
||||
Compare data structure *a* to structure *b*.
|
||||
|
||||
Return True if the objects match completely.
|
||||
Otherwise, return a structure that describes the differences:
|
||||
|
||||
{ 'type': bool
|
||||
'len': bool,
|
||||
'str': bool,
|
||||
'shape': bool,
|
||||
'dtype': bool,
|
||||
'mask': array,
|
||||
}
|
||||
|
||||
|
||||
"""
|
||||
bad = (255, 200, 200)
|
||||
diff = []
|
||||
# generate typestr, desc, childs for each object
|
||||
typeA, descA, childsA, _ = self.trees[0].parse(a)
|
||||
typeB, descB, childsB, _ = self.trees[1].parse(b)
|
||||
|
||||
if typeA != typeB:
|
||||
self.setColor(path, 1, bad)
|
||||
if descA != descB:
|
||||
self.setColor(path, 2, bad)
|
||||
|
||||
if isinstance(a, dict) and isinstance(b, dict):
|
||||
keysA = set(a.keys())
|
||||
keysB = set(b.keys())
|
||||
for key in keysA - keysB:
|
||||
self.setColor(path + (key,), 0, bad, tree=0)
|
||||
for key in keysB - keysA:
|
||||
self.setColor(path + (key,), 0, bad, tree=1)
|
||||
for key in keysA & keysB:
|
||||
self.compare(a[key], b[key], path + (key,))
|
||||
|
||||
elif isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)):
|
||||
for i in range(max(len(a), len(b))):
|
||||
if len(a) <= i:
|
||||
self.setColor(path + (i,), 0, bad, tree=1)
|
||||
elif len(b) <= i:
|
||||
self.setColor(path + (i,), 0, bad, tree=0)
|
||||
else:
|
||||
self.compare(a[i], b[i], path + (i,))
|
||||
|
||||
elif (
|
||||
isinstance(a, np.ndarray)
|
||||
and isinstance(b, np.ndarray)
|
||||
and a.shape == b.shape
|
||||
):
|
||||
tableNodes = [tree.nodes[path].child(0) for tree in self.trees]
|
||||
if a.dtype.fields is None and b.dtype.fields is None:
|
||||
eq = self.compareArrays(a, b)
|
||||
if not np.all(eq):
|
||||
for n in tableNodes:
|
||||
n.setBackground(0, fn.mkBrush(bad))
|
||||
# for i in np.argwhere(~eq):
|
||||
|
||||
else:
|
||||
if a.dtype == b.dtype:
|
||||
for i, k in enumerate(a.dtype.fields.keys()):
|
||||
eq = self.compareArrays(a[k], b[k])
|
||||
if not np.all(eq):
|
||||
for n in tableNodes:
|
||||
n.setBackground(0, fn.mkBrush(bad))
|
||||
# for j in np.argwhere(~eq):
|
||||
|
||||
# dict: compare keys, then values where keys match
|
||||
# list:
|
||||
# array: compare elementwise for same shape
|
||||
|
||||
def compareArrays(self, a, b):
|
||||
intnan = -9223372036854775808 # happens when np.nan is cast to int
|
||||
anans = np.isnan(a) | (a == intnan)
|
||||
bnans = np.isnan(b) | (b == intnan)
|
||||
eq = anans == bnans
|
||||
mask = ~anans
|
||||
eq[mask] = np.allclose(a[mask], b[mask])
|
||||
return eq
|
||||
|
||||
def setColor(self, path, column, color, tree=None):
|
||||
brush = fn.mkBrush(color)
|
||||
|
||||
# Color only one tree if specified.
|
||||
if tree is None:
|
||||
trees = self.trees
|
||||
else:
|
||||
trees = [self.trees[tree]]
|
||||
|
||||
for tree in trees:
|
||||
item = tree.nodes[path]
|
||||
item.setBackground(column, brush)
|
||||
|
||||
def _compare(self, a, b):
|
||||
"""
|
||||
Compare data structure *a* to structure *b*.
|
||||
"""
|
||||
# Check test structures are the same
|
||||
assert type(info) is type(expect)
|
||||
if hasattr(info, "__len__"):
|
||||
assert len(info) == len(expect)
|
||||
|
||||
if isinstance(info, dict):
|
||||
for k in info:
|
||||
assert k in expect
|
||||
for k in expect:
|
||||
assert k in info
|
||||
self.compare_results(info[k], expect[k])
|
||||
elif isinstance(info, list):
|
||||
for i in range(len(info)):
|
||||
self.compare_results(info[i], expect[i])
|
||||
elif isinstance(info, np.ndarray):
|
||||
assert info.shape == expect.shape
|
||||
assert info.dtype == expect.dtype
|
||||
if info.dtype.fields is None:
|
||||
intnan = -9223372036854775808 # happens when np.nan is cast to int
|
||||
inans = np.isnan(info) | (info == intnan)
|
||||
enans = np.isnan(expect) | (expect == intnan)
|
||||
assert np.all(inans == enans)
|
||||
mask = ~inans
|
||||
assert np.allclose(info[mask], expect[mask])
|
||||
else:
|
||||
for k in info.dtype.fields.keys():
|
||||
self.compare_results(info[k], expect[k])
|
||||
else:
|
||||
try:
|
||||
assert info == expect
|
||||
except Exception:
|
||||
raise NotImplementedError(
|
||||
"Cannot compare objects of type %s" % type(info)
|
||||
)
|
||||
2
cnmodel/util/difftreewidget/README
Normal file
2
cnmodel/util/difftreewidget/README
Normal file
@@ -0,0 +1,2 @@
|
||||
Copied from github.com/campagnola/pyqtgraph datatree-arrays branch; this gives us DiffTreeWidget.
|
||||
This file can be removed if DiffTreeWidget has been merged into pyqtgraph.
|
||||
515
cnmodel/util/difftreewidget/TableWidget.py
Normal file
515
cnmodel/util/difftreewidget/TableWidget.py
Normal file
@@ -0,0 +1,515 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import numpy as np
|
||||
from pyqtgraph.Qt import QtGui, QtCore
|
||||
from pyqtgraph.python2_3 import asUnicode, basestring
|
||||
import pyqtgraph.metaarray as metaarray
|
||||
|
||||
|
||||
__all__ = ["TableWidget"]
|
||||
|
||||
|
||||
def _defersort(fn):
|
||||
def defersort(self, *args, **kwds):
|
||||
# may be called recursively; only the first call needs to block sorting
|
||||
setSorting = False
|
||||
if self._sorting is None:
|
||||
self._sorting = self.isSortingEnabled()
|
||||
setSorting = True
|
||||
self.setSortingEnabled(False)
|
||||
try:
|
||||
return fn(self, *args, **kwds)
|
||||
finally:
|
||||
if setSorting:
|
||||
self.setSortingEnabled(self._sorting)
|
||||
self._sorting = None
|
||||
|
||||
return defersort
|
||||
|
||||
|
||||
class TableWidget(QtGui.QTableWidget):
|
||||
"""Extends QTableWidget with some useful functions for automatic data handling
|
||||
and copy / export context menu. Can automatically format and display a variety
|
||||
of data types (see :func:`setData() <pyqtgraph.TableWidget.setData>` for more
|
||||
information.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwds):
|
||||
"""
|
||||
All positional arguments are passed to QTableWidget.__init__().
|
||||
|
||||
===================== =================================================
|
||||
**Keyword Arguments**
|
||||
editable (bool) If True, cells in the table can be edited
|
||||
by the user. Default is False.
|
||||
sortable (bool) If True, the table may be soted by
|
||||
clicking on column headers. Note that this also
|
||||
causes rows to appear initially shuffled until
|
||||
a sort column is selected. Default is True.
|
||||
*(added in version 0.9.9)*
|
||||
===================== =================================================
|
||||
"""
|
||||
|
||||
QtGui.QTableWidget.__init__(self, *args)
|
||||
|
||||
self.itemClass = TableWidgetItem
|
||||
|
||||
self.setVerticalScrollMode(self.ScrollPerPixel)
|
||||
self.setSelectionMode(QtGui.QAbstractItemView.ContiguousSelection)
|
||||
self.setSizePolicy(QtGui.QSizePolicy.Preferred, QtGui.QSizePolicy.Preferred)
|
||||
self.clear()
|
||||
|
||||
kwds.setdefault("sortable", True)
|
||||
kwds.setdefault("editable", False)
|
||||
self.setEditable(kwds.pop("editable"))
|
||||
self.setSortingEnabled(kwds.pop("sortable"))
|
||||
|
||||
if len(kwds) > 0:
|
||||
raise TypeError("Invalid keyword arguments '%s'" % kwds.keys())
|
||||
|
||||
self._sorting = None # used when temporarily disabling sorting
|
||||
|
||||
self._formats = {
|
||||
None: None
|
||||
} # stores per-column formats and entire table format
|
||||
self.sortModes = {} # stores per-column sort mode
|
||||
|
||||
self.itemChanged.connect(self.handleItemChanged)
|
||||
|
||||
self.contextMenu = QtGui.QMenu()
|
||||
self.contextMenu.addAction("Copy Selection").triggered.connect(self.copySel)
|
||||
self.contextMenu.addAction("Copy All").triggered.connect(self.copyAll)
|
||||
self.contextMenu.addAction("Save Selection").triggered.connect(self.saveSel)
|
||||
self.contextMenu.addAction("Save All").triggered.connect(self.saveAll)
|
||||
|
||||
def clear(self):
|
||||
"""Clear all contents from the table."""
|
||||
QtGui.QTableWidget.clear(self)
|
||||
self.verticalHeadersSet = False
|
||||
self.horizontalHeadersSet = False
|
||||
self.items = []
|
||||
self.setRowCount(0)
|
||||
self.setColumnCount(0)
|
||||
self.sortModes = {}
|
||||
|
||||
def setData(self, data):
|
||||
"""Set the data displayed in the table.
|
||||
Allowed formats are:
|
||||
|
||||
* numpy arrays
|
||||
* numpy record arrays
|
||||
* metaarrays
|
||||
* list-of-lists [[1,2,3], [4,5,6]]
|
||||
* dict-of-lists {'x': [1,2,3], 'y': [4,5,6]}
|
||||
* list-of-dicts [{'x': 1, 'y': 4}, {'x': 2, 'y': 5}, ...]
|
||||
"""
|
||||
self.clear()
|
||||
self.appendData(data)
|
||||
self.resizeColumnsToContents()
|
||||
|
||||
@_defersort
|
||||
def appendData(self, data):
|
||||
"""
|
||||
Add new rows to the table.
|
||||
|
||||
See :func:`setData() <pyqtgraph.TableWidget.setData>` for accepted
|
||||
data types.
|
||||
"""
|
||||
startRow = self.rowCount()
|
||||
|
||||
fn0, header0 = self.iteratorFn(data)
|
||||
if fn0 is None:
|
||||
self.clear()
|
||||
return
|
||||
it0 = fn0(data)
|
||||
try:
|
||||
first = next(it0)
|
||||
except StopIteration:
|
||||
return
|
||||
fn1, header1 = self.iteratorFn(first)
|
||||
if fn1 is None:
|
||||
self.clear()
|
||||
return
|
||||
|
||||
firstVals = [x for x in fn1(first)]
|
||||
self.setColumnCount(len(firstVals))
|
||||
|
||||
if not self.verticalHeadersSet and header0 is not None:
|
||||
labels = [self.verticalHeaderItem(i).text() for i in range(self.rowCount())]
|
||||
self.setRowCount(startRow + len(header0))
|
||||
self.setVerticalHeaderLabels(labels + header0)
|
||||
self.verticalHeadersSet = True
|
||||
if not self.horizontalHeadersSet and header1 is not None:
|
||||
self.setHorizontalHeaderLabels(header1)
|
||||
self.horizontalHeadersSet = True
|
||||
|
||||
i = startRow
|
||||
self.setRow(i, firstVals)
|
||||
for row in it0:
|
||||
i += 1
|
||||
self.setRow(i, [x for x in fn1(row)])
|
||||
|
||||
if (
|
||||
self._sorting
|
||||
and self.horizontalHeadersSet
|
||||
and self.horizontalHeader().sortIndicatorSection() >= self.columnCount()
|
||||
):
|
||||
self.sortByColumn(0, QtCore.Qt.AscendingOrder)
|
||||
|
||||
def setEditable(self, editable=True):
|
||||
self.editable = editable
|
||||
for item in self.items:
|
||||
item.setEditable(editable)
|
||||
|
||||
def setFormat(self, format, column=None):
|
||||
"""
|
||||
Specify the default text formatting for the entire table, or for a
|
||||
single column if *column* is specified.
|
||||
|
||||
If a string is specified, it is used as a format string for converting
|
||||
float values (and all other types are converted using str). If a
|
||||
function is specified, it will be called with the item as its only
|
||||
argument and must return a string. Setting format = None causes the
|
||||
default formatter to be used instead.
|
||||
|
||||
Added in version 0.9.9.
|
||||
|
||||
"""
|
||||
if (
|
||||
format is not None
|
||||
and not isinstance(format, basestring)
|
||||
and not callable(format)
|
||||
):
|
||||
raise ValueError(
|
||||
"Format argument must string, callable, or None. (got %s)" % format
|
||||
)
|
||||
|
||||
self._formats[column] = format
|
||||
|
||||
if column is None:
|
||||
# update format of all items that do not have a column format
|
||||
# specified
|
||||
for c in range(self.columnCount()):
|
||||
if self._formats.get(c, None) is None:
|
||||
for r in range(self.rowCount()):
|
||||
item = self.item(r, c)
|
||||
if item is None:
|
||||
continue
|
||||
item.setFormat(format)
|
||||
else:
|
||||
# set all items in the column to use this format, or the default
|
||||
# table format if None was specified.
|
||||
if format is None:
|
||||
format = self._formats[None]
|
||||
for r in range(self.rowCount()):
|
||||
item = self.item(r, column)
|
||||
if item is None:
|
||||
continue
|
||||
item.setFormat(format)
|
||||
|
||||
def iteratorFn(self, data):
|
||||
## Return 1) a function that will provide an iterator for data and 2) a list of header strings
|
||||
if isinstance(data, list) or isinstance(data, tuple):
|
||||
return lambda d: d.__iter__(), None
|
||||
elif isinstance(data, dict):
|
||||
return lambda d: iter(d.values()), list(map(asUnicode, data.keys()))
|
||||
elif hasattr(data, "implements") and data.implements("MetaArray"):
|
||||
if data.axisHasColumns(0):
|
||||
header = [
|
||||
asUnicode(data.columnName(0, i)) for i in range(data.shape[0])
|
||||
]
|
||||
elif data.axisHasValues(0):
|
||||
header = list(map(asUnicode, data.xvals(0)))
|
||||
else:
|
||||
header = None
|
||||
return self.iterFirstAxis, header
|
||||
elif isinstance(data, np.ndarray):
|
||||
return self.iterFirstAxis, None
|
||||
elif isinstance(data, np.void):
|
||||
return self.iterate, list(map(asUnicode, data.dtype.names))
|
||||
elif data is None:
|
||||
return (None, None)
|
||||
elif np.isscalar(data):
|
||||
return self.iterateScalar, None
|
||||
else:
|
||||
msg = "Don't know how to iterate over data type: {!s}".format(type(data))
|
||||
raise TypeError(msg)
|
||||
|
||||
def iterFirstAxis(self, data):
|
||||
for i in range(data.shape[0]):
|
||||
yield data[i]
|
||||
|
||||
def iterate(self, data):
|
||||
# for numpy.void, which can be iterated but mysteriously
|
||||
# has no __iter__ (??)
|
||||
for x in data:
|
||||
yield x
|
||||
|
||||
def iterateScalar(self, data):
|
||||
yield data
|
||||
|
||||
def appendRow(self, data):
|
||||
self.appendData([data])
|
||||
|
||||
@_defersort
|
||||
def addRow(self, vals):
|
||||
row = self.rowCount()
|
||||
self.setRowCount(row + 1)
|
||||
self.setRow(row, vals)
|
||||
|
||||
@_defersort
|
||||
def setRow(self, row, vals):
|
||||
if row > self.rowCount() - 1:
|
||||
self.setRowCount(row + 1)
|
||||
for col in range(len(vals)):
|
||||
val = vals[col]
|
||||
item = self.itemClass(val, row)
|
||||
item.setEditable(self.editable)
|
||||
sortMode = self.sortModes.get(col, None)
|
||||
if sortMode is not None:
|
||||
item.setSortMode(sortMode)
|
||||
format = self._formats.get(col, self._formats[None])
|
||||
item.setFormat(format)
|
||||
self.items.append(item)
|
||||
self.setItem(row, col, item)
|
||||
item.setValue(val) # Required--the text-change callback is invoked
|
||||
# when we call setItem.
|
||||
|
||||
def setSortMode(self, column, mode):
|
||||
"""
|
||||
Set the mode used to sort *column*.
|
||||
|
||||
============== ========================================================
|
||||
**Sort Modes**
|
||||
value Compares item.value if available; falls back to text
|
||||
comparison.
|
||||
text Compares item.text()
|
||||
index Compares by the order in which items were inserted.
|
||||
============== ========================================================
|
||||
|
||||
Added in version 0.9.9
|
||||
"""
|
||||
for r in range(self.rowCount()):
|
||||
item = self.item(r, column)
|
||||
if hasattr(item, "setSortMode"):
|
||||
item.setSortMode(mode)
|
||||
self.sortModes[column] = mode
|
||||
|
||||
def sizeHint(self):
|
||||
# based on http://stackoverflow.com/a/7195443/54056
|
||||
width = sum(self.columnWidth(i) for i in range(self.columnCount()))
|
||||
width += self.verticalHeader().sizeHint().width()
|
||||
width += self.verticalScrollBar().sizeHint().width()
|
||||
width += self.frameWidth() * 2
|
||||
height = sum(self.rowHeight(i) for i in range(self.rowCount()))
|
||||
height += self.verticalHeader().sizeHint().height()
|
||||
height += self.horizontalScrollBar().sizeHint().height()
|
||||
return QtCore.QSize(width, height)
|
||||
|
||||
def serialize(self, useSelection=False):
|
||||
"""Convert entire table (or just selected area) into tab-separated text values"""
|
||||
if useSelection:
|
||||
selection = self.selectedRanges()[0]
|
||||
rows = list(range(selection.topRow(), selection.bottomRow() + 1))
|
||||
columns = list(range(selection.leftColumn(), selection.rightColumn() + 1))
|
||||
else:
|
||||
rows = list(range(self.rowCount()))
|
||||
columns = list(range(self.columnCount()))
|
||||
|
||||
data = []
|
||||
if self.horizontalHeadersSet:
|
||||
row = []
|
||||
if self.verticalHeadersSet:
|
||||
row.append(asUnicode(""))
|
||||
|
||||
for c in columns:
|
||||
row.append(asUnicode(self.horizontalHeaderItem(c).text()))
|
||||
data.append(row)
|
||||
|
||||
for r in rows:
|
||||
row = []
|
||||
if self.verticalHeadersSet:
|
||||
row.append(asUnicode(self.verticalHeaderItem(r).text()))
|
||||
for c in columns:
|
||||
item = self.item(r, c)
|
||||
if item is not None:
|
||||
row.append(asUnicode(item.value))
|
||||
else:
|
||||
row.append(asUnicode(""))
|
||||
data.append(row)
|
||||
|
||||
s = ""
|
||||
for row in data:
|
||||
s += "\t".join(row) + "\n"
|
||||
return s
|
||||
|
||||
def copySel(self):
|
||||
"""Copy selected data to clipboard."""
|
||||
QtGui.QApplication.clipboard().setText(self.serialize(useSelection=True))
|
||||
|
||||
def copyAll(self):
|
||||
"""Copy all data to clipboard."""
|
||||
QtGui.QApplication.clipboard().setText(self.serialize(useSelection=False))
|
||||
|
||||
def saveSel(self):
|
||||
"""Save selected data to file."""
|
||||
self.save(self.serialize(useSelection=True))
|
||||
|
||||
def saveAll(self):
|
||||
"""Save all data to file."""
|
||||
self.save(self.serialize(useSelection=False))
|
||||
|
||||
def save(self, data):
|
||||
fileName = QtGui.QFileDialog.getSaveFileName(
|
||||
self, "Save As..", "", "Tab-separated values (*.tsv)"
|
||||
)
|
||||
if fileName == "":
|
||||
return
|
||||
open(fileName, "w").write(data)
|
||||
|
||||
def contextMenuEvent(self, ev):
|
||||
self.contextMenu.popup(ev.globalPos())
|
||||
|
||||
def keyPressEvent(self, ev):
|
||||
if ev.key() == QtCore.Qt.Key_C and ev.modifiers() == QtCore.Qt.ControlModifier:
|
||||
ev.accept()
|
||||
self.copySel()
|
||||
else:
|
||||
QtGui.QTableWidget.keyPressEvent(self, ev)
|
||||
|
||||
def handleItemChanged(self, item):
|
||||
item.itemChanged()
|
||||
|
||||
|
||||
class TableWidgetItem(QtGui.QTableWidgetItem):
|
||||
def __init__(self, val, index, format=None):
|
||||
QtGui.QTableWidgetItem.__init__(self, "")
|
||||
self._blockValueChange = False
|
||||
self._format = None
|
||||
self._defaultFormat = "%0.3g"
|
||||
self.sortMode = "value"
|
||||
self.index = index
|
||||
flags = QtCore.Qt.ItemIsSelectable | QtCore.Qt.ItemIsEnabled
|
||||
self.setFlags(flags)
|
||||
self.setValue(val)
|
||||
self.setFormat(format)
|
||||
|
||||
def setEditable(self, editable):
|
||||
"""
|
||||
Set whether this item is user-editable.
|
||||
"""
|
||||
if editable:
|
||||
self.setFlags(self.flags() | QtCore.Qt.ItemIsEditable)
|
||||
else:
|
||||
self.setFlags(self.flags() & ~QtCore.Qt.ItemIsEditable)
|
||||
|
||||
def setSortMode(self, mode):
|
||||
"""
|
||||
Set the mode used to sort this item against others in its column.
|
||||
|
||||
============== ========================================================
|
||||
**Sort Modes**
|
||||
value Compares item.value if available; falls back to text
|
||||
comparison.
|
||||
text Compares item.text()
|
||||
index Compares by the order in which items were inserted.
|
||||
============== ========================================================
|
||||
"""
|
||||
modes = ("value", "text", "index", None)
|
||||
if mode not in modes:
|
||||
raise ValueError("Sort mode must be one of %s" % str(modes))
|
||||
self.sortMode = mode
|
||||
|
||||
def setFormat(self, fmt):
|
||||
"""Define the conversion from item value to displayed text.
|
||||
|
||||
If a string is specified, it is used as a format string for converting
|
||||
float values (and all other types are converted using str). If a
|
||||
function is specified, it will be called with the item as its only
|
||||
argument and must return a string.
|
||||
|
||||
Added in version 0.9.9.
|
||||
"""
|
||||
if fmt is not None and not isinstance(fmt, basestring) and not callable(fmt):
|
||||
raise ValueError(
|
||||
"Format argument must string, callable, or None. (got %s)" % fmt
|
||||
)
|
||||
self._format = fmt
|
||||
self._updateText()
|
||||
|
||||
def _updateText(self):
|
||||
self._blockValueChange = True
|
||||
try:
|
||||
self._text = self.format()
|
||||
self.setText(self._text)
|
||||
finally:
|
||||
self._blockValueChange = False
|
||||
|
||||
def setValue(self, value):
|
||||
self.value = value
|
||||
self._updateText()
|
||||
|
||||
def itemChanged(self):
|
||||
"""Called when the data of this item has changed."""
|
||||
if self.text() != self._text:
|
||||
self.textChanged()
|
||||
|
||||
def textChanged(self):
|
||||
"""Called when this item's text has changed for any reason."""
|
||||
self._text = self.text()
|
||||
|
||||
if self._blockValueChange:
|
||||
# text change was result of value or format change; do not
|
||||
# propagate.
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
self.value = type(self.value)(self.text())
|
||||
except ValueError:
|
||||
self.value = str(self.text())
|
||||
|
||||
def format(self):
|
||||
if callable(self._format):
|
||||
return self._format(self)
|
||||
if isinstance(self.value, (float, np.floating)):
|
||||
if self._format is None:
|
||||
return self._defaultFormat % self.value
|
||||
else:
|
||||
return self._format % self.value
|
||||
else:
|
||||
return asUnicode(self.value)
|
||||
|
||||
def __lt__(self, other):
|
||||
if self.sortMode == "index" and hasattr(other, "index"):
|
||||
return self.index < other.index
|
||||
if self.sortMode == "value" and hasattr(other, "value"):
|
||||
return self.value < other.value
|
||||
else:
|
||||
return self.text() < other.text()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = QtGui.QApplication([])
|
||||
win = QtGui.QMainWindow()
|
||||
t = TableWidget()
|
||||
win.setCentralWidget(t)
|
||||
win.resize(800, 600)
|
||||
win.show()
|
||||
|
||||
ll = [[1, 2, 3, 4, 5]] * 20
|
||||
ld = [{"x": 1, "y": 2, "z": 3}] * 20
|
||||
dl = {"x": list(range(20)), "y": list(range(20)), "z": list(range(20))}
|
||||
|
||||
a = np.ones((20, 5))
|
||||
ra = np.ones((20,), dtype=[("x", int), ("y", int), ("z", int)])
|
||||
|
||||
t.setData(ll)
|
||||
|
||||
ma = metaarray.MetaArray(
|
||||
np.ones((20, 3)),
|
||||
info=[
|
||||
{"values": np.linspace(1, 5, 20)},
|
||||
{"cols": [{"name": "x"}, {"name": "y"}, {"name": "z"}]},
|
||||
],
|
||||
)
|
||||
t.setData(ma)
|
||||
1
cnmodel/util/difftreewidget/__init__.py
Normal file
1
cnmodel/util/difftreewidget/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .DiffTreeWidget import DiffTreeWidget
|
||||
83
cnmodel/util/expfitting.py
Executable file
83
cnmodel/util/expfitting.py
Executable file
@@ -0,0 +1,83 @@
|
||||
#!/usr/bin/env python
|
||||
# encoding: utf-8
|
||||
"""
|
||||
expfitting.py
|
||||
Provide single or double exponential fits to data.
|
||||
"""
|
||||
|
||||
import lmfit
|
||||
import numpy as np
|
||||
import scipy.optimize
|
||||
|
||||
|
||||
class ExpFitting:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
nexp : int
|
||||
1 or 2 for single or double exponential fit
|
||||
initpars : dict
|
||||
dict of initial parameters. For example: {'dc': 0.,
|
||||
'a1': 1., 't1': 3, 'a2' : 0.5, 'delta': 3.}, where
|
||||
delta determines the ratio between the time constants.
|
||||
bounds : dict
|
||||
dictionary of bounds for each parameter, with a list of lower and upper values.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, nexp=1, initpars=None, bounds=None):
|
||||
self.fitpars = lmfit.Parameters()
|
||||
if nexp == 1:
|
||||
# (Name, Value, Vary, Min, Max, Expr)
|
||||
self.fitpars.add_many(
|
||||
("dc", 0, True, -100.0, 0.0, None),
|
||||
("a1", 1.0, True, -25.0, 25.0, None),
|
||||
("t1", 10.0, True, 0.1, 50, None),
|
||||
)
|
||||
self.efunc = self.exp1_err
|
||||
elif nexp == 2:
|
||||
self.fitpars.add_many(
|
||||
("dc", 0, True, -100.0, 0.0, None),
|
||||
("a1", 1.0, True, 0.0, 25.0, None),
|
||||
("t1", 10.0, True, 0.1, 50, None),
|
||||
("a2", 1.0, True, 0.0, 25.0, None),
|
||||
("delta", 3.0, True, 3.0, 100.0, None),
|
||||
)
|
||||
if initpars is not None:
|
||||
assert len(initpars) == 5
|
||||
for k, v in initpars.iteritems():
|
||||
self.fitpars[k].value = v
|
||||
if bounds is not None:
|
||||
assert len(bounds) == 5
|
||||
for k, v in bounds.iteritems():
|
||||
self.fitpars[k].min = v[0]
|
||||
self.fitpars[k].max = v[1]
|
||||
|
||||
self.efunc = self.exp2_err
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def fit(self, x, y, p, verbose=False):
|
||||
|
||||
kws = {"maxfev": 5000}
|
||||
mim = lmfit.minimize(
|
||||
self.efunc, p, method="least_squares", args=(x, y)
|
||||
) # , kws=kws)
|
||||
if verbose:
|
||||
lmfit.printfuncs.report_fit(mim.params)
|
||||
fitpars = mim.params
|
||||
return fitpars
|
||||
|
||||
@staticmethod
|
||||
def exp1(x, dc, t1, a1):
|
||||
return dc + a1 * np.exp(-x / t1)
|
||||
|
||||
def exp1_err(self, p, x, y):
|
||||
return np.fabs(y - self.exp1(x, **dict([(k, p.value) for k, p in p.items()])))
|
||||
|
||||
@staticmethod
|
||||
def exp2(x, dc, t1, a1, a2, delta):
|
||||
return dc + a1 * np.exp(-x / t1) + a2 * np.exp(-x / (t1 * delta))
|
||||
|
||||
def exp2_err(self, p, x, y):
|
||||
return np.fabs(y - self.exp2(x, **dict([(k, p.value) for k, p in p.items()])))
|
||||
133
cnmodel/util/filelock.py
Normal file
133
cnmodel/util/filelock.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# Copyright (c) 2009, Evan Fosmark
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
||||
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
# The views and conclusions contained in the software and documentation are those
|
||||
# of the authors and should not be interpreted as representing official policies,
|
||||
# either expressed or implied, of the FreeBSD Project.
|
||||
|
||||
import os
|
||||
import time
|
||||
import errno
|
||||
|
||||
|
||||
class FileLockException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FileLock(object):
|
||||
""" A file locking mechanism that has context-manager support so
|
||||
you can use it in a with statement. This should be relatively cross
|
||||
compatible as it doesn't rely on msvcrt or fcntl for the locking.
|
||||
"""
|
||||
|
||||
lock_count = {} # lock count for each locked file
|
||||
lock_handles = {} # file handle for each locked file
|
||||
|
||||
def __init__(self, file_name, timeout=10, delay=0.05):
|
||||
""" Prepare the file locker. Specify the file to lock and optionally
|
||||
the maximum timeout and the delay between each attempt to lock.
|
||||
"""
|
||||
self.fd = None
|
||||
self.is_locked = False
|
||||
self.lockfile = os.path.join(os.getcwd(), "%s.lock" % file_name)
|
||||
self.file_name = file_name
|
||||
self.timeout = timeout
|
||||
self.delay = delay
|
||||
|
||||
def acquire(self):
|
||||
""" Acquire the lock, if possible. If the lock is in use, it check again
|
||||
every `wait` seconds. It does this until it either gets the lock or
|
||||
exceeds `timeout` number of seconds, in which case it throws
|
||||
an exception.
|
||||
"""
|
||||
# Don't try to lock the same file more than once
|
||||
if FileLock.lock_count.setdefault(self.lockfile, 0) > 0:
|
||||
self.is_locked = True
|
||||
self.fd = FileLock.lock_handles[self.lockfile]
|
||||
FileLock.lock_count[self.lockfile] += 1
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
while True:
|
||||
try:
|
||||
# create the cache directory if it does not exist (new installations will not have a cache directory)
|
||||
stimdir = os.path.dirname(self.lockfile)
|
||||
cachedir = os.path.dirname(stimdir)
|
||||
if not os.path.isdir(cachedir): # make sure the cache exists
|
||||
os.mkdir(cachedir)
|
||||
if not os.path.isdir(stimdir): # and the specific stimulus dir
|
||||
os.mkdir(stimdir)
|
||||
self.fd = os.open(self.lockfile, os.O_CREAT | os.O_EXCL | os.O_RDWR)
|
||||
open(self.lockfile, "w").write(str(os.getpid()))
|
||||
break
|
||||
except OSError as e:
|
||||
if e.errno != errno.EEXIST:
|
||||
raise
|
||||
if (time.time() - start_time) >= self.timeout:
|
||||
try:
|
||||
pid = open(self.lockfile).read()
|
||||
except Exception:
|
||||
pid = "[error reading lockfile: %s]" % sys.exc_info()[0]
|
||||
raise FileLockException(
|
||||
"Timeout occured. (%s is locked by pid %s)"
|
||||
% (self.lockfile, pid)
|
||||
)
|
||||
time.sleep(self.delay)
|
||||
|
||||
self.is_locked = True
|
||||
FileLock.lock_count[self.lockfile] += 1
|
||||
FileLock.lock_handles[self.lockfile] = self.fd
|
||||
|
||||
def release(self):
|
||||
""" Get rid of the lock by deleting the lockfile.
|
||||
When working in a `with` statement, this gets automatically
|
||||
called at the end.
|
||||
"""
|
||||
if self.is_locked:
|
||||
self.is_locked = False
|
||||
FileLock.lock_count[self.lockfile] -= 1
|
||||
if FileLock.lock_count[self.lockfile] == 0:
|
||||
os.close(self.fd)
|
||||
os.unlink(self.lockfile)
|
||||
del FileLock.lock_handles[self.lockfile]
|
||||
|
||||
def __enter__(self):
|
||||
""" Activated when used in the with statement.
|
||||
Should automatically acquire a lock to be used in the with block.
|
||||
"""
|
||||
if not self.is_locked:
|
||||
self.acquire()
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
""" Activated at the end of the with statement.
|
||||
It automatically releases the lock if it isn't locked.
|
||||
"""
|
||||
if self.is_locked:
|
||||
self.release()
|
||||
|
||||
def __del__(self):
|
||||
""" Make sure that the FileLock instance doesn't leave a lockfile
|
||||
lying around.
|
||||
"""
|
||||
self.release()
|
||||
95
cnmodel/util/find_point.py
Normal file
95
cnmodel/util/find_point.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from __future__ import print_function
|
||||
from scipy import interpolate
|
||||
import numpy as np
|
||||
|
||||
|
||||
def find_point(x, y, peakindex, val, direction="left", limits=None):
|
||||
"""
|
||||
Given a waveform defined by *x* and *y* arrays, return the first time
|
||||
at which the waveform crosses (y[peakindex] * val). The search begins at
|
||||
*peakindex* and proceeds in *direction*.
|
||||
|
||||
Optionally, *limits* may specify a smaller search region in the form of
|
||||
(t0, t1, dt).
|
||||
"""
|
||||
# F = interpolate.UnivariateSpline(x, y, s=0) # declare function
|
||||
# To find x at y then do:
|
||||
istart = 0
|
||||
iend = len(y)
|
||||
if limits is not None:
|
||||
istart = int(limits[0] / limits[2])
|
||||
iend = int(limits[1] / limits[2])
|
||||
yToFind = y[peakindex] * val
|
||||
if direction == "left":
|
||||
yreduced = np.array(y[istart:peakindex]) - yToFind
|
||||
try:
|
||||
Fr = interpolate.UnivariateSpline(x[istart:peakindex], yreduced, s=0)
|
||||
except:
|
||||
print("find_point: insufficient time points for analysis")
|
||||
print("arg lengths:", len(x[istart:peakindex]), len(yreduced))
|
||||
print("istart, peakindex: ", istart, peakindex)
|
||||
print("ytofine: ", yToFind)
|
||||
raise
|
||||
res = float("nan")
|
||||
return res
|
||||
res = Fr.roots()
|
||||
if len(res) > 1:
|
||||
res = res[-1]
|
||||
else:
|
||||
yreduced = np.array(y[peakindex:iend]) - yToFind
|
||||
try:
|
||||
Fr = interpolate.UnivariateSpline(x[peakindex:iend], yreduced, s=0)
|
||||
except:
|
||||
print("find_point: insufficient time points for analysis?")
|
||||
print("arg lengths:", len(x[peakindex:iend]), len(yreduced))
|
||||
raise
|
||||
res = float("nan")
|
||||
return res
|
||||
res = Fr.roots()
|
||||
if len(res) > 1:
|
||||
res = res[0]
|
||||
# pdb.set_trace()
|
||||
try:
|
||||
res.pop()
|
||||
except:
|
||||
pass
|
||||
if not res: # tricky - an empty list is False, but does not evaluate to False
|
||||
res = float("nan") # replace with a NaN
|
||||
else:
|
||||
res = float(res) # make sure is just a simple number (no arrays)
|
||||
return res
|
||||
|
||||
|
||||
def find_crossing(data, start=0, direction=1, threshold=0):
|
||||
"""Return the index at which *data* crosses *threshold*, starting
|
||||
at index *start* and proceeding in *direction* (+/-1).
|
||||
|
||||
The value returned is a float indicating the interpolated index
|
||||
position where the data crosses threshold, or NaN if the threshold was
|
||||
never crossed.
|
||||
"""
|
||||
|
||||
# Note: this function is very similar in purpose to find_point, but was
|
||||
# added due to issues with interpolate.UnivariateSpline
|
||||
|
||||
assert direction in (1, -1)
|
||||
cross_rising = data[start] < threshold
|
||||
|
||||
def test(x):
|
||||
if cross_rising:
|
||||
return x > threshold
|
||||
else:
|
||||
return x < threshold
|
||||
|
||||
while True:
|
||||
next_ind = start + direction
|
||||
if next_ind < 0 or next_ind >= len(data):
|
||||
return np.nan
|
||||
|
||||
if test(data[next_ind]):
|
||||
# crossed; return interpolated value
|
||||
s1 = data[next_ind] - threshold
|
||||
s2 = threshold = data[start]
|
||||
return (next_ind * s2 + start * s1) / (s2 + s1)
|
||||
|
||||
start = next_ind
|
||||
219
cnmodel/util/fitting.py
Normal file
219
cnmodel/util/fitting.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import numpy as np
|
||||
import lmfit
|
||||
import pyqtgraph as pg
|
||||
from pyqtgraph.Qt import QtCore, QtGui
|
||||
|
||||
|
||||
class FitModel(lmfit.Model):
|
||||
""" Simple extension of lmfit.Model that allows one-line fitting.
|
||||
|
||||
Example uses:
|
||||
|
||||
# single exponential fit::
|
||||
|
||||
fit = expfitting.Exp1.fit(data,
|
||||
x=time_vals,
|
||||
xoffset=(0, 'fixed'),
|
||||
yoffset=(yoff_guess, -120, 0),
|
||||
amp=(amp_guess, 0, 50),
|
||||
tau=(tau_guess, 0.1, 50))
|
||||
|
||||
# plot the fit::
|
||||
|
||||
fit_curve = fit.eval()
|
||||
plot(time_vals, fit_curve)
|
||||
|
||||
|
||||
# double exponential fit with tau ratio constraint
|
||||
# note that 'tau_ratio' does not appear in the exp2 model;
|
||||
# we can define new parameters here.::
|
||||
|
||||
fit = expfitting.Exp2.fit(data,
|
||||
x=time_vals,
|
||||
xoffset=(0, 'fixed'),
|
||||
yoffset=(yoff_guess, -120, 0),
|
||||
amp1=(amp_guess, 0, 50),
|
||||
tau1=(tau_guess, 0.1, 50),
|
||||
amp2=(-0.5, -50, 0),
|
||||
tau_ratio=(10, 3, 50),
|
||||
tau2='tau1 * tau_ratio'
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
def fit(self, data, interactive=False, **params):
|
||||
""" Return a fit of data to this model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : array
|
||||
dependent data to fit
|
||||
interactive : bool
|
||||
If True, show a GUI used for interactively exploring fit parameters
|
||||
|
||||
Extra keyword arguments are passed to make_params() if they are model
|
||||
parameter names, or passed directly to Model.fit() for independent
|
||||
variable names.
|
||||
|
||||
Returns
|
||||
-------
|
||||
fit of data to model as an lmfit object
|
||||
"""
|
||||
|
||||
fit_params = {}
|
||||
model_params = {}
|
||||
for k, v in params.items():
|
||||
if k in self.independent_vars or k in [
|
||||
"weights",
|
||||
"method",
|
||||
"scale_covar",
|
||||
"iter_cb",
|
||||
]:
|
||||
fit_params[k] = v
|
||||
else:
|
||||
model_params[k] = v
|
||||
p = self.make_params(**model_params)
|
||||
# print ('params: ', p)
|
||||
# print ('fitparams: ', fit_params)
|
||||
# import matplotlib.pyplot as mpl
|
||||
# mpl.plot(data)
|
||||
# mpl.show()
|
||||
|
||||
fit = lmfit.Model.fit(self, data, params=p, **fit_params)
|
||||
if interactive:
|
||||
self.show_interactive(fit)
|
||||
return fit
|
||||
|
||||
def make_params(self, **params):
|
||||
"""
|
||||
Make parameters used for fitting with this model.
|
||||
|
||||
Keyword arguments are used to generate parameters for the fit. Each
|
||||
parameter may be specified by the following formats:
|
||||
|
||||
param=value :
|
||||
The initial value of the parameter
|
||||
param=(value, 'fixed') :
|
||||
Fixed value for the parameter
|
||||
param=(value, min, max) :
|
||||
Initial value and min, max values, which may be float or None
|
||||
param='expression' :
|
||||
Expression used to compute parameter value. See:
|
||||
http://lmfit.github.io/lmfit-py/constraints.html#constraints-chapter
|
||||
"""
|
||||
p = lmfit.Parameters()
|
||||
for k in self.param_names:
|
||||
p.add(k)
|
||||
|
||||
for param, val in params.items():
|
||||
if param not in p:
|
||||
p.add(param)
|
||||
|
||||
if isinstance(val, str):
|
||||
p[param].expr = val
|
||||
elif np.isscalar(val):
|
||||
p[param].value = val
|
||||
elif isinstance(val, tuple):
|
||||
if len(val) == 2:
|
||||
assert val[1] == "fixed"
|
||||
p[param].value = val[0]
|
||||
p[param].vary = False
|
||||
elif len(val) == 3:
|
||||
p[param].value = val[0]
|
||||
p[param].min = val[1]
|
||||
p[param].max = val[2]
|
||||
else:
|
||||
raise TypeError(
|
||||
"Tuple parameter specifications must be (val, 'fixed')"
|
||||
" or (val, min, max)."
|
||||
)
|
||||
else:
|
||||
raise TypeError("Invalid parameter specification: %r" % val)
|
||||
|
||||
# set initial values for parameters with mathematical constraints
|
||||
# this is to allow fit.eval(**fit.init_params)
|
||||
global_ns = np.__dict__
|
||||
for param, val in params.items():
|
||||
if isinstance(val, str):
|
||||
p[param].value = eval(val, global_ns, p.valuesdict())
|
||||
return p
|
||||
|
||||
def show_interactive(self, fit=None):
|
||||
""" Show an interactive GUI for exploring fit parameters.
|
||||
"""
|
||||
if not hasattr(self, "_interactive_win"):
|
||||
self._interactive_win = FitExplorer(model=self, fit=fit)
|
||||
self._interactive_win.show()
|
||||
|
||||
|
||||
def exp1(x, xoffset, yoffset, tau, amp):
|
||||
return yoffset + amp * np.exp(-(x - xoffset) / tau)
|
||||
|
||||
|
||||
class Exp1(FitModel):
|
||||
""" Single exponential fitting model.
|
||||
|
||||
Parameters are xoffset, yoffset, amp, and tau.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
FitModel.__init__(self, exp1, independent_vars=["x"])
|
||||
|
||||
|
||||
def exp2(x, xoffset, yoffset, tau1, amp1, tau2, amp2):
|
||||
xoff = x - xoffset
|
||||
return yoffset + amp1 * np.exp(-xoff / tau1) + amp2 * np.exp(-xoff / tau2)
|
||||
|
||||
|
||||
class Exp2(FitModel):
|
||||
""" Double exponential fitting model.
|
||||
|
||||
Parameters are xoffset, yoffset, amp1, tau1, amp2, and tau2.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
FitModel.__init__(self, exp2, independent_vars=["x"])
|
||||
|
||||
|
||||
class FitExplorer(QtGui.QWidget):
|
||||
def __init__(self, model, fit):
|
||||
QtGui.QWidget.__init__(self)
|
||||
self.model = model
|
||||
self.fit = fit
|
||||
self.layout = QtGui.QGridLayout()
|
||||
self.setLayout(self.layout)
|
||||
self.splitter = QtGui.QSplitter(QtCore.Qt.Horizontal)
|
||||
self.layout.addWidget(self.splitter)
|
||||
self.ptree = pg.parametertree.ParameterTree()
|
||||
self.splitter.addWidget(self.ptree)
|
||||
self.plot = pg.PlotWidget()
|
||||
self.splitter.addWidget(self.plot)
|
||||
|
||||
self.params = pg.parametertree.Parameter.create(
|
||||
name="param_root",
|
||||
type="group",
|
||||
children=[
|
||||
dict(name="fit", type="action"),
|
||||
dict(name="parameters", type="group"),
|
||||
],
|
||||
)
|
||||
|
||||
for k in fit.params:
|
||||
p = pg.parametertree.Parameter.create(
|
||||
name=k, type="float", value=fit.params[k].value
|
||||
)
|
||||
self.params.param("parameters").addChild(p)
|
||||
|
||||
self.ptree.setParameters(self.params)
|
||||
|
||||
self.update_plots()
|
||||
|
||||
self.params.param("parameters").sigTreeStateChanged.connect(self.update_plots)
|
||||
|
||||
def update_plots(self):
|
||||
for k in self.fit.params:
|
||||
self.fit.params[k].value = self.params["parameters", k]
|
||||
|
||||
self.plot.clear()
|
||||
self.plot.plot(self.fit.data)
|
||||
self.plot.plot(self.fit.eval(), pen="y")
|
||||
477
cnmodel/util/get_anspikes.py
Normal file
477
cnmodel/util/get_anspikes.py
Normal file
@@ -0,0 +1,477 @@
|
||||
from __future__ import print_function
|
||||
|
||||
__author__ = "pbmanis"
|
||||
"""
|
||||
ManageANSpikes is a class to read the output of the Zilany et al. 2009 AN model into
|
||||
python, and provides services to access that data.
|
||||
|
||||
Basic usage is to create an instance of the class, and specify the data directory
|
||||
if necessary.
|
||||
|
||||
You may then get the data in the format of a list using one of the "get" routines.
|
||||
The data is pulled from the nearest CF or the specified CF.
|
||||
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
import scipy.io
|
||||
import matplotlib.pyplot as MP
|
||||
|
||||
|
||||
class ManageANSpikes:
|
||||
"""
|
||||
ManageANSpikes is a class to read the output of the Zilany et al. 2009 AN model into
|
||||
python, and provides services to access that data.
|
||||
|
||||
Basic usage is to create an instance of the class, and specify the data directory
|
||||
if necessary.
|
||||
|
||||
You may then get the data in the format of a list using one of the "get" routines.
|
||||
The data is pulled from the nearest CF or the specified CF.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# self.datadir = environment['HOME'] + '/Desktop/Matlab/ZilanyCarney-JASAcode-2009/'
|
||||
self.data_dir = os.environ["HOME"] + "/Desktop/Matlab/ANData/"
|
||||
self.data_read_flag = False
|
||||
self.dataType = "RI"
|
||||
self.set_CF_map(4000, 38000, 25) # the default list
|
||||
self.all_AN = None
|
||||
|
||||
def get_data_dir(self):
|
||||
return self.data_dir
|
||||
|
||||
def set_data_dir(self, directory):
|
||||
if os.path.isdir(directory):
|
||||
self.data_dir = directory
|
||||
else:
|
||||
raise ValueError("ManageANSpikes.set_data_dir: Path %d is not a directory")
|
||||
|
||||
def get_data_read(self):
|
||||
return self.data_read_flag
|
||||
|
||||
def get_CF_map(self):
|
||||
return self.CF_map
|
||||
|
||||
def set_CF_map(self, low, high, nfreq):
|
||||
self.CF_map = np.round(np.logspace(np.log10(low), np.log10(high), nfreq))
|
||||
|
||||
def get_dataType(self):
|
||||
return self.dataType
|
||||
|
||||
def set_dataType(self, datatype):
|
||||
if datatype in ["RI", "PL"]: # only for recognized data types
|
||||
self.dataType = datatype
|
||||
else:
|
||||
raise ValueError(
|
||||
"get_anspikes.set_dataType: unrecognized type %s " % datatype
|
||||
)
|
||||
|
||||
def plot_RI_vs_F(self, freq=10000.0, spontclass="HS", display=False):
|
||||
cfd = {}
|
||||
if display:
|
||||
MP.figure(10)
|
||||
for i, fr in enumerate(self.CF_map):
|
||||
self.read_AN_data(freq=freq, CF=fr, spontclass=spontclass, ignoreflag=True)
|
||||
nsp = np.zeros(len(self.SPLs)) # this is the same regardless
|
||||
for i, db in enumerate(self.SPLs):
|
||||
spkl = self.combine_reps(self.spikelist[i])
|
||||
nsp[i] = len(spkl) # /self.SPLs[i]
|
||||
if display:
|
||||
MP.plot(self.SPLs, nsp)
|
||||
|
||||
def plot_Rate_vs_F(self, freq=10000.0, spontclass="HS", display=False, SPL=0.0):
|
||||
"""
|
||||
assumes we have done a "read all"
|
||||
|
||||
:param freq:
|
||||
:param spontclass:
|
||||
:param display:
|
||||
:param SPL:
|
||||
:return:
|
||||
"""
|
||||
cfd = {}
|
||||
if display:
|
||||
MP.figure(10)
|
||||
for i, db in enumerate(self.SPLs): # plot will have lines at each SPL
|
||||
# retrieve_from_all_AN(self, cf, SPL) # self.read_AN_data(freq=freq, CF=fr, spontclass=spontclass)
|
||||
nsp = np.zeros(len(self.CF_map)) # this is the same regardless
|
||||
for j, fr in enumerate(self.CF_map):
|
||||
spikelist = self.retrieve_from_all_AN(fr, db)
|
||||
spkl = self.combine_reps(spikelist)
|
||||
nsp[j] = len(spkl) / len(spikelist)
|
||||
if display:
|
||||
MP.plot(self.CF_map, nsp)
|
||||
|
||||
def getANatFandSPL(self, spontclass="MS", freq=10000.0, CF=None, SPL=0):
|
||||
"""
|
||||
Get the AN data at a particular frequency and SPL. Note that the tone freq is specified,
|
||||
but the CF of the fiber might be different. If CF is None, we try to get the closest
|
||||
fiber data to the tone Freq.
|
||||
Other arguments are the spontaneous rate and the SPL level.
|
||||
The data must exist, or we epically fail.
|
||||
|
||||
:param spontclass: 'HS', 'MS', or 'LS' for high, middle or low spont groups (1,2,3 in Zilany et al model)
|
||||
:param freq: The stimulus frequency, in Hz
|
||||
:param CF: The 'characteristic frequency' of the desired AN fiber, in Hz
|
||||
:param SPL: The sound pressure level, in dB SPL for the stimulus
|
||||
:return: an array of nReps of spike times.
|
||||
"""
|
||||
|
||||
if CF is None:
|
||||
closest = np.argmin(
|
||||
np.abs(self.CF_map - freq)
|
||||
) # find closest to the stim freq
|
||||
else:
|
||||
closest = np.argmin(
|
||||
np.abs(self.CF_map - CF)
|
||||
) # find closest to the stim freq
|
||||
CF = self.CF_map[closest]
|
||||
print("closest f: ", CF)
|
||||
self.read_AN_data(spontclass=spontclass, freq=10000.0, CF=CF)
|
||||
return self.get_AN_at_SPL(SPL)
|
||||
|
||||
def get_AN_at_SPL(self, spl=None):
|
||||
"""
|
||||
grabs the AN data for the requested SPL for the data set currently loaded
|
||||
The data must exist, or we epically fail.
|
||||
|
||||
:param spl: sound pressure level, in dB SPL
|
||||
:return: spike trains (times) as a list of nReps numpy arrays.
|
||||
"""
|
||||
|
||||
if not self.data_read_flag:
|
||||
print("getANatSPL: No data read yet")
|
||||
return None
|
||||
try:
|
||||
k = int(np.where(self.SPLs == spl)[0])
|
||||
except (ValueError, TypeError) as e:
|
||||
print(
|
||||
"get_anspikes::getANatSPL: spl=%6.1f not in list of spls:" % (spl),
|
||||
self.SPLs,
|
||||
)
|
||||
exit() # no match
|
||||
# now clean up array to keep it iterable upon return
|
||||
spkl = [[]] * len(self.spikelist[k])
|
||||
for i in xrange(len(self.spikelist[k])):
|
||||
try:
|
||||
len(self.spikelist[k][i])
|
||||
spkl[i] = self.spikelist[k][i]
|
||||
except:
|
||||
spkl[i] = [np.array(self.spikelist[k][i])]
|
||||
return spkl # returns all the trials in the list
|
||||
|
||||
def read_all_ANdata(
|
||||
self, freq=10000.0, CFList=None, spontclass="HS", stim="BFTone"
|
||||
):
|
||||
"""
|
||||
Reads a bank of AN data, across frequency and intensity, for a given stimulus frequency
|
||||
Assumptions: the bank of data is consistent in terms of nReps and SPLs
|
||||
|
||||
:param freq: Tone stimulus frequency (if stim is Tone or BFTone)
|
||||
:param CFList: A list of the CFs to read
|
||||
:param spontclass: 'HS', 'MS', or 'LS', corresponding to spont rate groups
|
||||
:param stim: Stimulus type - Tone/BFTone, or Noise
|
||||
:return: Nothing. The data are stored in an array accessible directly or through a selector function
|
||||
"""
|
||||
|
||||
self.all_AN = OrderedDict(
|
||||
[(f, None) for f in CFList]
|
||||
) # access through dictionary with keys as freq
|
||||
for cf in CFList:
|
||||
self.read_AN_data(
|
||||
freq=freq, CF=cf, spontclass=spontclass, stim=stim, setflag=False
|
||||
)
|
||||
self.all_AN[cf] = self.spikelist
|
||||
self.data_read_flag = True # inform and block
|
||||
|
||||
def retrieve_from_all_AN(self, cf, SPL):
|
||||
"""
|
||||
:param cf:
|
||||
:param SPL:
|
||||
:return: spike list (all nreps) at the cf and SPL requested, for the loaded stimulus set
|
||||
|
||||
"""
|
||||
|
||||
if not self.data_read_flag or self.all_AN is None:
|
||||
print("get_anspikes::retrieve_from_all_AN: No data read yet: ")
|
||||
exit()
|
||||
|
||||
ispl = int(np.where(self.SPLs == float(SPL))[0])
|
||||
icf = self.all_AN.keys().index(cf)
|
||||
spikelist = self.all_AN[cf][ispl]
|
||||
spkl = [[]] * len(spikelist)
|
||||
# print len(spikelist)
|
||||
for i in xrange(len(spikelist)): # across trials
|
||||
try:
|
||||
len(spikelist[i])
|
||||
spkl[i] = spikelist[i]
|
||||
except:
|
||||
spkl[i] = [np.array(spikelist[i])]
|
||||
# print 'cf: %f spl: %d nspk: %f' % (cf, SPL, len(spkl[i]))
|
||||
return spikelist
|
||||
|
||||
def read_AN_data(
|
||||
self,
|
||||
freq=10000,
|
||||
CF=5300,
|
||||
spontclass="HS",
|
||||
display=False,
|
||||
stim="BFTone",
|
||||
setflag=True,
|
||||
ignoreflag=True,
|
||||
):
|
||||
"""
|
||||
read responses of auditory nerve model of Zilany et al. (2009).
|
||||
display = True plots the psth's for all ANF channels in the dataset
|
||||
Request response to stimulus at Freq, for fiber with CF
|
||||
and specified spont rate
|
||||
This version is for rate-intensity runs March 2014.
|
||||
Returns: Nothing
|
||||
"""
|
||||
|
||||
if not ignoreflag:
|
||||
assert self.data_read_flag == False
|
||||
# print 'Each instance of ManageANSPikes is allowed ONE data set to manage'
|
||||
# print 'This is a design decision to avoid confusion about which data is in the instance'
|
||||
# exit()
|
||||
if stim in ["BFTone", "Tone"]:
|
||||
fname = "%s_F%06.3f_CF%06.3f_%2s.mat" % (
|
||||
self.dataType,
|
||||
freq / 1000.0,
|
||||
CF / 1000.0,
|
||||
spontclass,
|
||||
)
|
||||
elif stim == "Noise":
|
||||
fname = "%s_Noise_CF%06.3f_%2s.mat" % (
|
||||
self.dataType,
|
||||
CF / 1000.0,
|
||||
spontclass,
|
||||
)
|
||||
|
||||
# print 'Reading: %s' % (fname)
|
||||
try:
|
||||
mfile = scipy.io.loadmat(
|
||||
os.path.join(self.data_dir, fname), squeeze_me=True
|
||||
)
|
||||
except IOError:
|
||||
print("get_anspikes::read_AN_data: Failed to find data file %s" % (fname))
|
||||
print(
|
||||
"Corresponding to Freq: %f CF: %f spontaneous rate class: %s"
|
||||
% (freq, CF, spontclass)
|
||||
)
|
||||
exit()
|
||||
|
||||
n_spl = len(mfile[self.dataType])
|
||||
if display:
|
||||
n = int(np.sqrt(n_spl)) + 1
|
||||
mg = np.meshgrid(range(n), range(n))
|
||||
mg = [mg[0].flatten(), mg[1].flatten()]
|
||||
spkl = [[]] * n_spl
|
||||
spl = np.zeros(n_spl)
|
||||
for k in range(n_spl):
|
||||
spkl[k] = mfile[self.dataType]["data"][k] # get data for one SPL
|
||||
spl[k] = mfile[self.dataType]["SPL"][k] # get SPL for this set of runs
|
||||
# print spkl[k].shape
|
||||
if display:
|
||||
self.display(spkl[k], k, n, mg)
|
||||
|
||||
if display:
|
||||
MP.show()
|
||||
self.spikelist = spkl # save these so we can have a single point to parse them
|
||||
self.SPLs = spl
|
||||
self.n_reps = mfile[self.dataType]["nrep"]
|
||||
if setflag:
|
||||
self.data_read_flag = True
|
||||
# return spkl, spl, mfile['RI']['nrep']
|
||||
|
||||
def getANatSPL(self, spl=None):
|
||||
if not self.dataRead:
|
||||
print("getANatSPL: No data read yet")
|
||||
return None
|
||||
try:
|
||||
k = int(np.where(self.SPLs == spl)[0])
|
||||
except (ValueError, TypeError) as e:
|
||||
print(
|
||||
"get_anspikes::getANatSPL: spl=%6.1f not in list of spls:" % (spl),
|
||||
self.SPLs,
|
||||
)
|
||||
exit() # no match
|
||||
# now clean up array to keep it iterable upon return
|
||||
spkl = [[]] * len(self.spikelist[k])
|
||||
for i in xrange(len(self.spikelist[k])):
|
||||
try:
|
||||
len(self.spikelist[k][i])
|
||||
spkl[i] = self.spikelist[k][i]
|
||||
except:
|
||||
spkl[i] = [np.array(self.spikelist[k][i])]
|
||||
return spkl # returns all the trials in the list
|
||||
|
||||
def get_AN_info(self):
|
||||
"""
|
||||
:return: Dictionary of nReps and SPLs that are in the current data set (instance).
|
||||
"""
|
||||
|
||||
if not self.data_read_flag:
|
||||
print("getANatSPL: No data read yet")
|
||||
return None
|
||||
return {"nReps": self.n_reps, "SPLs": self.SPLs}
|
||||
|
||||
def read_cmrr_data(self, P, signal="S0", display=False):
|
||||
"""
|
||||
read responses of auditory nerve model of Zilany et al. (2009).
|
||||
The required parameters are passed via the class P.
|
||||
This includes:
|
||||
|
||||
the SN (s2m) is the signal-to-masker ratio to select
|
||||
the mode is 'CMR', 'CMD' (deviant) or 'REF' (reference, no flanking bands),
|
||||
'Tone', or 'Noise'
|
||||
the Spont rate group
|
||||
|
||||
signal is either 'S0' (signal present) or 'NS' (no signal)
|
||||
display = True plots the psth's for all ANF channels in the dataset
|
||||
Returns:
|
||||
tuple of (spikelist and frequency list for modulated tones)
|
||||
"""
|
||||
|
||||
if P.modF == 10:
|
||||
datablock = "%s_F4000.0" % (P.mode)
|
||||
else:
|
||||
datablock = "%s_F4000.0_M%06.1f" % (
|
||||
P.mode,
|
||||
P.modF,
|
||||
) # 100 Hz modulation directory
|
||||
# print P.mode
|
||||
fs = os.listdir(self.datadir + datablock)
|
||||
s_sn = "SN%03d" % (P.s2m)
|
||||
fntemplate = "(\S+)_%s_%s_%s_%s.mat" % (s_sn, P.mode, signal, P.SR)
|
||||
p = re.compile(fntemplate)
|
||||
fl = [re.match(p, file) for file in fs]
|
||||
fl = [f.group(0) for f in fl if f != None] # returns list of files matching...
|
||||
fr = [float(f[1:7]) for f in fl] # frequency list
|
||||
if display:
|
||||
self.makeFig()
|
||||
i = 0
|
||||
j = 0
|
||||
spkl = [[]] * len(fl)
|
||||
n = int(np.sqrt(len(fl)))
|
||||
mg = np.meshgrid(range(n), range(n))
|
||||
mg = [mg[0].flatten(), mg[1].flatten()]
|
||||
for k, fi in enumerate(fl):
|
||||
mfile = scipy.io.loadmat(
|
||||
self.datadir + datablock + "/" + fi, squeeze_me=True
|
||||
)
|
||||
spkl[k] = mfile["Cpsth"]
|
||||
if display:
|
||||
self.display(spkl[k], k, n, mg)
|
||||
|
||||
if display:
|
||||
MP.show()
|
||||
return (spkl, fr)
|
||||
|
||||
def make_fig(self):
|
||||
self.fig = MP.figure(1)
|
||||
|
||||
def combine_reps(self, spkl):
|
||||
"""
|
||||
|
||||
Just turns the spike list into one big linear sequence.
|
||||
:param spkl: a spike train (nreps of numpy arrays)
|
||||
:return: all the spikes in one long array.
|
||||
"""
|
||||
|
||||
# print spkl
|
||||
allsp = np.array(self.flatten(spkl))
|
||||
# print allsp.shape
|
||||
if allsp.shape == (): # nothing to show
|
||||
return
|
||||
# print allsp
|
||||
spks = []
|
||||
for i in range(len(allsp)):
|
||||
if allsp[i] == ():
|
||||
continue
|
||||
if isinstance(allsp[i], float):
|
||||
spks.append(allsp[i])
|
||||
else:
|
||||
spks.extend(np.array(allsp[i]))
|
||||
return spks
|
||||
|
||||
def display(self, spkl, k, n, mg):
|
||||
# print spkl
|
||||
spks = self.combine_reps(spkl)
|
||||
if spks != []:
|
||||
spks = np.sort(spks)
|
||||
MP.hist(spks, 100)
|
||||
return
|
||||
|
||||
def flatten(self, x):
|
||||
"""
|
||||
flatten(sequence) -> list
|
||||
|
||||
Returns a single, flat list which contains all elements retrieved
|
||||
from the sequence and all recursively contained sub-sequences
|
||||
(iterables).
|
||||
"""
|
||||
result = []
|
||||
if isinstance(x, float) or len([x]) <= 1:
|
||||
return x
|
||||
for el in x:
|
||||
if hasattr(el, "__iter__") and not isinstance(el, basestring):
|
||||
result.extend(self.flatten(el))
|
||||
else:
|
||||
result.append(el)
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from cnmodel.util import Params
|
||||
|
||||
modes = ["CMR", "CMD", "REF"]
|
||||
sr_types = ["H", "M", "L"]
|
||||
sr_type = sr_types[0]
|
||||
sr_names = {
|
||||
"H": "HS",
|
||||
"M": "MS",
|
||||
"L": "LS",
|
||||
} # spontaneous rate groups (AN fiber input selection)
|
||||
# define params for reading/testing. This is a subset of params...
|
||||
P = Params(
|
||||
mode=modes[0],
|
||||
s2m=0,
|
||||
n_rep=0,
|
||||
start_rep=0,
|
||||
sr=sr_names["L"],
|
||||
modulation_freq=10,
|
||||
DS_phase=0,
|
||||
dataset="test",
|
||||
sr_types=sr_types,
|
||||
sr_names=sr_names,
|
||||
fileVersion=1.5,
|
||||
)
|
||||
|
||||
manager = ManageANSpikes() # create instance of the manager
|
||||
print(dir(manager))
|
||||
test = "RI"
|
||||
|
||||
if test == "all":
|
||||
manager.read_all_ANdata(
|
||||
freq=10000.0, CFList=manager.CF_map, spontclass="HS", stim="BFTone"
|
||||
)
|
||||
# spkl = manager.retrieve_from_all_AN(manager.CF_map[0], 40.)
|
||||
# spks = manager.combine_reps(spkl)
|
||||
manager.plot_Rate_vs_F(freq=manager.CF_map[0], display=True, spontclass="HS")
|
||||
MP.show()
|
||||
|
||||
if test == manager.dataType:
|
||||
spikes = manager.getANatFandSPL(spontclass="MS", freq=10000.0, CF=None, SPL=50)
|
||||
spks = manager.combine_reps(spikes)
|
||||
if spks != []:
|
||||
print("spikes.")
|
||||
manager.plot_RI_vs_F(freq=10000.0, display=True, spontclass="MS")
|
||||
MP.figure(11)
|
||||
spks = np.sort(spks)
|
||||
MP.hist(spks, 100)
|
||||
MP.show()
|
||||
402
cnmodel/util/matlab_proc.py
Normal file
402
cnmodel/util/matlab_proc.py
Normal file
@@ -0,0 +1,402 @@
|
||||
from __future__ import print_function
|
||||
|
||||
"""
|
||||
Simple system for interfacing with a MATLAB process using stdin/stdout pipes.
|
||||
|
||||
"""
|
||||
from .process import Process
|
||||
from io import StringIO
|
||||
import scipy.io
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import os, sys, glob, signal
|
||||
import weakref
|
||||
import atexit
|
||||
|
||||
|
||||
class MatlabProcess(object):
|
||||
""" This class starts a new matlab process, allowing remote control of
|
||||
the interpreter and transfer of data between python and matlab.
|
||||
"""
|
||||
|
||||
# Implements a MATLAB CLI that is a bit easier for us to parse
|
||||
_bootstrap = r"""
|
||||
while true
|
||||
fprintf('\n::ready\n');
|
||||
|
||||
% Accumulate command lines until we see ::cmd_done
|
||||
cmd = '';
|
||||
empty_count = 0;
|
||||
while true
|
||||
line = input('', 's');
|
||||
if length(line) == 0
|
||||
% If we encounter many empty lines, assume the host process
|
||||
% has ended. Is there a better way to detect this?
|
||||
empty_count = empty_count + 1;
|
||||
if empty_count > 100
|
||||
fprintf('Shutting down!\n');
|
||||
exit;
|
||||
end
|
||||
end
|
||||
if strcmp(line, '::cmd_done')
|
||||
break
|
||||
end
|
||||
cmd = [cmd, line, sprintf('\n')];
|
||||
end
|
||||
|
||||
% Evaluate command
|
||||
try
|
||||
%fprintf('EVAL: %s\n', cmd);
|
||||
eval(cmd);
|
||||
fprintf('\n::ok\n');
|
||||
catch err
|
||||
fprintf('\n::err\n');
|
||||
fprintf(['::message:', err.message, '\n']);
|
||||
fprintf(['::identifier:', err.identifier, '\n']);
|
||||
for i = 1:length(err.stack)
|
||||
frame = err.stack(i,1);
|
||||
fprintf(['::stack:', frame.name, ' in ', frame.file, ' line ', frame.line]);
|
||||
end
|
||||
end
|
||||
end
|
||||
"""
|
||||
|
||||
def __init__(self, executable=None, **kwds):
|
||||
self.__closed = False
|
||||
self.__refs = weakref.WeakValueDictionary()
|
||||
|
||||
# Decide which executables to try
|
||||
if executable is not None:
|
||||
execs = [executable]
|
||||
else:
|
||||
execs = ["matlab"] # always pick the matlab in the path, if available
|
||||
if sys.platform == "darwin":
|
||||
installed = glob.glob("/Applications/MATLAB_R*")
|
||||
installed.sort(reverse=True)
|
||||
execs.extend([os.path.join(p, "bin", "matlab") for p in installed])
|
||||
|
||||
# try starting each in order until one works
|
||||
self.__proc = None
|
||||
for exe in execs:
|
||||
try:
|
||||
self.__proc = Process([exe, "-nodesktop", "-nosplash"], **kwds)
|
||||
break
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
pass
|
||||
|
||||
# bail out if we couldn't start any
|
||||
if self.__proc is None:
|
||||
raise RuntimeError(
|
||||
"Could not start MATLAB process.\nPaths attempted: %s "
|
||||
"\nLast error: %s" % (str(execs), str(e))
|
||||
)
|
||||
|
||||
# Wait a moment for MATLAB to start up,
|
||||
# read the version string
|
||||
while True:
|
||||
line = self.__proc.stdout.readline()
|
||||
if "Copyright" in line:
|
||||
# next line is version info
|
||||
self.__version_str = self.__proc.stdout.readline().strip()
|
||||
break
|
||||
|
||||
# start input loop
|
||||
self.__proc.stdin.write(self._bootstrap)
|
||||
|
||||
# wait for input loop to be ready
|
||||
while True:
|
||||
line = self.__proc.stdout.readline()
|
||||
if line == "::ready\n":
|
||||
break
|
||||
|
||||
atexit.register(self.close)
|
||||
|
||||
def __call__(self, cmd, parse_result=True):
|
||||
"""
|
||||
Execute the specified statement(s) on the MATLAB interpreter and return
|
||||
the output string or raise an exception if there was an error.
|
||||
"""
|
||||
assert not self.__closed, "MATLAB process has already closed."
|
||||
if cmd[-1] != "\n":
|
||||
cmd += "\n"
|
||||
cmd += "::cmd_done\n"
|
||||
self.__proc.stdout.read()
|
||||
self.__proc.stdin.write(cmd)
|
||||
|
||||
if parse_result:
|
||||
return self._parse_result()
|
||||
|
||||
def _parse_result(self):
|
||||
assert not self.__closed, "MATLAB process has already closed."
|
||||
output = []
|
||||
while True:
|
||||
line = self.__proc.stdout.readline()
|
||||
if line == "::ready\n":
|
||||
break
|
||||
output.append(line)
|
||||
|
||||
for i in reversed(range(len(output))):
|
||||
line = output[i]
|
||||
if line == "::ok\n":
|
||||
return "".join(output[:i])
|
||||
elif line == "::err\n":
|
||||
raise MatlabError(output[i + 1 :], output[:i])
|
||||
|
||||
raise RuntimeError("No success/failure code found in output (printed above).")
|
||||
|
||||
def _get(self, name):
|
||||
"""
|
||||
Transfer an object from MATLAB to Python.
|
||||
"""
|
||||
assert isinstance(name, str)
|
||||
tmp = tempfile.mktemp(suffix=".mat")
|
||||
out = self("save('%s', '%s', '-v7')" % (tmp, name))
|
||||
objs = scipy.io.loadmat(tmp)
|
||||
os.remove(tmp)
|
||||
return objs[name]
|
||||
|
||||
def _set(self, **kwds):
|
||||
"""
|
||||
Transfer an object from Python to MATLAB and assign it to the given
|
||||
variable name.
|
||||
"""
|
||||
tmp = tempfile.mktemp(suffix=".mat")
|
||||
scipy.io.savemat(tmp, kwds)
|
||||
self("load('%s')" % tmp)
|
||||
os.remove(tmp)
|
||||
|
||||
def _get_via_pipe(self, name):
|
||||
"""
|
||||
Transfer an object from MATLAB to Python.
|
||||
|
||||
This method sends data over the pipe, but is less reliable than get().
|
||||
"""
|
||||
assert isinstance(name, str)
|
||||
out = self("save('stdio', '%s', '-v7')" % name)
|
||||
start = stop = None
|
||||
for i, line in enumerate(out):
|
||||
if line.startswith("start_binary"):
|
||||
start = i
|
||||
elif line.startswith("::ok"):
|
||||
stop = i
|
||||
data = "".join(out[start + 1 : stop])
|
||||
io = StringIO(data[:-1])
|
||||
objs = scipy.io.loadmat(io)
|
||||
return objs[name]
|
||||
|
||||
def _set_via_pipe(self, **kwds):
|
||||
"""
|
||||
Transfer an object from Python to MATLAB and assign it to the given
|
||||
variable name.
|
||||
|
||||
This method sends data over the pipe, but is less reliable than set().
|
||||
"""
|
||||
assert not self.__closed, "MATLAB process has already closed."
|
||||
io = StringIO()
|
||||
scipy.io.savemat(io, kwds)
|
||||
io.seek(0)
|
||||
strn = io.read()
|
||||
self.__proc.stdout.read()
|
||||
self.__proc.stdin.write("load('stdio')\n::cmd_done\n")
|
||||
while True:
|
||||
line = self.__proc.stdout.readline()
|
||||
if line == "ack load stdio\n":
|
||||
# now it is safe to send data
|
||||
break
|
||||
self.__proc.stdin.write(strn)
|
||||
self.__proc.stdin.write("\n")
|
||||
while True:
|
||||
line = self.__proc.stdout.readline()
|
||||
if line == "ack load finished\n":
|
||||
break
|
||||
self._parse_result()
|
||||
|
||||
def exist(self, name):
|
||||
if name == "exist":
|
||||
return 5
|
||||
else:
|
||||
for line in self("exist %s" % name).split("\n"):
|
||||
try:
|
||||
return int(line.strip())
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name not in self.__refs:
|
||||
ex = self.exist(name)
|
||||
if ex == 0:
|
||||
raise AttributeError("No object named '%s' in matlab workspace." % name)
|
||||
elif ex in (2, 3, 5):
|
||||
r = MatlabFunction(self, name)
|
||||
elif ex == 1:
|
||||
r = self._mkref(name)
|
||||
else:
|
||||
typ = {4: "library file", 6: "P-file", 7: "folder", 8: "class"}[ex]
|
||||
raise TypeError("Variable '%s' has unsupported type '%s'" % (name, typ))
|
||||
self.__refs[name] = r
|
||||
return self.__refs[name]
|
||||
|
||||
def _mkref(self, name):
|
||||
assert name not in self.__refs
|
||||
ref = MatlabReference(self, name)
|
||||
self.__refs[name] = ref
|
||||
return ref
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name.startswith("_MatlabProcess__"):
|
||||
object.__setattr__(self, name, value)
|
||||
else:
|
||||
self._set(**{name: value})
|
||||
|
||||
def close(self):
|
||||
if self.__closed:
|
||||
return
|
||||
self("exit;\n", parse_result=False)
|
||||
self.__closed = True
|
||||
|
||||
|
||||
class MatlabReference(object):
|
||||
""" Reference to a variable in the matlab workspace.
|
||||
"""
|
||||
|
||||
def __init__(self, proc, name):
|
||||
self._proc = proc
|
||||
self._name = name
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
def get(self):
|
||||
return self._proc._get(self._name)
|
||||
|
||||
def clear(self):
|
||||
self._proc("clear %s;" % self._name)
|
||||
|
||||
def __del__(self):
|
||||
self.clear()
|
||||
|
||||
|
||||
class MatlabFunction(object):
|
||||
"""
|
||||
Proxy to a MATLAB function.
|
||||
|
||||
Calling this object transfers the arguments to MATLAB, invokes the function
|
||||
remotely, and then transfers the return value back to Python.
|
||||
"""
|
||||
|
||||
def __init__(self, proc, name):
|
||||
self._proc = proc
|
||||
self._name = name
|
||||
self._nargout = None
|
||||
|
||||
@property
|
||||
def nargout(self):
|
||||
""" Number of output arguments for this function.
|
||||
|
||||
For some functions, requesting nargout() will fail. In these cases,
|
||||
the nargout property must be set manually before calling the function.
|
||||
"""
|
||||
if self._nargout is None:
|
||||
cmd = "fprintf('%%d\\n', nargout('%s'));" % (self._name)
|
||||
ret = self._proc(cmd)
|
||||
self._nargout = int(ret.strip())
|
||||
return self._nargout
|
||||
|
||||
@nargout.setter
|
||||
def nargout(self, n):
|
||||
self._nargout = n
|
||||
|
||||
def __call__(self, *args, **kwds):
|
||||
"""
|
||||
Call this function with the given arguments.
|
||||
|
||||
If _transfer is False, then the return values are left in MATLAB and
|
||||
references to these values are returned instead.
|
||||
"""
|
||||
import pyqtgraph as pg
|
||||
|
||||
_transfer = kwds.pop("_transfer", True)
|
||||
assert len(kwds) == 0
|
||||
|
||||
# for generating unique variable names
|
||||
rand = np.random.randint(1e12)
|
||||
|
||||
# store args to temporary variables, excluding those already present
|
||||
# in the workspace
|
||||
argnames = []
|
||||
upload = {}
|
||||
for i, arg in enumerate(args):
|
||||
if isinstance(arg, MatlabReference):
|
||||
argnames.append(arg.name)
|
||||
elif np.isscalar(arg):
|
||||
argnames.append(repr(arg))
|
||||
else:
|
||||
argname = "%s_%d_%d" % (self._name, i, rand)
|
||||
argnames.append(argname)
|
||||
upload[argname] = arg
|
||||
if len(upload) > 0:
|
||||
self._proc._set(**upload)
|
||||
|
||||
try:
|
||||
# get number of output args
|
||||
nargs = self.nargout
|
||||
|
||||
# invoke function, fetch return value(s)
|
||||
retvars = ["%s_rval_%d_%d" % (self._name, i, rand) for i in range(nargs)]
|
||||
cmd = "[%s] = %s(%s);" % (",".join(retvars), self._name, ",".join(argnames))
|
||||
self._proc(cmd)
|
||||
if _transfer:
|
||||
ret = [self._proc._get(var) for var in retvars]
|
||||
self._proc("clear %s;" % (" ".join(retvars)))
|
||||
else:
|
||||
ret = [self._proc._mkref(name) for name in retvars]
|
||||
if len(ret) == 1:
|
||||
ret = ret[0]
|
||||
else:
|
||||
ret = tuple(ret)
|
||||
return ret
|
||||
finally:
|
||||
# clear all temp variables
|
||||
clear = list(upload.keys())
|
||||
# if _transfer:
|
||||
# clear += retvars
|
||||
if len(clear) > 0:
|
||||
cmd = "clear %s;" % (" ".join(clear))
|
||||
self._proc(cmd)
|
||||
return ret
|
||||
|
||||
|
||||
class MatlabError(Exception):
|
||||
def __init__(self, error, output):
|
||||
self.output = "".join(output)
|
||||
for line in error:
|
||||
self.stack = []
|
||||
if line.startswith("::message:"):
|
||||
self.message = line[10:].strip()
|
||||
elif line.startswith("::identifier:"):
|
||||
self.identifier = line[13:].strip()
|
||||
elif line.startswith("::stack:"):
|
||||
self.stack.append(line[8:].strip())
|
||||
|
||||
def __repr__(self):
|
||||
return "MatlabError(message=%s, identifier=%s)" % (
|
||||
repr(self.message),
|
||||
repr(self.identifier),
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
if len(self.stack) > 0:
|
||||
stack = "\nMATLAB Stack:\n%s\nMATLAB Error: " % "\n".join(self.stack)
|
||||
return stack + self.message
|
||||
else:
|
||||
return self.message
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
p = MatlabProcess()
|
||||
io = StringIO()
|
||||
scipy.io.savemat(io, {"x": 1})
|
||||
io.seek(0)
|
||||
strn = io.read()
|
||||
218
cnmodel/util/nrnutils.py
Normal file
218
cnmodel/util/nrnutils.py
Normal file
@@ -0,0 +1,218 @@
|
||||
from __future__ import print_function
|
||||
|
||||
"""
|
||||
Wrapper classes to make working with NEURON easier.
|
||||
|
||||
Author: Andrew P. Davison, UNIC, CNRS
|
||||
"""
|
||||
|
||||
__version__ = "0.3.0"
|
||||
|
||||
from neuron import nrn, h, hclass
|
||||
|
||||
h.load_file("stdrun.hoc")
|
||||
|
||||
PROXIMAL = 0
|
||||
DISTAL = 1
|
||||
|
||||
|
||||
class Mechanism(object):
|
||||
"""
|
||||
Examples:
|
||||
>>> leak = Mechanism('pas', {'e': -65, 'g': 0.0002})
|
||||
>>> hh = Mechanism('hh')
|
||||
added set_parameters to allow post-instantiation parameter modification
|
||||
"""
|
||||
|
||||
def __init__(self, name, **parameters):
|
||||
self.name = name
|
||||
self.parameters = parameters
|
||||
|
||||
def set_parameters(self, parameters):
|
||||
self.parameters = parameters
|
||||
|
||||
def insert_into(self, section):
|
||||
section.insert(self.name)
|
||||
for name, value in self.parameters.items():
|
||||
for segment in section:
|
||||
mech = getattr(segment, self.name)
|
||||
setattr(mech, name, value)
|
||||
|
||||
|
||||
class Section(nrn.Section):
|
||||
"""
|
||||
Examples:
|
||||
>>> soma = Section(L=30, diam=30, mechanisms=[hh, leak])
|
||||
>>> apical = Section(L=600, diam=2, nseg=5, mechanisms=[leak],
|
||||
... parent=soma, connection_point=DISTAL)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
L,
|
||||
diam,
|
||||
nseg=1,
|
||||
Ra=100,
|
||||
cm=1,
|
||||
mechanisms=[],
|
||||
parent=None,
|
||||
connection_point=DISTAL,
|
||||
):
|
||||
nrn.Section.__init__(self)
|
||||
# set geometry
|
||||
self.L = L
|
||||
self.diam = diam
|
||||
self.nseg = nseg
|
||||
# set cable properties
|
||||
self.Ra = Ra
|
||||
self.cm = cm
|
||||
# connect to parent section
|
||||
if parent:
|
||||
self.connect(parent, connection_point, PROXIMAL)
|
||||
# add ion channels
|
||||
for mechanism in mechanisms:
|
||||
mechanism.insert_into(self)
|
||||
|
||||
def add_synapses(self, label, type, locations=[0.5], **parameters):
|
||||
if hasattr(self, label):
|
||||
raise Exception("Can't overwrite synapse labels (to keep things simple)")
|
||||
synapse_group = []
|
||||
for location in locations:
|
||||
synapse = getattr(h, type)(location, sec=self)
|
||||
for name, value in parameters.items():
|
||||
setattr(synapse, name, value)
|
||||
synapse_group.append(synapse)
|
||||
if len(synapse_group) == 1:
|
||||
synapse_group = synapse_group[0]
|
||||
setattr(self, label, synapse_group)
|
||||
|
||||
add_synapse = add_synapses # for backwards compatibility
|
||||
|
||||
def plot(self, variable, location=0.5, tmin=0, tmax=5, xmin=-80, xmax=40):
|
||||
import neuron.gui
|
||||
|
||||
self.graph = h.Graph()
|
||||
h.graphList[0].append(self.graph)
|
||||
self.graph.size(tmin, tmax, xmin, xmax)
|
||||
self.graph.addvar("%s(%g)" % (variable, location), sec=self)
|
||||
|
||||
def record_spikes(self, threshold=-30):
|
||||
self.spiketimes = h.Vector()
|
||||
self.spikecount = h.APCount(0.5, sec=self)
|
||||
self.spikecount.thresh = threshold
|
||||
self.spikecount.record(self.spiketimes)
|
||||
|
||||
|
||||
def alias(attribute_path):
|
||||
"""
|
||||
Returns a new property, mapping an attribute nested in an object hierarchy
|
||||
to a simpler name
|
||||
|
||||
For example, suppose that an object of class A has an attribute b which
|
||||
itself has an attribute c which itself has an attribute d. Then placing
|
||||
e = alias('b.c.d')
|
||||
in the class definition of A makes A.e an alias for A.b.c.d
|
||||
"""
|
||||
|
||||
parts = attribute_path.split(".")
|
||||
attr_name = parts[-1]
|
||||
attr_path = parts[:-1]
|
||||
|
||||
def set(self, value):
|
||||
obj = reduce(getattr, [self] + attr_path)
|
||||
setattr(obj, attr_name, value)
|
||||
|
||||
def get(self):
|
||||
obj = reduce(getattr, [self] + attr_path)
|
||||
return getattr(obj, attr_name)
|
||||
|
||||
return property(fset=set, fget=get)
|
||||
|
||||
|
||||
def uniform_property(section_list, attribute_path):
|
||||
"""
|
||||
Define a property that will have a uniform value across a list of sections.
|
||||
|
||||
For example, suppose we define a neuron model as a class A, which contains
|
||||
three compartments: soma, dendrite and axon. Then placing
|
||||
|
||||
gnabar = uniform_property(["soma", "axon"], "hh.gnabar")
|
||||
|
||||
in the class definition of A means that setting a.gnabar (where a is an
|
||||
instance of A) will set the value of hh.gnabar in both the soma and axon, i.e.
|
||||
|
||||
a.gnabar = 0.01
|
||||
|
||||
is equivalent to:
|
||||
|
||||
for sec in [a.soma, a.axon]:
|
||||
for seg in sec:
|
||||
seg.hh.gnabar = 0.01
|
||||
|
||||
"""
|
||||
parts = attribute_path.split(".")
|
||||
attr_name = parts[-1]
|
||||
attr_path = parts[:-1]
|
||||
|
||||
def set(self, value):
|
||||
for sec_name in section_list:
|
||||
sec = getattr(self, sec_name)
|
||||
for seg in sec:
|
||||
obj = reduce(getattr, [seg] + attr_path)
|
||||
setattr(obj, attr_name, value)
|
||||
|
||||
def get(self):
|
||||
sec = getattr(self, section_list[0])
|
||||
obj = reduce(getattr, [sec(0.5)] + attr_path)
|
||||
return getattr(obj, attr_name)
|
||||
|
||||
return property(fset=set, fget=get)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
class SimpleNeuron(object):
|
||||
def __init__(self):
|
||||
# define ion channel parameters
|
||||
leak = Mechanism("pas", e=-65, g=0.0002)
|
||||
hh = Mechanism("hh")
|
||||
# create cable sections
|
||||
self.soma = Section(L=30, diam=30, mechanisms=[hh])
|
||||
self.apical = Section(
|
||||
L=600,
|
||||
diam=2,
|
||||
nseg=5,
|
||||
mechanisms=[leak],
|
||||
parent=self.soma,
|
||||
connection_point=DISTAL,
|
||||
)
|
||||
self.basilar = Section(
|
||||
L=600,
|
||||
diam=2,
|
||||
nseg=5,
|
||||
mechanisms=[leak],
|
||||
parent=self.soma,
|
||||
connection_point=0.5,
|
||||
)
|
||||
self.axon = Section(
|
||||
L=1000, diam=1, nseg=37, mechanisms=[hh], connection_point=0
|
||||
)
|
||||
# synaptic input
|
||||
self.soma.add_synapses("syn", "AlphaSynapse", onset=0.5, gmax=0.05, e=0)
|
||||
|
||||
gnabar = uniform_property(["soma", "axon"], "hh.gnabar")
|
||||
gkbar = uniform_property(["soma", "axon"], "hh.gkbar")
|
||||
|
||||
neuron = SimpleNeuron()
|
||||
neuron.soma.plot("v")
|
||||
neuron.apical.plot("v")
|
||||
|
||||
print("gNa_bar: ", neuron.gnabar)
|
||||
neuron.gnabar = 0.15
|
||||
assert neuron.soma(0.5).hh.gnabar == 0.15
|
||||
|
||||
h.dt = 0.025
|
||||
v_init = -65
|
||||
tstop = 5
|
||||
h.finitialize(v_init)
|
||||
h.run()
|
||||
111
cnmodel/util/process.py
Normal file
111
cnmodel/util/process.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from __future__ import print_function
|
||||
|
||||
"""
|
||||
Utility class for spawning and controlling CLI processes.
|
||||
See: http://stackoverflow.com/questions/375427/non-blocking-read-on-a-subprocess-pipe-in-python
|
||||
"""
|
||||
import sys, time
|
||||
from subprocess import PIPE, Popen
|
||||
from threading import Thread
|
||||
|
||||
try:
|
||||
from Queue import Queue, Empty
|
||||
except ImportError:
|
||||
from queue import Queue, Empty # python 3.x
|
||||
|
||||
ON_POSIX = "posix" in sys.builtin_module_names
|
||||
|
||||
|
||||
class Process(object):
|
||||
"""
|
||||
Process encapsulates a subprocess with queued stderr/stdout pipes.
|
||||
Wraps most methods from subprocess.Popen.
|
||||
|
||||
For non-blocking reads, use proc.stdout.get_nowait() or
|
||||
proc.stdout.get(timeout=0.1).
|
||||
"""
|
||||
|
||||
def __init__(self, exec_args, cwd=None):
|
||||
self.proc = Popen(
|
||||
exec_args,
|
||||
stdout=PIPE,
|
||||
stdin=PIPE,
|
||||
stderr=PIPE,
|
||||
bufsize=1,
|
||||
close_fds=ON_POSIX,
|
||||
universal_newlines=True,
|
||||
cwd=cwd,
|
||||
)
|
||||
self.stdin = self.proc.stdin
|
||||
# self.stdin = PipePrinter(self.proc.stdin)
|
||||
self.stdout = PipeQueue(self.proc.stdout)
|
||||
self.stderr = PipeQueue(self.proc.stderr)
|
||||
for method in ["poll", "wait", "send_signal", "kill", "terminate"]:
|
||||
setattr(self, method, getattr(self.proc, method))
|
||||
|
||||
|
||||
class PipePrinter(object):
|
||||
""" For debugging writes to a pipe.
|
||||
"""
|
||||
|
||||
def __init__(self, pipe):
|
||||
self._pipe = pipe
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self._pipe, attr)
|
||||
|
||||
def write(self, strn):
|
||||
print("WRITE:" + repr(strn))
|
||||
return self._pipe.write(strn)
|
||||
|
||||
|
||||
class PipeQueue(Queue):
|
||||
"""
|
||||
Queue that starts a second process to monitor a PIPE for new data.
|
||||
This is needed to allow non-blocking pipe reads.
|
||||
"""
|
||||
|
||||
def __init__(self, pipe):
|
||||
Queue.__init__(self)
|
||||
self.thread = Thread(target=self.enqueue_output, args=(pipe, self))
|
||||
self.thread.daemon = True # thread dies with the program
|
||||
self.thread.start()
|
||||
|
||||
@staticmethod
|
||||
def enqueue_output(out, queue):
|
||||
for line in iter(out.readline, b""):
|
||||
queue.put(line)
|
||||
# print "READ: " + repr(line)
|
||||
out.close()
|
||||
|
||||
def read(self):
|
||||
"""
|
||||
Read all available lines from the queue, concatenated into a single string.
|
||||
"""
|
||||
out = ""
|
||||
while True:
|
||||
try:
|
||||
out += self.get_nowait()
|
||||
except Empty:
|
||||
break
|
||||
return out
|
||||
|
||||
def readline(self, timeout=None):
|
||||
""" Read a single line from the queue.
|
||||
"""
|
||||
# we break this up into multiple short reads to allow keyboard
|
||||
# interrupts
|
||||
start = time.time()
|
||||
ret = ""
|
||||
while True:
|
||||
if timeout is not None:
|
||||
remaining = start + timeout - time.time()
|
||||
if remaining <= 0:
|
||||
return ""
|
||||
else:
|
||||
remaining = 1
|
||||
|
||||
try:
|
||||
return self.get(timeout=min(0.1, remaining))
|
||||
except Empty:
|
||||
pass
|
||||
734
cnmodel/util/pynrnutilities.py
Executable file
734
cnmodel/util/pynrnutilities.py
Executable file
@@ -0,0 +1,734 @@
|
||||
from __future__ import print_function
|
||||
|
||||
#!/usr/bin/python
|
||||
#
|
||||
# utilities for NEURON, in Python
|
||||
# Module neuron for cnmodel
|
||||
|
||||
import numpy as np
|
||||
import numpy.ma as ma # masked array
|
||||
import re, sys, gc, collections
|
||||
|
||||
import neuron
|
||||
|
||||
|
||||
_mechtype_cache = None
|
||||
|
||||
|
||||
def all_mechanism_types():
|
||||
"""Return a dictionary of all available mechanism types.
|
||||
|
||||
Each dictionary key is the name of a mechanism and each value is
|
||||
another dictionary containing information about the mechanism::
|
||||
|
||||
mechanism_types = {
|
||||
'mech_name1': {
|
||||
'point_process': bool,
|
||||
'artificial_cell': bool,
|
||||
'netcon_target': bool,
|
||||
'has_netevent': bool,
|
||||
'internal_type': int,
|
||||
'globals': {name:size, ...},
|
||||
'parameters': {name:size, ...},
|
||||
'assigned': {name:size, ...},
|
||||
'state': {name:size, ...},
|
||||
},
|
||||
'mech_name2': {...},
|
||||
'mech_name3': {...},
|
||||
...
|
||||
}
|
||||
|
||||
* point_process: False for distributed mechanisms, True for point
|
||||
processes and artificial cells.
|
||||
* artificial_cell: True for artificial cells, False otherwise
|
||||
* netcon_target: True if the mechanism can receive NetCon events
|
||||
* has_netevent: True if the mechanism can emit NetCon events
|
||||
* internal_type: Integer specifying the NEURON internal type index of
|
||||
the mechanism
|
||||
* globals: dict of the name and vector size of the mechanism's global
|
||||
variables
|
||||
* parameters: dict of the name and vector size of the mechanism's
|
||||
parameter variables
|
||||
* assigned: dict of the name and vector size of the mechanism's
|
||||
assigned variables
|
||||
* state: dict of the name and vector size of the mechanism's state
|
||||
variables
|
||||
|
||||
|
||||
Note: The returned data structure is cached; do not modify it.
|
||||
|
||||
For more information on global, parameter, assigned, and state
|
||||
variables see:
|
||||
http://www.neuron.yale.edu/neuron/static/docs/help/neuron/nmodl/nmodl.html
|
||||
"""
|
||||
global _mechtype_cache
|
||||
if _mechtype_cache is None:
|
||||
_mechtype_cache = collections.OrderedDict()
|
||||
mname = neuron.h.ref("")
|
||||
# Iterate over two mechanism types (distributed, point/artificial)
|
||||
for i in [0, 1]:
|
||||
mt = neuron.h.MechanismType(i)
|
||||
nmech = int(mt.count())
|
||||
# Iterate over all mechanisms of this type
|
||||
for j in range(nmech):
|
||||
mt.select(j)
|
||||
mt.selected(mname)
|
||||
|
||||
# General mechanism properties
|
||||
name = mname[0] # convert hoc string ptr to python str
|
||||
|
||||
desc = {
|
||||
"point_process": bool(i),
|
||||
"netcon_target": bool(mt.is_netcon_target(j)),
|
||||
"has_netevent": bool(mt.has_net_event(j)),
|
||||
"artificial_cell": bool(mt.is_artificial(j)),
|
||||
"internal_type": int(mt.internal_type()),
|
||||
}
|
||||
|
||||
# Collect information about 4 different types of variables
|
||||
for k, ptype in [
|
||||
(-1, "globals"),
|
||||
(1, "parameters"),
|
||||
(2, "assigned"),
|
||||
(3, "state"),
|
||||
]:
|
||||
desc[ptype] = {} # collections.OrderedDict()
|
||||
ms = neuron.h.MechanismStandard(name, k)
|
||||
for l in range(int(ms.count())):
|
||||
psize = ms.name(mname, l)
|
||||
pname = mname[0] # parameter name
|
||||
desc[ptype][pname] = int(psize)
|
||||
|
||||
# Assemble everything in one place
|
||||
_mechtype_cache[name] = desc
|
||||
|
||||
return _mechtype_cache
|
||||
|
||||
|
||||
def reset(raiseError=True):
|
||||
"""Introspect the NEURON kernel to verify that no objects are left over
|
||||
from previous simulation runs.
|
||||
"""
|
||||
# Release objects held by an internal buffer
|
||||
# See https://www.neuron.yale.edu/phpBB/viewtopic.php?f=2&t=3221
|
||||
neuron.h.Vector().size()
|
||||
|
||||
# Make sure nothing is hanging around in an old exception or because of
|
||||
# reference cycles
|
||||
|
||||
# sys.exc_clear()
|
||||
gc.collect(2)
|
||||
neuron.h.Vector().size()
|
||||
numsec = 0
|
||||
|
||||
remaining = []
|
||||
n = len(list(neuron.h.allsec()))
|
||||
|
||||
if n > 0:
|
||||
remaining.append((n, "Section"))
|
||||
|
||||
n = len(neuron.h.List("NetCon"))
|
||||
if n > 0:
|
||||
remaining.append((n, "NetCon"))
|
||||
|
||||
# No point processes or artificial cells left
|
||||
for name, typ in all_mechanism_types().items():
|
||||
if typ["artificial_cell"] or typ["point_process"]:
|
||||
n = len(neuron.h.List(name))
|
||||
if n > 0:
|
||||
remaining.append((n, name))
|
||||
|
||||
if (
|
||||
len(remaining) > 0 and raiseError
|
||||
): # note that not raising the error leads to memory leak
|
||||
msg = "Cannot reset--old objects have not been cleared: %s" % ", ".join(
|
||||
["%d %s" % rem for rem in remaining]
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
|
||||
def custom_init(v_init=-60.0):
|
||||
"""
|
||||
Perform a custom initialization of the current model/section.
|
||||
|
||||
This initialization follows the scheme outlined in the
|
||||
NEURON book, 8.4.2, p 197 for initializing to steady state.
|
||||
|
||||
N.B.: For a complex model with dendrites/axons and different channels,
|
||||
this initialization will not find the steady-state the whole cell,
|
||||
leading to initial transient currents. In that case, this initialization
|
||||
should be followed with a 0.1-5 second run (depends on the rates in the
|
||||
various channel mechanisms) with no current injection or active point
|
||||
processes to allow the system to settle to a steady- state. Use either
|
||||
h.svstate or h.batch_save to save states and restore them. Batch is
|
||||
preferred
|
||||
|
||||
Parameters
|
||||
----------
|
||||
v_init : float (default: -60 mV)
|
||||
Voltage to start the initialization process. This should
|
||||
be close to the expected resting state.
|
||||
"""
|
||||
inittime = -1e10
|
||||
tdt = neuron.h.dt # save current step size
|
||||
dtstep = 1e9
|
||||
neuron.h.finitialize(v_init)
|
||||
neuron.h.t = inittime # set time to large negative value (avoid activating
|
||||
# point processes, we hope)
|
||||
tmp = neuron.h.cvode.active() # check state of variable step integrator
|
||||
if tmp != 0: # turn off CVode variable step integrator if it was active
|
||||
neuron.h.cvode.active(0) # now just use backward Euler with large step
|
||||
neuron.h.dt = dtstep
|
||||
n = 0
|
||||
while neuron.h.t < -1e9: # Step forward
|
||||
neuron.h.fadvance()
|
||||
n += 1
|
||||
# print('advances: ', n)
|
||||
if tmp != 0:
|
||||
neuron.h.cvode.active(1) # restore integrator
|
||||
neuron.h.t = 0
|
||||
if neuron.h.cvode.active():
|
||||
neuron.h.cvode.re_init() # update d(state)/dt and currents
|
||||
else:
|
||||
neuron.h.fcurrent() # recalculate currents
|
||||
neuron.h.frecord_init() # save new state variables
|
||||
neuron.h.dt = tdt # restore original time step
|
||||
|
||||
|
||||
# routine to convert conductances from nS as given elsewhere
|
||||
# to mho/cm2 as required by NEURON 1/28/99 P. Manis
|
||||
# units: nano siemens, soma area in um^2
|
||||
#
|
||||
def nstomho(ns, somaarea, refarea=None):
|
||||
if refarea == None:
|
||||
return 1e-9 * float(ns) / float(somaarea)
|
||||
else:
|
||||
return 1e9 * float(ns) / float(refarea)
|
||||
|
||||
|
||||
def mho2ns(mho, somaarea):
|
||||
return float(mho) * somaarea / 1e-9
|
||||
|
||||
|
||||
def spherearea(dia):
|
||||
"""
|
||||
given diameter in microns, return sphere area in cm2
|
||||
"""
|
||||
r = dia * 1e-4 # convert to cm
|
||||
return 4 * np.pi * r ** 2
|
||||
|
||||
|
||||
def get_sections(h):
|
||||
"""
|
||||
go through all the sections and find the names of the sections and all of their
|
||||
parts (ids). Returns a dict, of sec: [id0, id1...]
|
||||
|
||||
"""
|
||||
secnames = {}
|
||||
resec = re.compile("(\w+)\[(\d*)\]")
|
||||
for sec in h.allsec():
|
||||
g = resec.match(sec.name())
|
||||
if g.group(1) not in secnames.keys():
|
||||
secnames[g.group(1)] = [int(g.group(2))]
|
||||
else:
|
||||
secnames[g.group(1)].append(int(g.group(2)))
|
||||
return secnames
|
||||
|
||||
|
||||
def all_objects():
|
||||
""" Return a dict of all objects known to NEURON.
|
||||
|
||||
Keys are 'Section', 'Segment', 'Mechanism', 'Vector', 'PointProcess',
|
||||
'NetCon', ...
|
||||
"""
|
||||
objs = {}
|
||||
objs["Section"] = list(h.all_sec())
|
||||
objs["Segment"] = []
|
||||
for sec in objs["Section"]:
|
||||
objs["Segment"].extend(list(sec.allseg()))
|
||||
objs["PointProcess"] = []
|
||||
for seg in objs["Segment"]:
|
||||
objs["PointProcess"].extend(list(seg.point_processes()))
|
||||
|
||||
return objs
|
||||
|
||||
|
||||
def alpha(alpha=0.1, delay=1, amp=1.0, tdur=50.0, dt=0.010):
|
||||
tvec = np.arange(0, tdur, dt)
|
||||
aw = np.zeros(tvec.shape)
|
||||
i = 0
|
||||
for t in tvec:
|
||||
if t > delay:
|
||||
aw[i] = (
|
||||
amp * (t - delay) * (1.0 / alpha) * np.exp(-(t - delay) / alpha)
|
||||
) # alpha waveform time course
|
||||
else:
|
||||
aw[i] = 0.0
|
||||
i += 1
|
||||
return (aw, tvec)
|
||||
|
||||
|
||||
def syns(
|
||||
alpha=0.1,
|
||||
rate=10,
|
||||
delay=0,
|
||||
dur=50,
|
||||
amp=1.0,
|
||||
dt=0.020,
|
||||
N=1,
|
||||
mindur=120,
|
||||
makewave=True,
|
||||
):
|
||||
""" Calculate a poisson train of alpha waves
|
||||
with mean rate rate, with a delay and duration (in mseco) dt in msec.
|
||||
N specifies the number of such independent waveforms to sum """
|
||||
deadtime = 0.7
|
||||
if dur + delay < mindur:
|
||||
tvec = np.arange(0.0, mindur, dt)
|
||||
else:
|
||||
tvec = np.arange(0.0, dur + delay, dt)
|
||||
npts = len(tvec)
|
||||
ta = np.arange(0.0, 20.0, dt)
|
||||
aw = ta * alpha * np.exp(-ta / alpha) / alpha # alpha waveform time course
|
||||
spt = [[]] * N # list of spike times
|
||||
wave = np.array([]) # waveform
|
||||
sptime = []
|
||||
for j in range(0, N):
|
||||
done = False
|
||||
t = 0.0
|
||||
nsp = 0
|
||||
while not done:
|
||||
a = np.random.sample(1)
|
||||
if t < delay:
|
||||
t = delay
|
||||
continue
|
||||
if t >= delay and t <= (delay + dur):
|
||||
ti = -np.log(a) / (
|
||||
rate / 1000.0
|
||||
) # convert to exponential distribution with rate
|
||||
if ti < deadtime:
|
||||
continue
|
||||
t = t + ti # running time
|
||||
if t > delay + dur:
|
||||
done = True
|
||||
continue
|
||||
if nsp is 0:
|
||||
sptime = t
|
||||
nsp = nsp + 1
|
||||
else:
|
||||
sptime = np.append(sptime, t)
|
||||
nsp = nsp + 1
|
||||
if j is 0:
|
||||
wavej = np.zeros(len(tvec))
|
||||
for i in range(0, len(sptime)):
|
||||
st = int(sptime[i] / dt)
|
||||
wavej[st] = wavej[st] + 1
|
||||
spt[j] = sptime
|
||||
|
||||
if makewave:
|
||||
w = np.convolve(wavej, aw / max(aw)) * amp
|
||||
if len(w) < npts:
|
||||
w = np.append(w, np.zeros(npts - len(w)))
|
||||
if len(w) > npts:
|
||||
w = w[0:npts]
|
||||
if j is 0:
|
||||
wave = w
|
||||
else:
|
||||
wave = wave + w
|
||||
return (spt, wave, tvec, N)
|
||||
|
||||
|
||||
def an_syn(
|
||||
alpha=0.1,
|
||||
spont=10,
|
||||
driven=100,
|
||||
delay=50,
|
||||
dur=100,
|
||||
post=20,
|
||||
amp=0.1,
|
||||
dt=0.020,
|
||||
N=1,
|
||||
makewave=True,
|
||||
):
|
||||
# constants for AN:
|
||||
deadtime = 0.7 # min time between spikes, msec
|
||||
trise = 0.2 # rise rate, ms
|
||||
tfall = 0.5 # fall rate, ms
|
||||
rss = driven / 1000.0 # spikes/millisecond
|
||||
rr = 3 * rss # transient driven rate
|
||||
rst = rss
|
||||
taur = 3 # rapid decay, msec
|
||||
taust = 10
|
||||
ton = delay # msec
|
||||
stim_end = ton + dur
|
||||
trace_end = stim_end + post
|
||||
tvec = np.arange(0.0, trace_end, dt) # dt is in msec, so tvec is in milliseconds
|
||||
ta = np.arange(0.0, 20.0, dt)
|
||||
aw = ta * alpha * np.exp(-ta / alpha) / alpha # alpha waveform time course
|
||||
spt = [[]] * N # list of spike times
|
||||
wave = [[]] * N # waveform
|
||||
for j in range(0, N): # for each
|
||||
done = False
|
||||
sptime = []
|
||||
qe = 0
|
||||
nsp = 0
|
||||
i = int(0)
|
||||
if spont <= 0:
|
||||
q = 1e6
|
||||
else:
|
||||
q = 1000.0 / spont # q is in msec (spont in spikes/second)
|
||||
t = 0.0
|
||||
while not done:
|
||||
a = np.random.sample(1)
|
||||
if t < ton:
|
||||
if spont <= 0:
|
||||
t = ton
|
||||
continue
|
||||
ti = -(np.log(a) / (spont / 1000.0)) # convert to msec
|
||||
if ti < deadtime: # delete intervals less than deadtime
|
||||
continue
|
||||
t = t + ti
|
||||
if t > ton: # if the interval would step us to the stimulus onset
|
||||
t = ton # then set to to the stimulus onset
|
||||
continue
|
||||
if t >= ton and t < stim_end:
|
||||
if t > ton:
|
||||
rise = 1.0 - np.exp(-(t - ton) / trise)
|
||||
else:
|
||||
rise = 1.0
|
||||
ra = rr * np.exp(-(t - ton) / taur)
|
||||
rs = rst * np.exp(-(t - ton) / taust)
|
||||
q = rise * (ra + rs + rss)
|
||||
ti = -np.log(a) / (q + spont / 1000) # random.negexp(1000/q)
|
||||
if ti < deadtime:
|
||||
continue
|
||||
t = t + ti
|
||||
if t > stim_end: # only include interval if it falls inside window
|
||||
t = stim_end
|
||||
continue
|
||||
if t >= stim_end and t <= trace_end:
|
||||
if spont <= 0.0:
|
||||
t = trace_end
|
||||
continue
|
||||
if qe is 0: # have not calculated the new qe at end of stimulus
|
||||
rise = 1.0 - np.exp(-(stim_end - ton) / trise)
|
||||
ra = rr * np.exp(-(stim_end - ton) / taur)
|
||||
rs = rst * np.exp(-(stim_end - ton) / taust)
|
||||
qe = rise * (ra + rs + rss) # calculate the rate at the end
|
||||
fall = np.exp(-(t - stim_end) / tfall)
|
||||
q = qe * fall
|
||||
ti = -np.log(a) / (
|
||||
q + spont / 1000.0
|
||||
) # keeps rate from falling below spont rate
|
||||
if ti < deadtime:
|
||||
continue
|
||||
t = t + ti
|
||||
if t >= trace_end:
|
||||
done = True
|
||||
continue
|
||||
# now add the spike time to the list
|
||||
if nsp is 0:
|
||||
sptime = t
|
||||
nsp = nsp + 1
|
||||
else:
|
||||
sptime = np.append(sptime, t)
|
||||
nsp = nsp + 1
|
||||
# end of for loop on i
|
||||
if j is 0:
|
||||
wavej = np.zeros(len(tvec))
|
||||
for i in range(0, len(sptime)):
|
||||
st = int(sptime[i] / dt)
|
||||
wavej[st] = wavej[st] + 1
|
||||
spt[j] = sptime
|
||||
npts = len(tvec)
|
||||
if makewave:
|
||||
w = np.convolve(wavej, aw / max(aw)) * amp
|
||||
wave[j] = w[0:npts]
|
||||
return (spt, wave, tvec, N)
|
||||
|
||||
|
||||
def findspikes(t, v, thresh):
|
||||
""" findspikes identifies the times of action potential in the trace v, with the
|
||||
times in t. An action potential is simply timed at the first point that exceeds
|
||||
the threshold.
|
||||
"""
|
||||
tm = np.array(t)
|
||||
s0 = (
|
||||
np.array(v) > thresh
|
||||
) # np.where(v > thresh) # np.array(v) > thresh # find points above threshold
|
||||
|
||||
# print ('v: ', v)
|
||||
dsp = tm[s0]
|
||||
if dsp.shape[0] == 1:
|
||||
dsp = np.array(dsp)
|
||||
sd = np.append(True, np.diff(dsp) > 1.0) # find first points of spikes
|
||||
if len(dsp) > 0:
|
||||
sp = dsp[sd]
|
||||
else:
|
||||
sp = []
|
||||
return sp # list of spike times.
|
||||
|
||||
|
||||
def measure(mode, x, y, x0, x1):
|
||||
""" return the mean and standard deviation of y in the window x0 to x1
|
||||
"""
|
||||
xm = ma.masked_outside(x, x0, x1)
|
||||
ym = ma.array(y, mask=ma.getmask(xm))
|
||||
if mode == "mean":
|
||||
r1 = ma.mean(ym)
|
||||
r2 = ma.std(ym)
|
||||
if mode == "max":
|
||||
r1 = ma.max(ym)
|
||||
r2 = 0
|
||||
if mode == "min":
|
||||
r1 = ma.min(ym)
|
||||
r2 = 0
|
||||
if mode == "median":
|
||||
r1 = ma.median(ym)
|
||||
r2 = 0
|
||||
if mode == "p2p": # peak to peak
|
||||
r1 = ma.ptp(ym)
|
||||
r2 = 0
|
||||
return (r1, r2)
|
||||
|
||||
|
||||
def mask(x, xm, x0, x1):
|
||||
xmask = ma.masked_outside(xm, x0, x1)
|
||||
xnew = ma.array(x, mask=ma.getmask(xmask))
|
||||
return xnew.compressed()
|
||||
|
||||
|
||||
def vector_strength(spikes, freq):
|
||||
"""
|
||||
Calculate vector strength and related parameters from a spike train, for the specified frequency
|
||||
:param spikes: Spike train, in msec.
|
||||
:param freq: Stimulus frequency in Hz
|
||||
:return: a dictionary containing:
|
||||
|
||||
r: vector strength
|
||||
n: number of spikes
|
||||
R: Rayleigh coefficient
|
||||
p: p value (is distribution not flat?)
|
||||
ph: the circularized spike train over period of the stimulus freq, freq, in radians
|
||||
d: the "dispersion" computed according to Ashida et al., 2010, etc.
|
||||
"""
|
||||
|
||||
per = 1e3 / freq # convert from Hz to period in msec
|
||||
ph = 2 * np.pi * np.fmod(spikes, per) / (per) # convert to radians within a cycle
|
||||
c = np.sum(np.cos(ph)) ** 2
|
||||
s = np.sum(np.sin(ph)) ** 2
|
||||
vs = (1.0 / len(ph)) * np.sqrt(c + s) # standard vector strength computation
|
||||
n = len(spikes)
|
||||
R = n * vs # Raleigh coefficient
|
||||
Rp = np.exp(-n * vs * vs) # p value for n > 50 (see Ashida et al. 2010).
|
||||
d = np.sqrt(2.0 * (1 - vs)) / (2 * np.pi * freq)
|
||||
return {"r": vs, "n": n, "R": R, "p": Rp, "ph": ph, "d": d}
|
||||
|
||||
|
||||
def isi_cv2(splist, binwidth=1, t0=0, t1=300, tgrace=25):
|
||||
""" compute the cv and regularity according to Young et al., J. Neurophys, 60: 1, 1988.
|
||||
Analysis is limited to isi's starting at or after t0 but before t1, and ending completely
|
||||
before t1 + tgrace(to avoid end effects). t1 should correspond to the
|
||||
the end of the stimulus
|
||||
VERSION using dictionary for cvisi
|
||||
"""
|
||||
cvisit = np.arange(0, t1, binwidth) # build time bins
|
||||
cvisi = {} # isi is dictionary, since each bin may have different length
|
||||
for i in range(0, len(splist)): # for all the traces
|
||||
isit = splist[i] # get the spike times for this trial [1:-1]
|
||||
if len(isit) <= 1: # need 2 spikes to get an interval
|
||||
continue
|
||||
isib = np.floor(isit[0:-2] / binwidth) # discreetize
|
||||
isii = np.diff(splist[i]) # isis.
|
||||
for j in range(0, len(isib)): # loop over possible start time bins
|
||||
if (
|
||||
isit[j] < t0 or isit[j] > t1 or isit[j + 1] > t1 + tgrace
|
||||
): # start time and interval in the window
|
||||
continue
|
||||
if isib[j] in cvisi:
|
||||
print("spike in bin: %d" % (isib[j]))
|
||||
cvisi[isib[j]] = np.append(
|
||||
cvisi[isib[j]], isii[j]
|
||||
) # and add the isi in that bin
|
||||
else:
|
||||
cvisi[isib[j]] = isii[j] # create it
|
||||
cvm = np.array([]) # set up numpy arrays for mean, std and time for cv analysis
|
||||
cvs = np.array([])
|
||||
cvt = np.array([])
|
||||
for i in cvisi.keys(): # for each entry (possible bin)
|
||||
c = [cvisi[i]]
|
||||
s = c.shape
|
||||
# print c
|
||||
if len(s) > 1 and s[1] >= 3: # require 3 spikes in a bin for statistics
|
||||
cvm = np.append(cvm, np.mean(c))
|
||||
cvs = np.append(cvs, np.std(c))
|
||||
cvt = np.append(cvt, i * binwidth)
|
||||
return (cvisit, cvisi, cvt, cvm, cvs)
|
||||
|
||||
|
||||
def isi_cv(splist, binwidth=1, t0=0, t1=300, tgrace=25):
|
||||
""" compute the cv and regularity according to Young et al., J. Neurophys, 60: 1, 1988.
|
||||
Analysis is limited to isi's starting at or after t0 but before t1, and ending completely
|
||||
before t1 + tgrace(to avoid end effects). t1 should correspond to the
|
||||
the end of the stimulus
|
||||
Version using a list of numpy arrays for cvisi
|
||||
"""
|
||||
cvisit = np.arange(0, t1, binwidth) # build time bins
|
||||
cvisi = [[]] * len(cvisit)
|
||||
for i in range(0, len(splist)): # for all the traces
|
||||
if len(splist[i]) < 2: # need at least 2 spikes
|
||||
continue
|
||||
isib = np.floor(
|
||||
splist[i][0:-2] / binwidth
|
||||
) # begining spike times for each interval
|
||||
isii = np.diff(splist[i]) # associated intervals
|
||||
for j in range(0, len(isib)): # loop over spikes
|
||||
if (
|
||||
splist[i][j] < t0 or splist[i][j] > t1 or splist[i][j + 1] > t1 + tgrace
|
||||
): # start time and interval in the window
|
||||
continue
|
||||
cvisi[int(isib[j])] = np.append(
|
||||
cvisi[int(isib[j])], isii[j]
|
||||
) # and add the isi in that bin
|
||||
cvm = np.array([]) # set up numpy arrays for mean, std and time for cv analysis
|
||||
cvs = np.array([])
|
||||
cvt = np.array([])
|
||||
for i in range(0, len(cvisi)): # for each entry (possible bin)
|
||||
c = cvisi[i]
|
||||
if len(c) >= 3: # require 3 spikes in a bin for statistics
|
||||
cvm = np.append(cvm, np.mean(c))
|
||||
cvs = np.append(cvs, np.std(c))
|
||||
cvt = np.append(cvt, i * binwidth)
|
||||
return (cvisit, cvisi, cvt, cvm, cvs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
test = "isicv"
|
||||
|
||||
if test == "isicv":
|
||||
""" this test is not perfect. Given an ISI, we calculate spike times
|
||||
by drawing from a normal distribution whose standard deviation varies
|
||||
with time, from 0 (regular) to 1 (irregular). As coded, the standard
|
||||
deviation never reaches the target value because spikes fall before or
|
||||
after previous spikes (thus reducing the stdev). Nonetheless, this shows
|
||||
that the CV calculation works correctly. """
|
||||
nreps = 500
|
||||
# cv will be 0 for first 50 msec, 0.5 for next 50 msec, and 1 for next 50 msec
|
||||
d = [[]] * nreps
|
||||
isi = 5.0 # mean isi
|
||||
# we create 100 msec of data where the CV goes from 0 to 1
|
||||
maxt = 100.0
|
||||
for i in range(nreps):
|
||||
for j in range(int(maxt / isi) + 1):
|
||||
t = float(j) * isi
|
||||
sd = float(j) / isi
|
||||
if sd == 0.0:
|
||||
d[i] = np.append(d[i], t)
|
||||
else:
|
||||
d[i] = np.append(d[i], np.random.normal(t, sd, 1))
|
||||
for j in range(1, 10): # add more intervals at the end
|
||||
te = t + float(j) * isi
|
||||
d[i] = np.append(d[i], np.random.normal(te, sd, 1))
|
||||
d[i] = np.sort(d[i])
|
||||
# print d[i]
|
||||
# print diff(d[i])
|
||||
sh = np.array([])
|
||||
for i in range(len(d)):
|
||||
sh = np.append(sh, np.array(d[i]))
|
||||
(hist, bins) = np.histogram(sh, bins=250, range=(0, 250), new=True)
|
||||
if len(bins) > len(hist):
|
||||
bins = bins[0 : len(hist)]
|
||||
|
||||
pl.figure(1)
|
||||
pl.subplot(2, 2, 1)
|
||||
pl.plot(bins, hist)
|
||||
|
||||
(cvisit, cvisi, cvt, cvm, cvs) = isi_cv(
|
||||
d, binwidth=0.5, t0=0, t1=100, tgrace=25
|
||||
)
|
||||
order = np.argsort(cvt)
|
||||
cvt = cvt[order]
|
||||
cvs = cvs[order]
|
||||
cvm = cvm[order]
|
||||
pl.subplot(2, 2, 2)
|
||||
pl.plot(cvt, cvm)
|
||||
pl.hold(True)
|
||||
pl.plot(cvt, cvs)
|
||||
pl.subplot(2, 2, 4)
|
||||
pl.plot(cvt, cvs / cvm)
|
||||
pl.show()
|
||||
|
||||
if test == "measure":
|
||||
x = np.arange(0, 100, 0.1)
|
||||
s = np.shape(x)
|
||||
y = np.random.randn(s[0])
|
||||
for i in range(0, 4):
|
||||
print("\ni is : %d" % (i))
|
||||
x0 = i * 20
|
||||
x1 = x0 + 20
|
||||
(r0, r1) = measure("mean", x, y, x0, x1)
|
||||
print("mean: %f std: %f [0, 20]" % (r0, r1))
|
||||
(r0, r1) = measure("max", x, y, x0, x1)
|
||||
print("max: %f std: %f [0, 20]" % (r0, r1))
|
||||
(r0, r1) = measure("min", x, y, x0, x1)
|
||||
print("min: %f std: %f [0, 20]" % (r0, r1))
|
||||
(r0, r1) = measure("median", x, y, x0, x1)
|
||||
print("median: %f std: %f [0, 20]" % (r0, r1))
|
||||
(r0, r1) = measure("p2p", x, y, x0, x1)
|
||||
print("peak to peak: %f std: %f [0, 20]" % (r0, r1))
|
||||
|
||||
if test == "an_syn":
|
||||
(s, w, t, n) = an_syn(N=50, spont=50, driven=150, post=100, makewave=True)
|
||||
sh = np.array([])
|
||||
for i in range(len(s)):
|
||||
sh = np.append(sh, np.array(s[i]))
|
||||
(hist, bins) = np.histogram(sh, bins=250, range=(0, 250), new=True)
|
||||
if len(bins) > len(hist):
|
||||
bins = bins[0 : len(hist)]
|
||||
|
||||
import pylab as pl
|
||||
|
||||
pl.figure(1)
|
||||
pl.subplot(2, 2, 1)
|
||||
pl.plot(bins, hist)
|
||||
|
||||
pl.subplot(2, 2, 3)
|
||||
for i in range(len(w)):
|
||||
pl.plot(t, w[i])
|
||||
pl.hold = True
|
||||
(cvisit, cvisi, cvt, cvm, cvs) = isi_cv(s)
|
||||
order = np.argsort(cvt)
|
||||
cvt = cvt[order]
|
||||
cvs = cvs[order]
|
||||
cvm = cvm[order]
|
||||
pl.subplot(2, 2, 2)
|
||||
pl.plot(cvt, cvs / cvm)
|
||||
pl.show()
|
||||
|
||||
if test == "syns":
|
||||
(s, w, t, n) = syns(rate=20, delay=0, dur=100.0, N=5, makewave=True)
|
||||
sh = np.array([])
|
||||
for i in range(len(s)):
|
||||
sh = np.append(sh, np.array(s[i]))
|
||||
(hist, bins) = np.histogram(sh, bins=250, range=(0, 250), new=True)
|
||||
if len(bins) > len(hist):
|
||||
bins = bins[0 : len(hist)]
|
||||
|
||||
import pylab as pl
|
||||
|
||||
pl.figure(1)
|
||||
pl.subplot(2, 2, 1)
|
||||
pl.plot(bins, hist)
|
||||
|
||||
pl.subplot(2, 2, 3)
|
||||
pl.plot(t, w)
|
||||
pl.hold = True
|
||||
(cvisit, cvisi, cvt, cvm, cvs) = isi_cv(s)
|
||||
order = np.argsort(cvt)
|
||||
cvt = cvt[order]
|
||||
cvs = cvs[order]
|
||||
cvm = cvm[order]
|
||||
pl.subplot(2, 2, 2)
|
||||
pl.plot(cvt, cvs / cvm)
|
||||
pl.show()
|
||||
1505
cnmodel/util/pyqtgraphPlotHelpers.py
Executable file
1505
cnmodel/util/pyqtgraphPlotHelpers.py
Executable file
File diff suppressed because it is too large
Load Diff
28
cnmodel/util/random_seed.py
Normal file
28
cnmodel/util/random_seed.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import numpy as np
|
||||
import hashlib, struct
|
||||
|
||||
_current_seed = 0
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
"""
|
||||
Set the random seed to be used globally. If a string is supplied, it
|
||||
will be converted to int using hash().
|
||||
|
||||
This immediately seeds the numpy RNG. Any other RNGs must be seeded using
|
||||
current_seed()
|
||||
"""
|
||||
if isinstance(seed, str):
|
||||
seed = struct.unpack("=I", hashlib.md5(seed.encode("utf-8")).digest()[:4])[0]
|
||||
np.random.seed(seed)
|
||||
assert seed < 2 ** 64 # neuron RNG fails if seed is too large
|
||||
global _current_seed
|
||||
_current_seed = seed
|
||||
return seed
|
||||
|
||||
|
||||
def current_seed():
|
||||
"""
|
||||
Return the currently-set global random seed.
|
||||
"""
|
||||
return _current_seed
|
||||
1415
cnmodel/util/sound.py
Normal file
1415
cnmodel/util/sound.py
Normal file
File diff suppressed because it is too large
Load Diff
87
cnmodel/util/stim.py
Normal file
87
cnmodel/util/stim.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def make_pulse(stim):
|
||||
"""
|
||||
Generate a pulse train for current / voltage command. Returns a tuple.
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
stim : dict
|
||||
Holds parameters that determine stimulus shape:
|
||||
|
||||
* delay : time before first pulse
|
||||
* Sfreq : frequency of pulses
|
||||
* dur : duration of one pulse or main pulse
|
||||
* predur : duration of prepulse (default should be 0 for no prepulse)
|
||||
* amp : pulse amplitude
|
||||
* preamp : amplitude of prepulse
|
||||
* PT : delay between end of train and test pulse (0 for no test)
|
||||
* NP : number of pulses
|
||||
* hold : holding level (optional)
|
||||
* dt : timestep
|
||||
|
||||
Returns
|
||||
-------
|
||||
w : stimulus waveform
|
||||
maxt : duration of waveform
|
||||
tstims : index of each pulse in the train
|
||||
"""
|
||||
defaults = {
|
||||
"delay": 10,
|
||||
"Sfreq": 50,
|
||||
"dur": 100,
|
||||
"predur": 0.0,
|
||||
"post": 50.0,
|
||||
"amp": None,
|
||||
"preamp": 0.0,
|
||||
"PT": 0,
|
||||
"NP": 1,
|
||||
"hold": 0.0,
|
||||
"dt": None,
|
||||
}
|
||||
for k in stim:
|
||||
if k not in defaults:
|
||||
raise Exception("Stim parameter '%s' not accepted." % k)
|
||||
defaults.update(stim)
|
||||
stim = defaults
|
||||
for k, v in stim.items():
|
||||
if v is None:
|
||||
raise Exception("Must specify stim parameter '%s'." % k)
|
||||
|
||||
dt = stim["dt"]
|
||||
delay = int(np.floor(stim["delay"] / dt))
|
||||
ipi = int(np.floor((1000.0 / stim["Sfreq"]) / dt))
|
||||
pdur = int(np.floor(stim["dur"] / dt))
|
||||
posttest = int(np.floor(stim["PT"] / dt))
|
||||
ndur = 5
|
||||
if stim["predur"] > 0.0:
|
||||
predur = int(np.floor(stim["predur"] / dt))
|
||||
else:
|
||||
predur = 0.0
|
||||
if stim["PT"] == 0:
|
||||
ndur = 1
|
||||
|
||||
maxt = dt * (delay + predur + (ipi * (stim["NP"] + 3)) + posttest + pdur * ndur)
|
||||
hold = stim.get("hold", None)
|
||||
|
||||
w = np.zeros(int(np.floor(maxt / dt)))
|
||||
if hold is not None:
|
||||
w += hold
|
||||
|
||||
# make pulse
|
||||
tstims = [0] * int(stim["NP"])
|
||||
for j in range(0, int(stim["NP"])):
|
||||
prestart = delay
|
||||
start = int(prestart + predur + j * ipi)
|
||||
if predur > 0.0:
|
||||
w[prestart : prestart + predur] = stim["preamp"]
|
||||
w[start : start + pdur] = stim["amp"]
|
||||
tstims[j] = start
|
||||
if stim["PT"] > 0.0:
|
||||
for i in range(start + posttest, start + posttest + pdur):
|
||||
w[i] = stim["amp"]
|
||||
w = np.append(w, 0.0)
|
||||
maxt = maxt + dt
|
||||
return (w, maxt, tstims)
|
||||
208
cnmodel/util/talbotetalTicks.py
Normal file
208
cnmodel/util/talbotetalTicks.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
# import matplotlib
|
||||
# import matplotlib.pyplot as plt
|
||||
# import matplotlib.ticker as tckr
|
||||
# import matplotlib.transforms as mtransforms
|
||||
# import matplotlib.mlab as mlab
|
||||
|
||||
# An alpha version of the Talbot, Lin, Hanrahan tick mark generator for matplotlib.
|
||||
# Described in "An Extension of Wilkinson's Algorithm for Positioning Tick Labels on Axes"
|
||||
# by Justin Talbot, Sharon Lin, and Pat Hanrahan, InfoVis 2010.
|
||||
|
||||
# Implementation by Justin Talbot
|
||||
# This implementation is in the public domain.
|
||||
# Report bugs to jtalbot@stanford.edu
|
||||
|
||||
# A shortcoming:
|
||||
# The weights used in the paper were designed for static plots where the extent of
|
||||
# the tick marks unioned with the extent of the data defines the extent of the plot.
|
||||
# In a plot where the extent of the plot is defined by the user (e.g. an interactive
|
||||
# plot supporting panning and zooming), the weights don't work as well. In particular,
|
||||
# you would want to retune them assuming that the tick labels must be inside
|
||||
# the provided view range. You probably want higher weighting on simplicity and lower
|
||||
# on coverage and possibly density. But I haven't experimented in any detail with this.
|
||||
#
|
||||
# If you do intend on using this for static plots in matplotlib, you should set
|
||||
# only_inside to False in the call to Extended.extended. And then you should
|
||||
# manually set your view extent to include the min and max ticks if they are outside
|
||||
# the data range. This should produce the same results as the paper.
|
||||
|
||||
# class Extended(tckr.Locator):
|
||||
class Extended:
|
||||
# density is labels per inch
|
||||
def __init__(self, density=1, steps=None, figure=None, range=(0, 1), axis="x"):
|
||||
"""
|
||||
Keyword args:
|
||||
"""
|
||||
self._density = density
|
||||
self._figure = figure
|
||||
self._axis = axis
|
||||
self.range = range
|
||||
|
||||
if steps is None:
|
||||
self._steps = [1, 5, 2, 2.5, 4, 3]
|
||||
else:
|
||||
self._steps = steps
|
||||
|
||||
def coverage(self, dmin, dmax, lmin, lmax):
|
||||
range = dmax - dmin
|
||||
return 1 - 0.5 * (
|
||||
math.pow(dmax - lmax, 2) + math.pow(dmin - lmin, 2)
|
||||
) / math.pow(0.1 * range, 2)
|
||||
|
||||
def coverage_max(self, dmin, dmax, span):
|
||||
range = dmax - dmin
|
||||
if span > range:
|
||||
half = (span - range) / 2.0
|
||||
return 1 - math.pow(half, 2) / math.pow(0.1 * range, 2)
|
||||
else:
|
||||
return 1
|
||||
|
||||
def density(self, k, m, dmin, dmax, lmin, lmax):
|
||||
r = (k - 1.0) / (lmax - lmin)
|
||||
rt = (m - 1.0) / (max(lmax, dmax) - min(lmin, dmin))
|
||||
return 2 - max(r / rt, rt / r)
|
||||
|
||||
def density_max(self, k, m):
|
||||
if k >= m:
|
||||
return 2 - (k - 1.0) / (m - 1.0)
|
||||
else:
|
||||
return 1
|
||||
|
||||
def simplicity(self, q, Q, j, lmin, lmax, lstep):
|
||||
eps = 1e-10
|
||||
n = len(Q)
|
||||
i = Q.index(q) + 1
|
||||
v = (
|
||||
1
|
||||
if (
|
||||
(lmin % lstep < eps or (lstep - lmin % lstep) < eps)
|
||||
and lmin <= 0
|
||||
and lmax >= 0
|
||||
)
|
||||
else 0
|
||||
)
|
||||
return (n - i) / (n - 1.0) + v - j
|
||||
|
||||
def simplicity_max(self, q, Q, j):
|
||||
n = len(Q)
|
||||
i = Q.index(q) + 1
|
||||
v = 1
|
||||
return (n - i) / (n - 1.0) + v - j
|
||||
|
||||
def legibility(self, lmin, lmax, lstep):
|
||||
return 1
|
||||
|
||||
def legibility_max(self, lmin, lmax, lstep):
|
||||
return 1
|
||||
|
||||
def extended(
|
||||
self,
|
||||
dmin,
|
||||
dmax,
|
||||
m,
|
||||
Q=[1, 5, 2, 2.5, 4, 3],
|
||||
only_inside=False,
|
||||
w=[0.25, 0.2, 0.5, 0.05],
|
||||
):
|
||||
n = len(Q)
|
||||
best_score = -2.0
|
||||
|
||||
j = 1.0
|
||||
while j < float("infinity"):
|
||||
for q in Q:
|
||||
sm = self.simplicity_max(q, Q, j)
|
||||
|
||||
if w[0] * sm + w[1] + w[2] + w[3] < best_score:
|
||||
j = float("infinity")
|
||||
break
|
||||
|
||||
k = 2.0
|
||||
while k < float("infinity"):
|
||||
dm = self.density_max(k, m)
|
||||
|
||||
if w[0] * sm + w[1] + w[2] * dm + w[3] < best_score:
|
||||
break
|
||||
|
||||
delta = (dmax - dmin) / (k + 1.0) / j / q
|
||||
z = math.ceil(math.log(delta, 10))
|
||||
|
||||
while z < float("infinity"):
|
||||
step = j * q * math.pow(10, z)
|
||||
cm = self.coverage_max(dmin, dmax, step * (k - 1.0))
|
||||
|
||||
if w[0] * sm + w[1] * cm + w[2] * dm + w[3] < best_score:
|
||||
break
|
||||
|
||||
min_start = math.floor(dmax / step) * j - (k - 1.0) * j
|
||||
max_start = math.ceil(dmin / step) * j
|
||||
|
||||
if min_start > max_start:
|
||||
z = z + 1
|
||||
break
|
||||
|
||||
for start in range(int(min_start), int(max_start) + 1):
|
||||
lmin = start * (step / j)
|
||||
lmax = lmin + step * (k - 1.0)
|
||||
lstep = step
|
||||
|
||||
s = self.simplicity(q, Q, j, lmin, lmax, lstep)
|
||||
c = self.coverage(dmin, dmax, lmin, lmax)
|
||||
d = self.density(k, m, dmin, dmax, lmin, lmax)
|
||||
l = self.legibility(lmin, lmax, lstep)
|
||||
|
||||
score = w[0] * s + w[1] * c + w[2] * d + w[3] * l
|
||||
|
||||
if score > best_score and (
|
||||
not only_inside or (lmin >= dmin and lmax <= dmax)
|
||||
):
|
||||
best_score = score
|
||||
best = (lmin, lmax, lstep, q, k)
|
||||
z = z + 1
|
||||
k = k + 1
|
||||
j = j + 1
|
||||
return best
|
||||
|
||||
def __call__(self):
|
||||
vmin, vmax = self.range # self.axis.get_view_interval()
|
||||
fsize = {"x": 5.0, "y": 4.0}
|
||||
size = fsize[self._axis] # self._figure.get_size_inches()[self._axis]
|
||||
# density * size gives target number of intervals,
|
||||
# density * size + 1 gives target number of tick marks,
|
||||
# the density function converts this back to a density in data units (not inches)
|
||||
# should probably make this cleaner.
|
||||
best = self.extended(
|
||||
vmin,
|
||||
vmax,
|
||||
self._density * size + 1.0,
|
||||
only_inside=True,
|
||||
w=[0.25, 0.2, 0.5, 0.05],
|
||||
)
|
||||
locs = np.arange(best[4]) * best[2] + best[0]
|
||||
return locs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
# fig = plt.figure()
|
||||
# ax = fig.add_subplot(111)
|
||||
# ax.plot(10*np.random.randn(100), 10*np.random.randn(100), 'o')
|
||||
#
|
||||
# xmin, xmax = ax.xaxis.get_data_interval()
|
||||
# xrange = xmax-xmin
|
||||
# xmin, xmax = (xmin - xrange * 0.05, xmax + xrange * 0.05)
|
||||
#
|
||||
# ymin, ymax = ax.yaxis.get_data_interval()
|
||||
# yrange = ymax-ymin
|
||||
# ymin, ymax = (ymin - yrange * 0.05, ymax + yrange * 0.05)
|
||||
#
|
||||
# ax.xaxis.set_view_interval(xmin, xmax, ignore=True)
|
||||
# ax.yaxis.set_view_interval(ymin, ymax, ignore=True)
|
||||
# ax.xaxis.set_major_locator(Extended(density=0.5, figure=fig, which=0))
|
||||
# ax.yaxis.set_major_locator(Extended(density=0.5, figure=fig, which=1))
|
||||
#
|
||||
# ax.set_title('Talbot, Lin, Hanrahan 2010')
|
||||
#
|
||||
# plt.show()
|
||||
36
cnmodel/util/tests/test_expfitting.py
Normal file
36
cnmodel/util/tests/test_expfitting.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from __future__ import print_function
|
||||
from cnmodel.util import ExpFitting
|
||||
import numpy as np
|
||||
|
||||
|
||||
def test_fit1():
|
||||
fit = ExpFitting(nexp=1)
|
||||
x = np.linspace(0.0, 50, 500)
|
||||
p = [-50.0, 4.0, 5.0]
|
||||
y = p[0] + p[1] * np.exp(-x / p[2])
|
||||
res = fit.fit(x, y, fit.fitpars)
|
||||
pr = [float(res[k].value) for k in res.keys()]
|
||||
print("\noriginal: ", p)
|
||||
print("fit res: ", pr)
|
||||
for i, v in enumerate(p):
|
||||
assert np.allclose(v, pr[i])
|
||||
|
||||
|
||||
def test_fit2():
|
||||
fit = ExpFitting(nexp=2)
|
||||
x = np.linspace(0.0, 50, 500)
|
||||
p = [
|
||||
-50.0,
|
||||
4.0,
|
||||
5.0,
|
||||
1.0,
|
||||
4.5,
|
||||
] # last term is ratio of the two time constants (t2 = delta*t1)
|
||||
y = p[0] + p[1] * np.exp(-x / p[2]) + p[3] * np.exp(-x / (p[2] * p[4]))
|
||||
res = fit.fit(x, y, fit.fitpars)
|
||||
pr = [float(res[k].value) for k in res.keys()]
|
||||
print("\noriginal: ", p)
|
||||
print("fit res: ", pr)
|
||||
# we can only do this approximately for 2 exp fits
|
||||
for i, v in enumerate(p): # test each one individually
|
||||
assert np.allclose(pr[i] / v, 1.0, atol=1e-4, rtol=1e-2)
|
||||
40
cnmodel/util/tests/test_matlab.py
Normal file
40
cnmodel/util/tests/test_matlab.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
# from cnmodel.util.matlab_proc import MatlabProcess
|
||||
import matlab.engine
|
||||
|
||||
|
||||
def test_matlab():
|
||||
global proc
|
||||
try:
|
||||
# proc = MatlabProcess()
|
||||
proc = matlab.engine.start_matlab()
|
||||
except RuntimeError:
|
||||
# no matlab available; skip this test
|
||||
pytest.skip("MATLAB unavailable")
|
||||
|
||||
base_vcount = proc.who(nargout=1) # .shape[0]
|
||||
assert len(base_vcount) == 0
|
||||
|
||||
e4 = np.array(proc.eye(4, nargout=1))
|
||||
|
||||
assert isinstance(e4, np.ndarray)
|
||||
assert np.all(e4 == np.eye(4))
|
||||
# assert proc.who().shape[0] == base_vcount
|
||||
|
||||
o6_ref = np.array(proc.ones(6, nargout=1)) # _transfer=False))
|
||||
# o6 = o6_ref.get()
|
||||
o6 = np.array(proc.ones(6, nargout=1))
|
||||
assert np.all(o6 == np.ones(6))
|
||||
# print(proc.who(nargout=1))
|
||||
# assert proc.who(nargout=1) == base_vcount + 1
|
||||
|
||||
del o6_ref
|
||||
# assert proc.who().shape[0] == base_vcount
|
||||
|
||||
proc.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_matlab()
|
||||
98
cnmodel/util/tests/test_sound.py
Normal file
98
cnmodel/util/tests/test_sound.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import numpy as np
|
||||
from cnmodel.util import sound
|
||||
|
||||
|
||||
def test_conversions():
|
||||
pa = np.array([3990.5, 20, 0.3639, 2e-5])
|
||||
db = np.array([166, 120, 85.2, 0])
|
||||
|
||||
assert np.allclose(sound.pa_to_dbspl(pa), db, atol=0.1, rtol=0.002)
|
||||
assert np.allclose(sound.dbspl_to_pa(db), pa, atol=0.1, rtol=0.002)
|
||||
|
||||
|
||||
def test_tonepip():
|
||||
rate = 100000
|
||||
dur = 0.1
|
||||
ps = 0.01
|
||||
rd = 0.02
|
||||
pd = 0.08
|
||||
db = 60
|
||||
s1 = sound.TonePip(
|
||||
rate=rate,
|
||||
duration=dur,
|
||||
f0=5321,
|
||||
dbspl=db,
|
||||
pip_duration=pd,
|
||||
pip_start=[ps],
|
||||
ramp_duration=rd,
|
||||
)
|
||||
|
||||
# test array sizes
|
||||
assert s1.sound.size == s1.time.size == int(dur * rate) + 1
|
||||
|
||||
# test for consistency
|
||||
assert np.allclose(
|
||||
[s1.sound.min(), s1.sound.mean(), s1.sound.max()],
|
||||
[-0.028284253158247834, -1.0954891976168953e-10, 0.028284270354167296],
|
||||
)
|
||||
|
||||
# test that we got the requested amplitude
|
||||
assert np.allclose(s1.measure_dbspl(ps + rd, ps + pd - rd), db, atol=0.1, rtol=0.01)
|
||||
|
||||
# test for quiet before and after pip
|
||||
assert np.all(s1.sound[: int(ps * rate) - 1] == 0)
|
||||
assert np.all(s1.sound[int((ps + pd) * rate) + 1 :] == 0)
|
||||
|
||||
# test the sound can be recreated from its key
|
||||
key = s1.key()
|
||||
s2 = sound.create(**key)
|
||||
assert np.all(s1.time == s2.time)
|
||||
assert np.all(s1.sound == s2.sound)
|
||||
|
||||
|
||||
def test_noisepip():
|
||||
rate = 100000
|
||||
dur = 0.1
|
||||
ps = 0.01
|
||||
rd = 0.02
|
||||
pd = 0.08
|
||||
db = 60
|
||||
s1 = sound.NoisePip(
|
||||
rate=rate,
|
||||
duration=dur,
|
||||
seed=184724,
|
||||
dbspl=db,
|
||||
pip_duration=pd,
|
||||
pip_start=[ps],
|
||||
ramp_duration=rd,
|
||||
)
|
||||
|
||||
# test array sizes
|
||||
assert s1.sound.size == s1.time.size == int(dur * rate) + 1
|
||||
|
||||
# test for consistency
|
||||
assert np.allclose(
|
||||
[s1.sound.min(), s1.sound.mean(), s1.sound.max()],
|
||||
[-0.082260796003197786, -0.00018484322982972046, 0.069160217220832404],
|
||||
)
|
||||
|
||||
# test that we got the requested amplitude
|
||||
assert np.allclose(s1.measure_dbspl(ps + rd, ps + pd - rd), db, atol=0.1, rtol=0.01)
|
||||
|
||||
# test for quiet before and after pip
|
||||
assert np.all(s1.sound[: int(ps * rate) - 1] == 0)
|
||||
assert np.all(s1.sound[int((ps + pd) * rate) + 1 :] == 0)
|
||||
|
||||
# test the sound can be recreated from its key
|
||||
key = s1.key()
|
||||
s2 = sound.create(**key)
|
||||
# also test new seed works, and does not affect other sounds
|
||||
key["seed"] += 1
|
||||
s3 = sound.create(**key)
|
||||
s3.sound # generate here to advance rng before generating s2
|
||||
|
||||
assert np.all(s1.time == s2.time)
|
||||
assert np.all(s1.sound == s2.sound)
|
||||
start = int(ps * rate) + 1
|
||||
end = int((ps + pd) * rate) - 1
|
||||
assert not np.any(s1.sound[start:end] == s3.sound[start:end])
|
||||
29
cnmodel/util/tests/test_stim.py
Normal file
29
cnmodel/util/tests/test_stim.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import numpy as np
|
||||
from numpy.testing import assert_raises
|
||||
|
||||
from cnmodel.util import stim
|
||||
from neuron import h
|
||||
|
||||
h.dt = 0.025
|
||||
|
||||
|
||||
def test_make_pulse():
|
||||
params = dict(delay=10, Sfreq=50, dur=1, amp=15, PT=0, NP=5)
|
||||
|
||||
assert_raises(Exception, lambda: stim.make_pulse(params))
|
||||
params["dt"] = 0.025
|
||||
|
||||
w, maxt, times = stim.make_pulse(params)
|
||||
|
||||
assert w.min() == 0.0
|
||||
assert w.max() == 15
|
||||
assert w.dtype == np.float64
|
||||
|
||||
triggers = np.argwhere(np.diff(w) > 0)[:, 0] + 1
|
||||
assert np.all(triggers == times)
|
||||
assert w.sum() == 15 * len(times) * int(1 / h.dt)
|
||||
|
||||
params["PT"] = 100
|
||||
w, maxt, times = stim.make_pulse(params)
|
||||
triggers = np.argwhere(np.diff(w) > 0)[:, 0] + 1
|
||||
assert triggers[-1] - triggers[-2] == 100 / h.dt
|
||||
205
cnmodel/util/user_tester.py
Normal file
205
cnmodel/util/user_tester.py
Normal file
@@ -0,0 +1,205 @@
|
||||
from __future__ import print_function
|
||||
import os, sys, pickle, pprint
|
||||
import numpy as np
|
||||
import pyqtgraph as pg
|
||||
from .. import AUDIT_TESTS
|
||||
|
||||
|
||||
class UserTester(object):
|
||||
"""
|
||||
Base class for testing when a human is required to verify the results.
|
||||
|
||||
When a test is passed by the user, its output is saved and used as a basis
|
||||
for future tests. If future test results do not match the stored results,
|
||||
then the user is asked to decide whether to fail the test, or pass the
|
||||
test and store new results.
|
||||
|
||||
Subclasses must reimplement run_test() to return a dictionary of results
|
||||
to store. Optionally, compare_results and audit_result may also be
|
||||
reimplemented to customize the testing behavior.
|
||||
|
||||
By default, test results are stored in a 'test_data' directory relative
|
||||
to the file that defines the UserTester subclass in use.
|
||||
"""
|
||||
|
||||
data_dir = "test_data"
|
||||
|
||||
def __init__(self, key, *args, **kwds):
|
||||
"""Initialize with a string *key* that provides a short, unique
|
||||
description of this test. All other arguments are passed to run_test().
|
||||
|
||||
*key* is used to determine the file name for storing test results.
|
||||
"""
|
||||
self.audit = AUDIT_TESTS
|
||||
self.key = key
|
||||
self.rtol = 1e-3
|
||||
self.args = args
|
||||
self.assert_test_info(*args, **kwds)
|
||||
|
||||
def run_test(self, *args, **kwds):
|
||||
"""
|
||||
Exceute the test. All arguments are taken from __init__.
|
||||
Return a picklable dictionary of test results.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def compare_results(self, key, info, expect):
|
||||
"""
|
||||
Compare *result* of the current test against the previously stored
|
||||
result *expect*. If *expect* is None, then no previous result was
|
||||
stored.
|
||||
|
||||
If *result* and *expect* do not match, then raise an exception.
|
||||
"""
|
||||
# Check test structures are the same
|
||||
assert type(info) is type(expect)
|
||||
if hasattr(info, "__len__"):
|
||||
assert len(info) == len(expect)
|
||||
|
||||
if isinstance(info, dict):
|
||||
for k in info:
|
||||
assert k in expect
|
||||
for k in expect:
|
||||
assert k in info
|
||||
self.compare_results(k, info[k], expect[k])
|
||||
elif isinstance(info, list):
|
||||
for i in range(len(info)):
|
||||
self.compare_results(key, info[i], expect[i])
|
||||
elif isinstance(info, np.ndarray):
|
||||
assert info.shape == expect.shape
|
||||
if len(info) == 0:
|
||||
return
|
||||
# assert info.dtype == expect.dtype
|
||||
if info.dtype.fields is None:
|
||||
intnan = -9223372036854775808 # happens when np.nan is cast to int
|
||||
inans = np.isnan(info) | (info == intnan)
|
||||
enans = np.isnan(expect) | (expect == intnan)
|
||||
assert np.all(inans == enans)
|
||||
mask = ~inans
|
||||
if not np.allclose(info[mask], expect[mask], rtol=self.rtol):
|
||||
print(
|
||||
"\nComparing data array, shapes match: ",
|
||||
info.shape == expect.shape,
|
||||
)
|
||||
print("Model tested: %s, measure: %s" % (self.key, key))
|
||||
# print( 'args: ', dir(self.args[0]))
|
||||
print("Array expected: ", expect[mask])
|
||||
print("Array received: ", info[mask])
|
||||
try:
|
||||
self.args[0].print_all_mechs()
|
||||
except:
|
||||
print("args[0] is string: ", self.args[0])
|
||||
assert np.allclose(info[mask], expect[mask], rtol=self.rtol)
|
||||
else:
|
||||
for k in info.dtype.fields.keys():
|
||||
self.compare_results(k, info[k], expect[k])
|
||||
elif np.isscalar(info):
|
||||
if not np.allclose(info, expect, rtol=self.rtol):
|
||||
print("Comparing Scalar data, model: %s, measure: %s" % (self.key, key))
|
||||
# print 'args: ', dir(self.args[0])
|
||||
print(
|
||||
"Expected: ",
|
||||
expect,
|
||||
", received: ",
|
||||
info,
|
||||
" relative tolerance: ",
|
||||
self.rtol,
|
||||
)
|
||||
if isinstance(self.args[0], str):
|
||||
pass
|
||||
# print ': ', str
|
||||
else:
|
||||
self.args[0].print_all_mechs()
|
||||
assert np.allclose(info, expect, rtol=self.rtol)
|
||||
else:
|
||||
try:
|
||||
assert info == expect
|
||||
except AssertionError:
|
||||
raise
|
||||
except Exception:
|
||||
raise NotImplementedError(
|
||||
"Cannot compare objects of type %s" % type(info)
|
||||
)
|
||||
|
||||
def audit_result(self, info, expect):
|
||||
""" Display results and ask the user to decide whether the test passed.
|
||||
Return True for pass, False for fail.
|
||||
|
||||
If *expect* is None, then no previous test results were stored.
|
||||
"""
|
||||
app = pg.mkQApp()
|
||||
print("\n=== New test results for %s: ===\n" % self.key)
|
||||
pprint.pprint(info)
|
||||
|
||||
# we use DiffTreeWidget to display differences between large data structures, but
|
||||
# this is not present in mainline pyqtgraph yet
|
||||
if hasattr(pg, "DiffTreeWidget"):
|
||||
win = pg.DiffTreeWidget()
|
||||
else:
|
||||
from cnmodel.util.difftreewidget import DiffTreeWidget
|
||||
|
||||
win = DiffTreeWidget()
|
||||
|
||||
win.resize(800, 800)
|
||||
win.setData(expect, info)
|
||||
win.show()
|
||||
print("Store new test results? [y/n]")
|
||||
yn = raw_input()
|
||||
win.hide()
|
||||
return yn.lower().startswith("y")
|
||||
|
||||
def assert_test_info(self, *args, **kwds):
|
||||
"""
|
||||
Test *cell* and raise exception if the results do not match prior
|
||||
data.
|
||||
"""
|
||||
result = self.run_test(*args, **kwds)
|
||||
expect = self.load_test_result()
|
||||
try:
|
||||
assert expect is not None
|
||||
self.compare_results(None, result, expect)
|
||||
except:
|
||||
if not self.audit:
|
||||
if expect is None:
|
||||
raise Exception(
|
||||
"No prior test results for test '%s'. "
|
||||
"Run test.py --audit store new test data." % self.key
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
store = self.audit_result(result, expect)
|
||||
if store:
|
||||
self.save_test_result(result)
|
||||
else:
|
||||
raise Exception("Rejected test results for '%s'" % self.key)
|
||||
|
||||
def result_file(self):
|
||||
"""
|
||||
Return a file name to be used for storing / retrieving test results
|
||||
given *self.key*.
|
||||
"""
|
||||
modfile = sys.modules[self.__class__.__module__].__file__
|
||||
path = os.path.dirname(modfile)
|
||||
return os.path.join(path, self.data_dir, self.key + ".pk")
|
||||
|
||||
def load_test_result(self):
|
||||
"""
|
||||
Load prior test results for *self.key*.
|
||||
If there are no prior results, return None.
|
||||
"""
|
||||
fn = self.result_file()
|
||||
if os.path.isfile(fn):
|
||||
return pickle.load(open(fn, "rb"), encoding="latin1")
|
||||
return None
|
||||
|
||||
def save_test_result(self, result):
|
||||
"""
|
||||
Store test results for *self.key*.
|
||||
Th e*result* argument must be picklable.
|
||||
"""
|
||||
fn = self.result_file()
|
||||
dirname = os.path.dirname(fn)
|
||||
if not os.path.isdir(dirname):
|
||||
os.mkdir(dirname)
|
||||
pickle.dump(result, open(fn, "wb"))
|
||||
Reference in New Issue
Block a user