""" Test synaptic connections between two different cell types. Usage: python compare_simple_multisynapses.py This script: 1. Creates single pre- and postsynaptic cells 2. Creates a single synaptic terminal between the two cells, using the multisite synapse method. 3. Stimulates the presynaptic cell by current injection. 4. Records and analyzes the resulting post-synaptic events. 5. Repeats 3, 4 100 times to get an average postsynaptic event. 6. stores the resulting waveform in a pandas database This is used mainly to check that the strength, kinetics, and dynamics of each synapse type is working as expected. Requires Python 3.6 """ import sys import argparse from pathlib import Path import numpy as np # import pyqtgraph as pg import matplotlib.pyplot as mpl import neuron as h import cnmodel.util.PlotHelpers as PH from cnmodel.protocols import SynapseTest from cnmodel import cells from cnmodel.synapses import Synapse import pickle import lmfit convergence = { "sgc": { "bushy": 1, "tstellate": 1, "dstellate": 1, "octopus": 1, "pyramidal": 1, "tuberculoventral": 1, }, "dstellate": { "bushy": 1, "tstellate": 1, "dstellate": 1, "pyramidal": 1, "tuberculoventral": 1, }, "tuberculoventral": { "bushy": 1, "tstellate": 1, "pyramidal": 1, "tuberculoventral": 1, }, } class Exp2SynFitting: """ Fit waveform against Exp2SYN function (used by Neuron) THe function is (from the documentation): i = G * (v - e) i(nanoamps), g(micromhos); G = weight * factor * (exp(-t/tau2) - exp(-t/tau1)) WHere factor is evaluated when initializing(in neuron; see exp2syn.mod source) as: tp = (tau1*tau2)/(tau2 - tau1) * log(tau2/tau1) (log10) factor = -exp(-tp/tau1) + exp(-tp/tau2) factor = 1/factor Parameters ---------- initpars : dict dict of initial parameters. For example: {'tau1': 0.1, 'tau2': 0.3, 'weight': 0.1, 'delay' : 0.0, 'erev': -80.} (erev in mV) bounds : dict dictionary of bounds for each parameter, with a list of lower and upper values. """ def __init__(self, initpars=None, bounds=None, functype="exp2syn"): self.fitpars = lmfit.Parameters() if functype == "exp2syn": # future handle other functions like alpha # (Name, Value, Vary, Min, Max, Expr) self.fitpars.add_many( ("tau1", initpars["tau1"], True, 0.05, 25.0, None), # ('tau2', initpars['tau2'], True, 0.1, 50., None), ("tauratio", initpars["tauratio"], True, 1.0001, 100.0, None), ("weight", initpars["weight"], True, 1e-6, 1, None), ("erev", initpars["erev"], False), # do not adjust! ("v", initpars["v"], False), ("delay", initpars["delay"], True, 0.0, 5.0, None), ) self.func = self.exp2syn_err else: raise ValueError def fit(self, x, y, p, verbose=False): """ Perform the curve fit against the specified function Parameters ---------- x : time base for waveform (np array or list) y : waveform (1-d np array or list) p : parameters in lmfit Parameters structure verbose : boolean (default: False) If true, print the parameters in a nice format """ kws = {"maxfev": 5000} # print('p: ', p) self.mim = lmfit.minimize( self.func, p, method="least_squares", args=(x, y) ) # , kws=kws) if verbose: lmfit.printfuncs.report_fit(self.mim.params) fitpars = self.mim.params return fitpars # @staticmethod def exp2syn(self, x, tau1, tauratio, weight, erev, v, delay): """ Compute the exp2syn waveform current as it is done in Neuron Note on units: The units are assumed to be "self consistent" Thus, if the time base is in msec, tau1, tau2 and the delay are in msec. erev and v should be in matching units (e.g., mV) Note thtat the weight is not equal to the conducdtance G, because of the scaling of the waveform Parameters ---------- x : time base tau1 : rising tau for waveform tau2 : falling tau for waveform weight : amplitude of the waveform (conductance) erev : reversal potential for ionic species (used to compute i) v : holding voltage at which waveform is computed delay : delay to start of function Returns ------- i, the calucated current trace for these parameters. """ # we handle the requirement that tau2 > tau1 by setting the # expression in lmfit, and using tauratio rather than a direct # tau2. # if tau1/tau2 > 1.0: # make sure tau1 is less than tau2 # tau1 = 0.999*tau2 tau2 = tauratio * tau1 tp = (tau1 * tau2) / (tau2 - tau1) * np.log(tau2 / tau1) factor = -np.exp(-tp / tau1) + np.exp(-tp / tau2) factor = 1.0 / factor G = ( weight * factor * (np.exp(-(x - delay) / tau2) - np.exp(-(x - delay) / tau1)) ) G[x - delay < 0] = 0.0 i = G * (v - erev) # i(nanoamps), g(micromhos); return i def exp2syn_err(self, p, x, y): return np.fabs( y - self.exp2syn(x, **dict([(k, p.value) for k, p in p.items()])) ) def factor(self, tau1, tauratio, weight): """ calculate tau-scaled weight """ tau2 = tau1 * tauratio tp = (tau1 * tau2) / (tau2 - tau1) * np.log(tau2 / tau1) factor = -np.exp(-tp / tau1) + np.exp(-tp / tau2) factor = 1.0 / factor G = weight * factor return G def testexp(): """ Test the exp2syn fitting function """ pars = { "tau1": 0.1, "tauratio": 2.0, "weight": 0.1, "erev": -70.0, "v": -65.0, "delay": 1, } F = Exp2SynFitting( initpars={ "tau1": 0.2, "tauratio": 2.0, "weight": 0.1, "erev": -70.0, "v": -65.0, "delay": 0, } ) t = np.arange(0, 10.0, 0.01) p = F.fitpars target = F.exp2syn( t, pars["tau1"], pars["tauratio"], pars["weight"], pars["erev"], pars["v"], pars["delay"], ) # print(F.fitpars) pars_fit = F.fit(t, target, F.fitpars) print("\nTest fit result: ") lmfit.printfuncs.report_fit(F.mim.params) print("( tau2 = ", pars_fit["tau1"].value * pars_fit["tauratio"].value, ")") print("target parameters: ", pars) print("\n") def compute_psc(synapsetype="multisite", celltypes=["sgc", "tstellate"]): """ Compute the PSC between the specified two cell tpes The type of PSC is set by synspase type and must be 'multisite' or 'simple' """ assert synapsetype in ["multisite", "simple"] c = [] for cellType in celltypes: if cellType == "sgc": cell = cells.SGC.create() elif cellType == "tstellate": cell = cells.TStellate.create(debug=True, ttx=False) elif ( cellType == "dstellate" ): # Type I-II Rothman model, similiar excitability (Xie/Manis, unpublished) cell = cells.DStellate.create(model="RM03", debug=True, ttx=False) # elif cellType == 'dstellate_eager': # From Eager et al. # cell = cells.DStellate.create(model='Eager', debug=True, ttx=False) elif cellType == "bushy": cell = cells.Bushy.create(debug=True, ttx=True) elif cellType == "octopus": cell = cells.Octopus.create(debug=True, ttx=True) elif cellType == "tuberculoventral": cell = cells.Tuberculoventral.create(debug=True, ttx=True) elif cellType == "pyramidal": cell = cells.Pyramidal.create(debug=True, ttx=True) elif cellType == "cartwheel": cell = cells.Cartwheel.create(debug=True, ttx=True) else: raise ValueError("Unknown cell type '%s'" % cellType) c.append(cell) preCell, postCell = c print(f"Computing psc for connection {celltypes[0]:s} -> {celltypes[1]:s}") nTerminals = convergence.get(celltypes[0], {}).get(celltypes[1], None) if nTerminals is None: nTerminals = 1 print( f"Warning: Unknown convergence for {celltypes[0]:s} -> {celltypes[1]:s}, ASSUMING {nTerminals:d} terminals" ) if celltypes == ["sgc", "bushy"]: niter = 50 else: niter = 200 assert synapsetype in ["simple", "multisite"] if synapsetype == "simple": niter = 1 st = SynapseTest() dt = 0.010 stim = { "NP": 1, "Sfreq": 100.0, "delay": 0.0, "dur": 0.5, "amp": 10.0, "PT": 0.0, "dt": dt, } st.run( preCell.soma, postCell.soma, nTerminals, dt=dt, vclamp=-65.0, iterations=niter, synapsetype=synapsetype, tstop=50.0, stim_params=stim, ) # st.show_result() # pyqtgraph plotting - # st.plots['VPre'].setYRange(-70., 10.) # st.plots['EPSC'].setYRange(-2.0, 0.5) # st.plots['latency2080'].setYRange(0., 1.0) # st.plots['halfwidth'].setYRange(0., 1.0) # st.plots['RT'].setYRange(0., 0.2) # st.plots['latency'].setYRange(0., 1.0) # st.plots['latency_distribution'].setYRange(0., 1.0) return st # need to keep st alive in memory def fit_one(st, stk): """ Fit one trace to the exp2syn function Parameters ---------- st : dict (no default) A dictionary containing (at least) the following keys: 't' : the time base for the trace to be fit 'i' : a list of arrays or a 2d numpy array of the data to be fit Returns ------- pars : The fitting parameters The fitted parameters are returned as an lmfit Parameters object, so access individual parameters a pars['parname'].value fitted : The fitted waveform, of the same length as the source data waveform """ if stk[0] in ["sgc", "tstellate", "granule"]: erev = ( 0.0 ) # set erev according to the source cell (excitatory: 0, inhibitory: -80) else: erev = -70.0 # value used in gly_psd print("\nstk, erev: ", stk, erev) F = Exp2SynFitting( initpars={ "tau1": 1.0, "tauratio": 5.0, "weight": 0.0001, "erev": erev, "v": -65.0, "delay": 0, } ) t = st["t"] p = F.fitpars target = np.mean(np.array(st["i"]), axis=0) # print(F.fitpars) pars = F.fit(t, target, F.fitpars) gw = F.factor(pars["tau1"].value, pars["tauratio"].value, pars["weight"].value) pars.add("GWeight", value=gw, vary=False) fitted = F.exp2syn( st["t"], pars["tau1"], pars["tauratio"], pars["weight"], pars["erev"], pars["v"], pars["delay"], ) lmfit.printfuncs.report_fit(F.mim.params) return (pars, fitted) def fit_all(): """ Fit exp2syn against the traces in stm This fits all pre-post cell pairs, and returns the fit """ stm = read_pickle("multisite.pkl") fits = {} fitted = {} # print('stm keys: ', stm.keys()) # exit() for i, stk in enumerate(stm.keys()): # for each pre-post cell pair fitp, fit = fit_one(stm[stk], stk) fits[stk] = fitp fitted[stk] = {"t": stm[stk]["t"], "i": [fit], "pars": fitp} with (open("simple.pkl", "wb")) as fh: pickle.dump(fitted, fh) return fits def plot_all(stm, sts=None): P = PH.Plotter((3, 5), figsize=(11, 6)) ax = P.axarr.ravel() keypairorder = [] for i, stk in enumerate(stm.keys()): keypairorder.append(stk) data = stm[stk] idat = np.array(data["i"]) ax[i].plot(np.array(data["t"]), np.mean(idat, axis=0), "c-", linewidth=1.5) sd = np.std(idat, axis=0) ax[i].plot( np.array(data["t"]), np.mean(idat, axis=0) + sd, "c--", linewidth=0.5 ) # for j in range(idat.shape[0]): ax[i].plot( np.array(data["t"]), np.mean(idat, axis=0) - sd, "c--", linewidth=0.5 ) # for j in range(idat.shape[0]): rel = 0 for j in range(idat.shape[0]): if np.min(idat[j, :]) < -1e-2 or np.max(idat[j, :]) > 1e-2: rel += 1 print(f"{str(stk):s} {rel:d} of {idat.shape[0]:d} {str(idat.shape):s}") # ax[i].plot(data['t'], idat[j], 'k-', linewidth=0.5, alpha=0.25) ax[i].set_title("%s : %s" % (stk[0], stk[1]), fontsize=7) if sts is not None: # plot the matching exp2syn on this for i, stks in enumerate(keypairorder): datas = sts[stks] isdat = np.array(datas["i"]) ax[i].plot(datas["t"], np.mean(isdat, axis=0), "m-", linewidth=3, alpha=0.5) mpl.show() def print_exp2syn_fits(st): stkeys = list(st.keys()) postcells = [ "bushy", "tstellate", "dstellate", "octopus", "pyramidal", "tuberculoventral", "cartwheel", ] fmts = { "weight": "{0:<18.6f}", "GWeight": "{0:<18.6f}", "tau1": "{0:<18.3f}", "tau2": "{0:<18.3f}", "delay": "{0:<18.3f}", "erev": "{0:<18.1f}", } pars = list(fmts.keys()) for pre in ["sgc", "dstellate", "tuberculoventral"]: firstline = False vrow = dict.fromkeys(pars, False) for v in pars: for post in postcells: if (pre, post) not in stkeys: print("{0:<18s}".format("0"), end="") continue if not firstline: print("\n%s" % pre.upper()) print("{0:<18s}".format(" "), end="") for i in range(len(postcells)): print("{0:<18s}".format(postcells[i]), end="") print() firstline = True if firstline: if not vrow[v]: if v == "tauratio": print("{0:<18s}".format("tau2"), end="") else: print("{0:<18s}".format(v), end="") vrow[v] = True # print(fits[(pre, post)].keys()) fits = st[(pre, post)]["pars"] # print(v, fits) if v in ["tauratio", "tau2"]: print( fmts[v].format(fits["tau1"].value * fits["tauratio"].value), end="", ) else: print(fmts[v].format(fits[v].value), end="") print() def read_pickle(mode): """ Read the pickled file generated by runall Parameters ---------- mode : str either 'simple' or 'multisite' Returns: the resulting data, which is a dictionary """ with (open(Path(mode).with_suffix(".pkl"), "rb")) as fh: st = pickle.load(fh) return st def run_all(): """ Run all of the multisite synapse calculations for each cell pair Save the results in the multisite.pkl file """ st = {} pkst = {} for pre in convergence.keys(): for post in convergence[pre]: # if pre == 'sgc' and post in ['bushy', 'tstellate']: sti = compute_psc(synapsetype="multisite", celltypes=[pre, post]) st[(pre, post)] = sti # remove neuron objects before pickling pkst[(pre, post)] = { "t": sti["t"], "i": sti.isoma, "v": sti["v_soma"], "pre": sti["v_pre"], } with (open(Path("multisite").with_suffix(".pkl"), "wb")) as fh: pickle.dump(pkst, fh) plot_all(pkst) def run_one(pre, post): """ Run all of the multisite synapse calculations for each cell pair Save the results in the multisite.pkl file """ assert pre in ["sgc", "tstellate", "dstellate", "tuberculoventral"] assert post in ["bushy", "tstellate", "dstellate", "tuberculoventral", "pyramidal"] st = {} pkst = read_pickle("multisite") # if pre == 'sgc' and post in ['bushy', 'tstellate']: sti = compute_psc(synapsetype="multisite", celltypes=[pre, post]) st[(pre, post)] = sti # remove neuron objects before pickling pkst[(pre, post)] = { "t": sti["t"], "i": sti.isoma, "v": sti["v_soma"], "pre": sti["v_pre"], } with (open(Path("multisite").with_suffix(".pkl"), "wb")) as fh: pickle.dump(pkst, fh) plot_all(pkst) def main(): parser = argparse.ArgumentParser( description="Compare simple and multisite synapses" ) parser.add_argument( "-m", "--mode", type=str, dest="mode", default="None", choices=["simple", "multisite", "both"], help="Select mode [simple, multisite, compare]", ) parser.add_argument( "-r", "--run", action="store_true", dest="run", help="Run multisite models between all cell types", ) parser.add_argument( "-o", "--runone", action="store_true", dest="runone", help="Run multisite models between specified cell types", ) parser.add_argument( "--pre", type=str, dest="pre", default="None", help="Select presynaptic cell for runone", ) parser.add_argument( "--post", type=str, dest="post", default="None", help="Select postsynaptic cell for runone", ) parser.add_argument( "-f", "--fit", action="store_true", dest="fit", help="Fit exp2syn waveforms to multisite data", ) parser.add_argument( "-t", "--test", action="store_true", dest="test", help="Run the test on exp2syn fitting", ) parser.add_argument( "-p", "--plot", action="store_true", dest="plot", help="Plot the current comparison between simple and multisite", ) parser.add_argument( "-l", "--list", action="store_true", dest="list", help="List the simple fit parameters", ) args = parser.parse_args() if args.test: testexp() exit() if args.run: run_all() exit() if args.runone: run_one(args.pre, args.post) exit() if args.fit: fit_all() # self contained - always fits exp2syn against the current multistie data ds = read_pickle("simple") dm = read_pickle("multisite") plot_all(dm, ds) exit() if args.plot: if args.mode in ["simple", "multisite"]: d = read_pickle(args.mode) plot_all(d) exit() elif args.mode in ["both"]: ds = read_pickle("simple") dm = read_pickle("multisite") plot_all(dm, ds) exit() else: print(f"Mode {args.mode:s} is not valid") exit() if args.list: ds = read_pickle("simple") print_exp2syn_fits(ds) # if sys.flags.interactive == 0: # pg.QtGui.QApplication.exec_() if __name__ == "__main__": main()