model of DCN pyramidal neuron
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

658 lines
20 KiB

"""
Test synaptic connections between two different cell types.
Usage: python compare_simple_multisynapses.py <pre_celltype> <post_celltype>
This script:
1. Creates single pre- and postsynaptic cells
2. Creates a single synaptic terminal between the two cells, using the multisite synapse method.
3. Stimulates the presynaptic cell by current injection.
4. Records and analyzes the resulting post-synaptic events.
5. Repeats 3, 4 100 times to get an average postsynaptic event.
6. stores the resulting waveform in a pandas database
This is used mainly to check that the strength, kinetics, and dynamics of
each synapse type is working as expected.
Requires Python 3.6
"""
import sys
import argparse
from pathlib import Path
import numpy as np
# import pyqtgraph as pg
import matplotlib.pyplot as mpl
import neuron as h
import cnmodel.util.PlotHelpers as PH
from cnmodel.protocols import SynapseTest
from cnmodel import cells
from cnmodel.synapses import Synapse
import pickle
import lmfit
convergence = {
"sgc": {
"bushy": 1,
"tstellate": 1,
"dstellate": 1,
"octopus": 1,
"pyramidal": 1,
"tuberculoventral": 1,
},
"dstellate": {
"bushy": 1,
"tstellate": 1,
"dstellate": 1,
"pyramidal": 1,
"tuberculoventral": 1,
},
"tuberculoventral": {
"bushy": 1,
"tstellate": 1,
"pyramidal": 1,
"tuberculoventral": 1,
},
}
class Exp2SynFitting:
"""
Fit waveform against Exp2SYN function (used by Neuron)
THe function is (from the documentation):
i = G * (v - e) i(nanoamps), g(micromhos);
G = weight * factor * (exp(-t/tau2) - exp(-t/tau1))
WHere factor is evaluated when initializing(in neuron; see exp2syn.mod source) as:
tp = (tau1*tau2)/(tau2 - tau1) * log(tau2/tau1) (log10)
factor = -exp(-tp/tau1) + exp(-tp/tau2)
factor = 1/factor
Parameters
----------
initpars : dict
dict of initial parameters. For example: {'tau1': 0.1,
'tau2': 0.3, 'weight': 0.1, 'delay' : 0.0, 'erev': -80.} (erev in mV)
bounds : dict
dictionary of bounds for each parameter, with a list of lower and upper values.
"""
def __init__(self, initpars=None, bounds=None, functype="exp2syn"):
self.fitpars = lmfit.Parameters()
if functype == "exp2syn": # future handle other functions like alpha
# (Name, Value, Vary, Min, Max, Expr)
self.fitpars.add_many(
("tau1", initpars["tau1"], True, 0.05, 25.0, None),
# ('tau2', initpars['tau2'], True, 0.1, 50., None),
("tauratio", initpars["tauratio"], True, 1.0001, 100.0, None),
("weight", initpars["weight"], True, 1e-6, 1, None),
("erev", initpars["erev"], False), # do not adjust!
("v", initpars["v"], False),
("delay", initpars["delay"], True, 0.0, 5.0, None),
)
self.func = self.exp2syn_err
else:
raise ValueError
def fit(self, x, y, p, verbose=False):
"""
Perform the curve fit against the specified function
Parameters
----------
x : time base for waveform (np array or list)
y : waveform (1-d np array or list)
p : parameters in lmfit Parameters structure
verbose : boolean (default: False)
If true, print the parameters in a nice format
"""
kws = {"maxfev": 5000}
# print('p: ', p)
self.mim = lmfit.minimize(
self.func, p, method="least_squares", args=(x, y)
) # , kws=kws)
if verbose:
lmfit.printfuncs.report_fit(self.mim.params)
fitpars = self.mim.params
return fitpars
# @staticmethod
def exp2syn(self, x, tau1, tauratio, weight, erev, v, delay):
"""
Compute the exp2syn waveform current as it is done in Neuron
Note on units:
The units are assumed to be "self consistent"
Thus, if the time base is in msec, tau1, tau2 and the delay are
in msec. erev and v should be in matching units (e.g., mV)
Note thtat the weight is not equal to the conducdtance G, because
of the scaling of the waveform
Parameters
----------
x : time base
tau1 : rising tau for waveform
tau2 : falling tau for waveform
weight : amplitude of the waveform (conductance)
erev : reversal potential for ionic species (used to compute i)
v : holding voltage at which waveform is computed
delay : delay to start of function
Returns
-------
i, the calucated current trace for these parameters.
"""
# we handle the requirement that tau2 > tau1 by setting the
# expression in lmfit, and using tauratio rather than a direct
# tau2.
# if tau1/tau2 > 1.0: # make sure tau1 is less than tau2
# tau1 = 0.999*tau2
tau2 = tauratio * tau1
tp = (tau1 * tau2) / (tau2 - tau1) * np.log(tau2 / tau1)
factor = -np.exp(-tp / tau1) + np.exp(-tp / tau2)
factor = 1.0 / factor
G = (
weight
* factor
* (np.exp(-(x - delay) / tau2) - np.exp(-(x - delay) / tau1))
)
G[x - delay < 0] = 0.0
i = G * (v - erev) # i(nanoamps), g(micromhos);
return i
def exp2syn_err(self, p, x, y):
return np.fabs(
y - self.exp2syn(x, **dict([(k, p.value) for k, p in p.items()]))
)
def factor(self, tau1, tauratio, weight):
"""
calculate tau-scaled weight
"""
tau2 = tau1 * tauratio
tp = (tau1 * tau2) / (tau2 - tau1) * np.log(tau2 / tau1)
factor = -np.exp(-tp / tau1) + np.exp(-tp / tau2)
factor = 1.0 / factor
G = weight * factor
return G
def testexp():
"""
Test the exp2syn fitting function
"""
pars = {
"tau1": 0.1,
"tauratio": 2.0,
"weight": 0.1,
"erev": -70.0,
"v": -65.0,
"delay": 1,
}
F = Exp2SynFitting(
initpars={
"tau1": 0.2,
"tauratio": 2.0,
"weight": 0.1,
"erev": -70.0,
"v": -65.0,
"delay": 0,
}
)
t = np.arange(0, 10.0, 0.01)
p = F.fitpars
target = F.exp2syn(
t,
pars["tau1"],
pars["tauratio"],
pars["weight"],
pars["erev"],
pars["v"],
pars["delay"],
)
# print(F.fitpars)
pars_fit = F.fit(t, target, F.fitpars)
print("\nTest fit result: ")
lmfit.printfuncs.report_fit(F.mim.params)
print("( tau2 = ", pars_fit["tau1"].value * pars_fit["tauratio"].value, ")")
print("target parameters: ", pars)
print("\n")
def compute_psc(synapsetype="multisite", celltypes=["sgc", "tstellate"]):
"""
Compute the PSC between the specified two cell tpes
The type of PSC is set by synspase type and must be 'multisite' or 'simple'
"""
assert synapsetype in ["multisite", "simple"]
c = []
for cellType in celltypes:
if cellType == "sgc":
cell = cells.SGC.create()
elif cellType == "tstellate":
cell = cells.TStellate.create(debug=True, ttx=False)
elif (
cellType == "dstellate"
): # Type I-II Rothman model, similiar excitability (Xie/Manis, unpublished)
cell = cells.DStellate.create(model="RM03", debug=True, ttx=False)
# elif cellType == 'dstellate_eager': # From Eager et al.
# cell = cells.DStellate.create(model='Eager', debug=True, ttx=False)
elif cellType == "bushy":
cell = cells.Bushy.create(debug=True, ttx=True)
elif cellType == "octopus":
cell = cells.Octopus.create(debug=True, ttx=True)
elif cellType == "tuberculoventral":
cell = cells.Tuberculoventral.create(debug=True, ttx=True)
elif cellType == "pyramidal":
cell = cells.Pyramidal.create(debug=True, ttx=True)
elif cellType == "cartwheel":
cell = cells.Cartwheel.create(debug=True, ttx=True)
else:
raise ValueError("Unknown cell type '%s'" % cellType)
c.append(cell)
preCell, postCell = c
print(f"Computing psc for connection {celltypes[0]:s} -> {celltypes[1]:s}")
nTerminals = convergence.get(celltypes[0], {}).get(celltypes[1], None)
if nTerminals is None:
nTerminals = 1
print(
f"Warning: Unknown convergence for {celltypes[0]:s} -> {celltypes[1]:s}, ASSUMING {nTerminals:d} terminals"
)
if celltypes == ["sgc", "bushy"]:
niter = 50
else:
niter = 200
assert synapsetype in ["simple", "multisite"]
if synapsetype == "simple":
niter = 1
st = SynapseTest()
dt = 0.010
stim = {
"NP": 1,
"Sfreq": 100.0,
"delay": 0.0,
"dur": 0.5,
"amp": 10.0,
"PT": 0.0,
"dt": dt,
}
st.run(
preCell.soma,
postCell.soma,
nTerminals,
dt=dt,
vclamp=-65.0,
iterations=niter,
synapsetype=synapsetype,
tstop=50.0,
stim_params=stim,
)
# st.show_result() # pyqtgraph plotting -
# st.plots['VPre'].setYRange(-70., 10.)
# st.plots['EPSC'].setYRange(-2.0, 0.5)
# st.plots['latency2080'].setYRange(0., 1.0)
# st.plots['halfwidth'].setYRange(0., 1.0)
# st.plots['RT'].setYRange(0., 0.2)
# st.plots['latency'].setYRange(0., 1.0)
# st.plots['latency_distribution'].setYRange(0., 1.0)
return st # need to keep st alive in memory
def fit_one(st, stk):
"""
Fit one trace to the exp2syn function
Parameters
----------
st : dict (no default)
A dictionary containing (at least) the following keys:
't' : the time base for the trace to be fit
'i' : a list of arrays or a 2d numpy array of the data to be fit
Returns
-------
pars : The fitting parameters
The fitted parameters are returned as an lmfit Parameters object, so access
individual parameters a pars['parname'].value
fitted : The fitted waveform, of the same length as the source data waveform
"""
if stk[0] in ["sgc", "tstellate", "granule"]:
erev = (
0.0
) # set erev according to the source cell (excitatory: 0, inhibitory: -80)
else:
erev = -70.0 # value used in gly_psd
print("\nstk, erev: ", stk, erev)
F = Exp2SynFitting(
initpars={
"tau1": 1.0,
"tauratio": 5.0,
"weight": 0.0001,
"erev": erev,
"v": -65.0,
"delay": 0,
}
)
t = st["t"]
p = F.fitpars
target = np.mean(np.array(st["i"]), axis=0)
# print(F.fitpars)
pars = F.fit(t, target, F.fitpars)
gw = F.factor(pars["tau1"].value, pars["tauratio"].value, pars["weight"].value)
pars.add("GWeight", value=gw, vary=False)
fitted = F.exp2syn(
st["t"],
pars["tau1"],
pars["tauratio"],
pars["weight"],
pars["erev"],
pars["v"],
pars["delay"],
)
lmfit.printfuncs.report_fit(F.mim.params)
return (pars, fitted)
def fit_all():
"""
Fit exp2syn against the traces in stm
This fits all pre-post cell pairs, and returns the fit
"""
stm = read_pickle("multisite.pkl")
fits = {}
fitted = {}
# print('stm keys: ', stm.keys())
# exit()
for i, stk in enumerate(stm.keys()): # for each pre-post cell pair
fitp, fit = fit_one(stm[stk], stk)
fits[stk] = fitp
fitted[stk] = {"t": stm[stk]["t"], "i": [fit], "pars": fitp}
with (open("simple.pkl", "wb")) as fh:
pickle.dump(fitted, fh)
return fits
def plot_all(stm, sts=None):
P = PH.Plotter((3, 5), figsize=(11, 6))
ax = P.axarr.ravel()
keypairorder = []
for i, stk in enumerate(stm.keys()):
keypairorder.append(stk)
data = stm[stk]
idat = np.array(data["i"])
ax[i].plot(np.array(data["t"]), np.mean(idat, axis=0), "c-", linewidth=1.5)
sd = np.std(idat, axis=0)
ax[i].plot(
np.array(data["t"]), np.mean(idat, axis=0) + sd, "c--", linewidth=0.5
) # for j in range(idat.shape[0]):
ax[i].plot(
np.array(data["t"]), np.mean(idat, axis=0) - sd, "c--", linewidth=0.5
) # for j in range(idat.shape[0]):
rel = 0
for j in range(idat.shape[0]):
if np.min(idat[j, :]) < -1e-2 or np.max(idat[j, :]) > 1e-2:
rel += 1
print(f"{str(stk):s} {rel:d} of {idat.shape[0]:d} {str(idat.shape):s}")
# ax[i].plot(data['t'], idat[j], 'k-', linewidth=0.5, alpha=0.25)
ax[i].set_title("%s : %s" % (stk[0], stk[1]), fontsize=7)
if sts is not None: # plot the matching exp2syn on this
for i, stks in enumerate(keypairorder):
datas = sts[stks]
isdat = np.array(datas["i"])
ax[i].plot(datas["t"], np.mean(isdat, axis=0), "m-", linewidth=3, alpha=0.5)
mpl.show()
def print_exp2syn_fits(st):
stkeys = list(st.keys())
postcells = [
"bushy",
"tstellate",
"dstellate",
"octopus",
"pyramidal",
"tuberculoventral",
"cartwheel",
]
fmts = {
"weight": "{0:<18.6f}",
"GWeight": "{0:<18.6f}",
"tau1": "{0:<18.3f}",
"tau2": "{0:<18.3f}",
"delay": "{0:<18.3f}",
"erev": "{0:<18.1f}",
}
pars = list(fmts.keys())
for pre in ["sgc", "dstellate", "tuberculoventral"]:
firstline = False
vrow = dict.fromkeys(pars, False)
for v in pars:
for post in postcells:
if (pre, post) not in stkeys:
print("{0:<18s}".format("0"), end="")
continue
if not firstline:
print("\n%s" % pre.upper())
print("{0:<18s}".format(" "), end="")
for i in range(len(postcells)):
print("{0:<18s}".format(postcells[i]), end="")
print()
firstline = True
if firstline:
if not vrow[v]:
if v == "tauratio":
print("{0:<18s}".format("tau2"), end="")
else:
print("{0:<18s}".format(v), end="")
vrow[v] = True
# print(fits[(pre, post)].keys())
fits = st[(pre, post)]["pars"]
# print(v, fits)
if v in ["tauratio", "tau2"]:
print(
fmts[v].format(fits["tau1"].value * fits["tauratio"].value),
end="",
)
else:
print(fmts[v].format(fits[v].value), end="")
print()
def read_pickle(mode):
"""
Read the pickled file generated by runall
Parameters
----------
mode : str
either 'simple' or 'multisite'
Returns:
the resulting data, which is a dictionary
"""
with (open(Path(mode).with_suffix(".pkl"), "rb")) as fh:
st = pickle.load(fh)
return st
def run_all():
"""
Run all of the multisite synapse calculations for each cell pair
Save the results in the multisite.pkl file
"""
st = {}
pkst = {}
for pre in convergence.keys():
for post in convergence[pre]:
# if pre == 'sgc' and post in ['bushy', 'tstellate']:
sti = compute_psc(synapsetype="multisite", celltypes=[pre, post])
st[(pre, post)] = sti
# remove neuron objects before pickling
pkst[(pre, post)] = {
"t": sti["t"],
"i": sti.isoma,
"v": sti["v_soma"],
"pre": sti["v_pre"],
}
with (open(Path("multisite").with_suffix(".pkl"), "wb")) as fh:
pickle.dump(pkst, fh)
plot_all(pkst)
def run_one(pre, post):
"""
Run all of the multisite synapse calculations for each cell pair
Save the results in the multisite.pkl file
"""
assert pre in ["sgc", "tstellate", "dstellate", "tuberculoventral"]
assert post in ["bushy", "tstellate", "dstellate", "tuberculoventral", "pyramidal"]
st = {}
pkst = read_pickle("multisite")
# if pre == 'sgc' and post in ['bushy', 'tstellate']:
sti = compute_psc(synapsetype="multisite", celltypes=[pre, post])
st[(pre, post)] = sti
# remove neuron objects before pickling
pkst[(pre, post)] = {
"t": sti["t"],
"i": sti.isoma,
"v": sti["v_soma"],
"pre": sti["v_pre"],
}
with (open(Path("multisite").with_suffix(".pkl"), "wb")) as fh:
pickle.dump(pkst, fh)
plot_all(pkst)
def main():
parser = argparse.ArgumentParser(
description="Compare simple and multisite synapses"
)
parser.add_argument(
"-m",
"--mode",
type=str,
dest="mode",
default="None",
choices=["simple", "multisite", "both"],
help="Select mode [simple, multisite, compare]",
)
parser.add_argument(
"-r",
"--run",
action="store_true",
dest="run",
help="Run multisite models between all cell types",
)
parser.add_argument(
"-o",
"--runone",
action="store_true",
dest="runone",
help="Run multisite models between specified cell types",
)
parser.add_argument(
"--pre",
type=str,
dest="pre",
default="None",
help="Select presynaptic cell for runone",
)
parser.add_argument(
"--post",
type=str,
dest="post",
default="None",
help="Select postsynaptic cell for runone",
)
parser.add_argument(
"-f",
"--fit",
action="store_true",
dest="fit",
help="Fit exp2syn waveforms to multisite data",
)
parser.add_argument(
"-t",
"--test",
action="store_true",
dest="test",
help="Run the test on exp2syn fitting",
)
parser.add_argument(
"-p",
"--plot",
action="store_true",
dest="plot",
help="Plot the current comparison between simple and multisite",
)
parser.add_argument(
"-l",
"--list",
action="store_true",
dest="list",
help="List the simple fit parameters",
)
args = parser.parse_args()
if args.test:
testexp()
exit()
if args.run:
run_all()
exit()
if args.runone:
run_one(args.pre, args.post)
exit()
if args.fit:
fit_all() # self contained - always fits exp2syn against the current multistie data
ds = read_pickle("simple")
dm = read_pickle("multisite")
plot_all(dm, ds)
exit()
if args.plot:
if args.mode in ["simple", "multisite"]:
d = read_pickle(args.mode)
plot_all(d)
exit()
elif args.mode in ["both"]:
ds = read_pickle("simple")
dm = read_pickle("multisite")
plot_all(dm, ds)
exit()
else:
print(f"Mode {args.mode:s} is not valid")
exit()
if args.list:
ds = read_pickle("simple")
print_exp2syn_fits(ds)
# if sys.flags.interactive == 0:
# pg.QtGui.QApplication.exec_()
if __name__ == "__main__":
main()