""" test_bushy_variation.py Test inputs to bushy cells as we co-vary the KLT and IH conductances Two tests: IV and spikes. Simulation results are first written to disk (.p files); plotting is done separately Usage: python test_bushy_variation.py [a, b] a runs IV curvers with variations of gKlt/gh b runs PSTHs to AN input (CF tones) across the same variations. """ import sys import numpy as np import pyqtgraph as pg import pyqtgraph.multiprocess as mp import pickle import time from neuron import h from cnmodel.protocols import Protocol from cnmodel.protocols import iv_curve from cnmodel import cells from cnmodel.util import sound from cnmodel.util import custom_init from cnmodel.util import make_pulse import cnmodel.util.pynrnutilities as PU from cnmodel import data import matplotlib.pyplot as mpl import cnmodel.util.PlotHelpers as PH import timeit synapseType = "multisite" # 'simple' species = "mouse" # tables for other species do not yet exist class RunTrial: def __init__(self, post_cell, info): """ info is a dict """ pre_cells = [] synapses = [] j = 0 xmtr = {} for nsgc, sgc in enumerate(range(info["n_sic0"])): pre_cells.append(cells.DummySGC(cf=info["cf"], sr=info["sr"])) if synapseType == "simple": synapses.append(pre_cells[-1].connect(post_cell, type=synapseType)) synapses[-1].terminal.netcon.weight[0] = info["gmax"] * 2.0 elif synapseType == "multisite": synapses.append( pre_cells[-1].connect( post_cell, post_opts={"AMPAScale": 2.0, "NMDAScale": 2.0}, type=synapseType, ) ) for i in range(synapses[-1].terminal.n_rzones): xmtr["xmtr%04d" % j] = h.Vector() xmtr["xmtr%04d" % j].record( synapses[-1].terminal.relsite._ref_XMTR[i] ) j = j + 1 synapses[ -1 ].terminal.relsite.Dep_Flag = ( False ) # no depression in these simulations pre_cells[-1].set_sound_stim( info["stim"], seed=info["seed"] + nsgc, simulator=str(info["simulator"]) ) Vm = h.Vector() Vm.record(post_cell.soma(0.5)._ref_v) rtime = h.Vector() rtime.record(h._ref_t) h.tstop = 1e3 * info["run_duration"] # duration of a run h.celsius = info["temp"] h.dt = info["dt"] post_cell.cell_initialize() info["init"]() h.t = 0.0 h.run() return { "time": np.array(rtime), "vm": Vm.to_python(), "xmtr": xmtr, "pre_cells": pre_cells, "post_cell": post_cell, "synapses": synapses, } class SGCInputTestPSTH(Protocol): def set_cell(self, cell="bushy"): self.cell = cell self.parallelize = False def run( self, temp=34.0, dt=0.025, seed=575982035, reps=10, stimulus="tone", simulator="cochlea", ): assert stimulus in ["tone", "SAM", "clicks"] # cases available assert self.cell in ["bushy", "tstellate", "octopus", "dstellate"] self.nrep = reps self.stimulus = stimulus self.run_duration = 0.20 # in seconds self.pip_duration = 0.05 # in seconds self.pip_start = [0.1] # in seconds self.Fs = 100e3 # in Hz self.f0 = 4000.0 # stimulus in Hz self.cf = 4000.0 # SGCs in Hz self.fMod = 100.0 # mod freq, Hz self.dMod = 0.0 # % mod depth, Hz self.dbspl = 50.0 self.simulator = str(simulator) self.sr = 1 # set SR group if self.stimulus == "SAM": self.stim = sound.SAMTone( rate=self.Fs, duration=self.run_duration, f0=self.f0, fmod=self.fMod, dmod=self.dMod, dbspl=self.dbspl, ramp_duration=2.5e-3, pip_duration=self.pip_duration, pip_start=self.pip_start, ) if self.stimulus == "tone": self.f0 = 4000.0 self.cf = 4000.0 self.stim = sound.TonePip( rate=self.Fs, duration=self.run_duration, f0=self.f0, dbspl=self.dbspl, ramp_duration=2.5e-3, pip_duration=self.pip_duration, pip_start=self.pip_start, ) if self.stimulus == "clicks": self.click_rate = 0.020 # msec self.stim = sound.ClickTrain( rate=self.Fs, duration=self.run_duration, f0=self.f0, dbspl=self.dbspl, click_start=0.010, click_duration=100.0e-6, click_interval=self.click_rate, nclicks=int((self.run_duration - 0.01) / self.click_rate), ramp_duration=2.5e-3, ) n_sgc = data.get( "convergence", species=species, post_type=self.cell, pre_type="sgc" )[0] self.n_sgc = int(np.round(n_sgc)) # for simple synapses, need this value: self.AMPA_gmax = ( data.get( "sgc_synapse", species=species, post_type=self.cell, field="AMPA_gmax" )[0] / 1e3 ) # convert nS to uS for NEURON self.vms = [None for n in range(self.nrep)] self.synapses = [None for n in range(self.nrep)] self.xmtrs = [None for n in range(self.nrep)] self.pre_cells = [None for n in range(self.nrep)] self.time = [None for n in range(self.nrep)] info = { "n_sgc": self.n_sgc, "gmax": self.AMPA_gmax, "stim": self.stim, "simulator": self.simulator, "cf": self.cf, "sr": self.sr, "seed": seed, "run_duration": self.run_duration, "temp": temp, "dt": dt, "init": custom_init, } if not self.parallelize: for nr in range(self.nrep): info["seed"] = seed + 3 * self.n_sgc * nr res = RunTrial(self.cell, info) # res contains: {'time': time, 'vm': Vm, 'xmtr': xmtr, 'pre_cells': pre_cells, 'post_cell': post_cell} self.pre_cells[nr] = res["pre_cells"] self.time[nr] = res["time"] self.xmtr = {k: v.to_python() for k, v in list(res["xmtr"].items())} self.vms[nr] = res["vm"] self.synapses[nr] = res["synapses"] self.xmtrs[nr] = self.xmtr if self.parallelize: ### Use parallelize with multiple workers tasks = list(range(len(self.nrep))) results3 = results[:] start = time.time() # with mp.Parallelize(enumerate(tasks), results=results, progressDialog='processing in parallel..') as tasker: with mp.Parallelize(enumerate(tasks), results=results) as tasker: for i, x in tasker: tot = 0 for j in range(size): tot += j * x tasker.results[i] = tot print( ( "\nParallel time, %d workers: %0.2f" % (mp.Parallelize.suggestedWorkerCount(), time.time() - start) ) ) print(("Results match serial: %s" % str(results3 == results))) def show(self): self.win = pg.GraphicsWindow() self.win.setBackground("w") Fs = self.Fs p1 = self.win.addPlot( title="Stimulus", row=0, col=0, labels={"bottom": "T (ms)", "left": "V"} ) p1.plot(self.stim.time * 1000, self.stim.sound, pen=pg.mkPen("k", width=0.75)) p1.setXLink(p1) p2 = self.win.addPlot( title="AN spikes", row=1, col=0, labels={"bottom": "T (ms)", "left": "AN spikes (first trial)"}, ) for nr in range(self.nrep): xan = [] yan = [] for k in range(len(self.pre_cells[nr])): r = self.pre_cells[nr][k]._spiketrain xan.extend(r) yr = k + np.zeros_like(r) + 0.2 yan.extend(yr) c = pg.PlotCurveItem() xp = np.repeat(np.array(xan), 2) yp = np.repeat(np.array(yan), 2) yp[1::2] = yp[::2] + 0.6 c.setData( xp.flatten(), yp.flatten(), connect="pairs", pen=pg.mkPen(pg.intColor(nr, self.nrep), hues=self.nrep, width=1.0), ) p2.addItem(c) p2.setXLink(p1) p3 = self.win.addPlot( title="%s Spikes" % self.cell, row=2, col=0, labels={"bottom": "T (ms)", "left": "Trial #"}, ) xcn = [] ycn = [] xspks = [] for k in range(self.nrep): bspk = PU.findspikes(self.time[k], self.vms[k], -35.0) xcn.extend(bspk) yr = k + np.zeros_like(bspk) + 0.2 ycn.extend(yr) d = pg.PlotCurveItem() xp = np.repeat(np.array(xcn), 2) yp = np.repeat(np.array(ycn), 2) yp[1::2] = yp[::2] + 0.6 d.setData( xp.flatten(), yp.flatten(), connect="pairs", pen=pg.mkPen("k", width=1.5) ) p3.addItem(d) p3.setXLink(p1) p4 = self.win.addPlot( title="%s Vm" % self.cell, row=3, col=0, labels={"bottom": "T (ms)", "left": "Vm (mV)"}, ) for nr in range(self.nrep): p4.plot( self.time[nr], self.vms[nr], pen=pg.mkPen(pg.intColor(nr, self.nrep), hues=self.nrep, width=1.0), ) p4.setXLink(p1) p5 = self.win.addPlot( title="xmtr", row=0, col=1, labels={"bottom": "T (ms)", "left": "gSyn"} ) if synapseType == "multisite": for nr in [0]: syn = self.synapses[nr] j = 0 for k in range(self.n_sgc): synapse = syn[k] for i in range(synapse.terminal.n_rzones): p5.plot( self.time[nr], self.xmtrs[nr]["xmtr%04d" % j], pen=pg.mkPen( pg.intColor(nr, self.nrep), hues=self.nrep, width=1.0 ), ) j = j + 1 p5.setXLink(p1) p6 = self.win.addPlot( title="AN PSTH", row=1, col=1, labels={"bottom": "T (ms)", "left": "Sp/ms/trial"}, ) bins = np.arange(0, 200, 1) (hist, binedges) = np.histogram(xan, bins) curve6 = p6.plot( binedges, hist, stepMode=True, fillBrush=(0, 0, 0, 255), brush=pg.mkBrush("k"), fillLevel=0, ) p7 = self.win.addPlot( title="%s PSTH" % self.cell, row=2, col=1, labels={"bottom": "T (ms)", "left": "Sp/ms/trial"}, ) bins = np.arange(0, 200, 1) (hist, binedges) = np.histogram(xcn, bins) curve7 = p7.plot( binedges, hist, stepMode=True, fillBrush=(0, 0, 0, 255), brush=pg.mkBrush("k"), fillLevel=0, ) self.win.show() class Variations(Protocol): def __init__(self, runtype, runname, simulator): self.runtype = runtype self.runname = runname self.simulator = str(simulator) self.npost = 5 # number of post cells to test self.npre = 3 # number of presynaptic cells self.reset() def reset(self): super(Variations, self).reset() def make_cells(self, cf=16e3, temp=34.0, dt=0.025): self.pre_cells = [] for n in range(self.npre): self.pre_cells.append(cells.DummySGC(cf=cf, sr=3)) self.post_cells = [] for n in range(self.npost): self.post_cell = cells.Bushy.create(species=species) self.post_cells.append(self.post_cell) for n in range(self.npre): for m in range(self.npost): synapse = self.pre_cells[n].connect(self.post_cells[m]) self.synapse = synapse synapse.terminal.relsite.Dep_Flag = False # make variations in the postsynaptic cells varsg = [0.5, 0.75, 1.0, 1.5, 2.0] for i, m in enumerate(range(self.npost)): refgbar_klt = self.post_cells[m].soma().klt.gbar refgbar_ih = self.post_cells[m].soma().ihvcn.gbar self.post_cells[m].soma().klt.gbar = refgbar_klt * varsg[i] self.post_cells[m].soma().ihvcn.gbar = refgbar_ih * varsg[i] # self.stim = sound.TonePip(rate=100e3, duration=0.1, f0=4000, dbspl=80, # ramp_duration=2.5e-3, pip_duration=0.04, # pip_start=[0.02]) # # preCell.set_sound_stim(self.stim, seed=seed) # # self['vm'] = postCell.soma(0.5)._ref_v # #self['prevm'] = preCell.soma(0.5)._ref_v # for i in range(30): # self['xmtr%d'%i] = synapse.terminal.relsite._ref_XMTR[i] # synapse.terminal.relsite.Dep_Flag = False def make_stimulus( self, stimulus="tone", cf=16000.0, f0=16000.0, simulator=None, rundur=0.2, pipdur=0.05, dbspl=50.0, fmod=100.0, dmod=0.0, ): self.stimulus = stimulus self.run_duration = rundur # in seconds self.pip_duration = pipdur # in seconds self.pip_start = [0.1] # in seconds self.Fs = 100e3 # in Hz self.f0 = f0 # stimulus in Hz self.cf = cf # SGCs in Hz self.fMod = fmod # mod freq, Hz self.dMod = dmod # % mod depth, Hz self.dbspl = dbspl # self.simulator = str(simulator) self.sr = 1 # set SR group if self.stimulus == "SAM": self.stim = sound.SAMTone( rate=self.Fs, duration=self.run_duration, f0=self.f0, fmod=self.fMod, dmod=self.dMod, dbspl=self.dbspl, ramp_duration=2.5e-3, pip_duration=self.pip_duration, pip_start=self.pip_start, ) if self.stimulus == "tone": self.stim = sound.TonePip( rate=self.Fs, duration=self.run_duration, f0=self.f0, dbspl=self.dbspl, ramp_duration=2.5e-3, pip_duration=self.pip_duration, pip_start=self.pip_start, ) if self.stimulus == "clicks": self.click_rate = 0.020 # msec self.stim = sound.ClickTrain( rate=self.Fs, duration=self.run_duration, f0=self.f0, dbspl=self.dbspl, click_start=0.010, click_duration=100.0e-6, click_interval=self.click_rate, nclicks=int((self.run_duration - 0.01) / self.click_rate), ramp_duration=2.5e-3, ) def run(self, mode="IV", cf=16e3, temp=34.0, dt=0.025, stimamp=0, iinj=[1.0]): self.dt = dt self.temp = temp self.make_cells(cf, temp, dt) print(dir(self.pre_cells[0])) seed = 0 j = 0 if mode == "sound": self.make_stimulus(stimulus="tone") for np in range(len(self.pre_cells)): self.pre_cells[np].set_sound_stim(self.stim, seed=seed) seed += 1 synapses.append( pre_cells[-1].connect( post_cell, post_opts={"AMPAScale": 2.0, "NMDAScale": 2.0}, type=synapseType, ) ) for i in range(synapses[-1].terminal.n_rzones): xmtr["xmtr%04d" % j] = h.Vector() xmtr["xmtr%04d" % j].record( synapses[-1].terminal.relsite._ref_XMTR[i] ) j = j + 1 synapses[ -1 ].terminal.relsite.Dep_Flag = ( False ) # no depression in these simulations print("setup to run") self.stim_params = [] self.istim = [] self.istims = [] if mode == "pulses": for i, pre_cell in enumerate(self.pre_cells): stim = {} stim["NP"] = 10 stim["Sfreq"] = 100.0 # stimulus frequency stim["delay"] = 10.0 stim["dur"] = 0.5 stim["amp"] = stimamp stim["PT"] = 0.0 stim["dt"] = dt (secmd, maxt, tstims) = make_pulse(stim) self.stim_params.append(stim) stim["amp"] = 0.0 (secmd2, maxt2, tstims2) = make_pulse(stim) # istim current pulse train istim = h.iStim(0.5, sec=pre_cell.soma) i_stim_vec = h.Vector(secmd) self.istim.append((istim, i_stim_vec)) self["istim%02d" % i] = istim._ref_i self.postpars = [] self.poststims = [] if mode == "IV": for i, post_cell in enumerate(self.post_cells): pstim = {} pstim["NP"] = 1 pstim["Sfreq"] = 100.0 # stimulus frequency pstim["delay"] = 10.0 pstim["dur"] = 100 pstim["amp"] = iinj[0] pstim["PT"] = 0.0 pstim["dt"] = dt (secmd, maxt, tstims) = make_pulse(pstim) self.postpars.append(pstim) # istim current pulse train pistim = h.iStim(0.5, sec=post_cell.soma) pi_stim_vec = h.Vector(secmd) self.poststims.append((pistim, pi_stim_vec)) self["poststim%02d" % i] = pistim._ref_i # # Run simulation # h.celsius = temp self.temp = temp # first find rmp for each cell for m in range(self.npost): self.post_cells[m].vm0 = None self.post_cells[m].cell_initialize() # set starting voltage... self.post_cells[m].soma(0.5).v = self.post_cells[m].vm0 h.dt = 0.02 h.t = 0 # run a bit to find true stable rmp h.tstop = 20.0 h.batch_save() h.batch_run(h.tstop, h.dt, "v.dat") # order matters: don't set these up until we need to self["t"] = h._ref_t # set up recordings if mode == "pulses": for i in range(self.npre): self.istim[i][1].play(self.istim[i][0]._ref_i, dt, 0) self["v_pre%02d" % i] = self.pre_cells[i].soma(0.5)._ref_v for m in range(self.npost): if mode == "IV": self.poststims[m][1].play(self.poststims[m][0]._ref_i, dt, 0) self["v_post%02d" % m] = self.post_cells[m].soma(0.5)._ref_v h.finitialize() # init and instantiate recordings print("running") h.t = 0.0 h.tstop = 200.0 h.batch_run(h.tstop, h.dt, "v.dat") # while h.t < h.tstop: # get data (do not use h.run() - try it and see why!) # h.fadvance() def runIV(self, parallelize): self.civ = {} self.iiv = [] varsg = np.linspace( 0.25, 2.0, int((2.0 - 0.25) / 0.25) + 1 ) # [0.5, 0.75, 1.0, 1.5, 2.0] # covary Ih and gklt in constant ratio self.gklts = np.zeros(len(varsg)) self.ghs = np.zeros(len(varsg)) if not parallelize: for n in range(self.npost): self.civ[n] = [] # self.iiv[c] = [] start = time.time() for inj in np.arange(-1.0, 1.51, 0.5): self.run(mode="IV", temp=34.0, dt=0.025, stimamp=10, iinj=[inj]) print("ran for current = ", inj) for c in range(self.npost): self.civ[c].append(self["v_post%02d" % c]) if c == 0: # just the first self.iiv.append(self["poststim%02d" % c]) print(("\nSerial time, %0.2f" % (time.time() - start))) if runname is not None: f = open(runname, "w") pickle.dump({"t": self["t"], "v": self.civ, "i": self.iiv}, f) f.close() else: # mp.parallelizer.multiprocessing.cpu_count() nworker = 16 self.npost = len(varsg) tasks = list(range(self.npost)) results = [None] * len(tasks) ivc = [None] * len(tasks) start = time.time() # with mp.Parallelize(enumerate(tasks), results=results, workers=nworker, progressDialog='processing in parallel..') as tasker: with mp.Parallelize( enumerate(tasks), results=results, workers=nworker ) as tasker: for i, x in tasker: post_cell = cells.Bushy.create(species=species) refgbar_klt = post_cell.soma().klt.gbar refgbar_ih = post_cell.soma().ihvcn.gbar gklts = refgbar_klt * varsg[i] ghs = refgbar_ih * varsg[i] post_cell.soma().klt.gbar = gklts post_cell.soma().ihvcn.gbar = ghs post_cell.initial_mechanisms = ( None ) # forget the mechanisms we set up initially post_cell.save_all_mechs() # and save new ones because we are explicitely varying them ivc[i] = iv_curve.IVCurve() ivc[i].run({"pulse": [(-1.0, 1.5, 0.25)]}, post_cell) tasker.results[i] = { "v": ivc[i].voltage_traces, "i": ivc[i].current_traces, "t": ivc[i].time_values, "gklt": gklts, "gh": ghs, } print( ( "\nParallel time: %d workers, %0.2f sec" % (nworker, time.time() - start) ) ) cell_info = {"varrange": varsg} print(cell_info) res = {"cells": cell_info, "results": results} if runname is not None: f = open(runname, "wb") pickle.dump(res, f, -1) f.close() def runSound(self, parallelize=False): self.civ = {} self.iiv = [] if not parallelize: pass if parallelize: nworker = 16 varsg = np.linspace( 0.25, 2.0, int((2.0 - 0.25) / 0.25) + 1 ) # [0.5, 0.75, 1.0, 1.5, 2.0] # covary Ih and gklt in constant ratio self.npost = len(varsg) nrep = 25 tasks = list(range(self.npost)) results = [None] * len(tasks) ivc = [None] * len(tasks) gklts = np.zeros(len(varsg)) ghs = np.zeros(len(varsg)) start = time.time() seed = 0 cf = 16000.0 f0 = 16000.0 rundur = 0.25 # seconds pipdur = 0.1 # seconds dbspl = 50.0 fmod = 40.0 dmod = 0.0 stimulus = "tone" # with mp.Parallelize(enumerate(tasks), results=results, workers=nworker, progressDialog='processing in parallel..') as tasker: with mp.Parallelize( enumerate(tasks), results=results, workers=nworker ) as tasker: for i, x in tasker: post_cell = cells.Bushy.create(species=species) h.celsius = 34 self.temp = h.celsius refgbar_klt = post_cell.soma().klt.gbar refgbar_ih = post_cell.soma().ihvcn.gbar gklts[i] = refgbar_klt * varsg[i] ghs[i] = refgbar_ih * varsg[i] post_cell.soma().klt.gbar = gklts[i] post_cell.soma().ihvcn.gbar = ghs[i] post_cell.initial_mechanisms = ( None ) # forget the mechanisms we set up initially post_cell.save_all_mechs() # and save new ones because we are explicitely varying them self.make_stimulus( stimulus=stimulus, cf=cf, f0=f0, rundur=rundur, pipdur=pipdur, dbspl=50.0, simulator=self.simulator, fmod=fmod, dmod=dmod, ) pre_cells = [] synapses = [] for n in range(self.npre): pre_cells.append(cells.DummySGC(cf=cf, sr=2)) synapses.append( pre_cells[n].connect(post_cell, type=synapseType) ) v_reps = [] i_reps = [] p_reps = [] # pre spike on 0'th sgc for j in range(nrep): for prec in range(len(pre_cells)): pre_cells[prec].set_sound_stim( self.stim, seed=seed, simulator=self.simulator ) seed += 1 # for i in range(synapses[-1].terminal.n_rzones): # xmtr['xmtr%04d'%j] = h.Vector() # xmtr['xmtr%04d'%j].record(synapses[-1].terminal.relsite._ref_XMTR[i]) # j = j + 1 # synapses[-1].terminal.relsite.Dep_Flag = False # no depression in these simulations # # Run simulation # post_cell.vm0 = None post_cell.cell_initialize() # set starting voltage... post_cell.soma(0.5).v = post_cell.vm0 h.dt = 0.02 h.t = 0 # run a bit to find true stable rmp h.tstop = 20.0 h.batch_save() h.batch_run(h.tstop, h.dt, "v.dat") self["t"] = h._ref_t # set up recordings self["v_post%02d" % j] = post_cell.soma(0.5)._ref_v h.finitialize() # init and instantiate recordings print("running %d" % i) h.t = 0.0 h.tstop = rundur * 1000.0 # rundur is in seconds. post_cell.check_all_mechs() # make sure no further changes were introduced before run. h.batch_run(h.tstop, h.dt, "v.dat") v_reps.append(self["v_post%02d" % j]) i_reps.append(0.0 * self["v_post%02d" % j]) p_reps.append(pre_cells[0]._stvec.to_python()) tasker.results[i] = { "v": v_reps, "i": i_reps, "t": self["t"], "pre": pre_cells[0]._stvec.to_python(), } print( ( "\nParallel time: %d workers, %0.2f sec" % (nworker, time.time() - start) ) ) cell_info = {"gklt": gklts, "gh": ghs} stim_info = { "nreps": nrep, "cf": cf, "f0": f0, "rundur": rundur, "pipdur": pipdur, "dbspl": dbspl, "fmod": fmod, "dmod": dmod, } res = {"cells": cell_info, "stim": stim_info, "results": results} if runname is not None: f = open(runname, "wb") pickle.dump(res, f, -1) f.close() # # def show(self): # # self.win = pg.GraphicsWindow() # self.win.resize(1000, 1000) # # cmd_plot = self.win.addPlot(title='Stim') # for i in range(len(self.pre_cells)): # cmd_plot.plot(self['t'], self['istim%02d' % i], pen=pg.mkPen(pg.intColor(i, len(self.pre_cells)), hues=len(self.pre_cells), width=1.0)) # # self.win.nextRow() # pre_plot = self.win.addPlot(title='SGC Vm') # for i in range(len(self.pre_cells)): # pre_plot.plot(self['t'], self['v_pre%02d'%i], pen=pg.mkPen(pg.intColor(i, len(self.pre_cells)), hues=len(self.pre_cells), width=1.0)) # # self.win.nextRow() # post_plot = self.win.addPlot(title='Post Cell: %s' % self.post_cell.type) # for m in range(self.npost): # post_plot.plot(self['t'], self['v_post%02d' % m], # pen=pg.mkPen(pg.intColor(m, len(self.post_cells)), hues=len(self.post_cells), width=1.0)) def showpicklediv(name): f = open(name, "rb") result = pickle.load(f) f.close() d = result["results"] ncells = len(d) vr = result["cells"]["varrange"] fig, ax = mpl.subplots(ncells + 1, 2, figsize=(8.5, 11.0)) # fig.set_size_inches(8.5, 11., forward=True) for ni in range(len(d[0]["i"])): ax[-1, 0].plot(d[0]["t"], d[0]["i"][ni], "k", linewidth=0.5) ax[-1, 0].set_ylim([-2.0, 2.0]) for nc in range(ncells): for ni in range(len(d[nc]["v"])): ax[nc, 0].plot(d[nc]["t"], d[nc]["v"][ni], "k", linewidth=0.5) ax[nc, 0].set_ylim([-180.0, 40.0]) if ni == 0: ax[nc, 0].annotate("%.2f" % vr[nc], (180.0, 20.0)) PH.nice_plot(ax.ravel().tolist()) PH.noaxes(ax.ravel()[:-1].tolist()) PH.calbar( ax[0, 0], calbar=[120.0, -130.0, 25.0, 50.0], axesoff=True, orient="left", unitNames={"x": "ms", "y": "mV"}, fontsize=9, weight="normal", font="Arial", ) PH.calbar( ax[-1, 0], calbar=[120.0, 0.5, 25.0, 1.0], axesoff=True, orient="left", unitNames={"x": "ms", "y": "nA"}, fontsize=9, weight="normal", font="Arial", ) mpl.show() def vector_plot(f, r, l, yp=None): ax2 = f.add_axes([yp.x1 - 0.06, yp.y1 - 0.06, 0.05, 0.05], polar=True) r = np.repeat(r, 3) l = np.repeat(l, 3) for i in range(2, len(l), 3): l[i] = 0.0 l[i - 2] = 0.0 ax2.plot(r, l, lw=0.5) # ax2.arrow(0, 0, np.mean(l), np.mean(r), head_width=0.05, head_length=-0.1, fc='r', ec='r') def phase_hist(f, spkphase, yp=None): ax2 = f.add_axes([yp.x1 - 0.06, yp.y1 - 0.06, 0.05, 0.05]) n, bins = np.histogram(spkphase, np.linspace(0.0, 2 * np.pi, 91.0), density=False) ax2.bar(bins[:-1], n, width=bins[1], facecolor="k", alpha=0.75) def clean_spiketimes(spikeTimes, mindT=0.7): """ Clean up spike time array, removing all less than mindT spikeTimes is a 1-D list or array mindT is difference in time, same units as spikeTimes If 1 or 0 spikes in array, just return the array """ if len(spikeTimes) > 1: dst = np.diff(spikeTimes) st = np.array(spikeTimes[0]) # get first spike sok = np.where(dst > mindT) st = np.append(st, [spikeTimes[s + 1] for s in sok]) # print st spikeTimes = st return spikeTimes def showplots(name): """ Show traces from sound stimulation - without current injection """ f = open(name, "rb") d = pickle.load(f) f.close() ncells = len(d["results"]) stiminfo = d["stim"] dur = stiminfo["rundur"] * 1000.0 print("dur: ", dur) print("stim info: ") print(" fmod: ", stiminfo["fmod"]) print(" dmod: ", stiminfo["dmod"]) print(" f0: ", stiminfo["f0"]) print(" cf: ", stiminfo["cf"]) varsg = np.linspace(0.25, 2.0, int((2.0 - 0.25) / 0.25) + 1) # was not stored... fig, ax = mpl.subplots(ncells + 1, 2, figsize=(8.5, 11.0)) spikelists = [[]] * ncells prespikes = [[]] * ncells xmin = 50.0 for i in range(ncells): vdat = d["results"][i]["v"] idat = d["results"][i]["i"] tdat = d["results"][i]["t"] pdat = d["results"][i]["pre"] PH.noaxes(ax[i, 0]) # if i == 0: # PH.calbar(ax[0, 0], calbar=[120., -120., 25., 20.], axesoff=True, orient='left', # unitNames={'x': 'ms', 'y': 'mV'}, fontsize=9, weight='normal', font='Arial') for j in range(len(vdat)): if j == 2: ax[i, 0].plot(tdat - xmin, vdat[j], "k", linewidth=0.5) if j == 0: ax[i, 0].annotate("%.2f" % varsg[i], (180.0, 20.0)) ax[i, 0].set_xlim([0, dur - xmin]) ax[i, 0].set_ylim([-75, 50]) PH.referenceline( ax[i, 0], reference=-62.0, limits=None, color="0.33", linestyle="--", linewidth=0.5, dashes=[3, 3], ) for j in range(len(vdat)): detected = PU.findspikes(tdat, vdat[j], -20.0) detected = clean_spiketimes(detected) spikelists[i].extend(detected) if j == 0: n, bins = np.histogram( detected, np.linspace(0.0, dur, 201), density=False ) else: m, bins = np.histogram( detected, np.linspace(0.0, dur, 201), density=False ) n += m prespikes[i].extend(pdat) if j == 0: n, bins = np.histogram(pdat, np.linspace(0.0, dur, 201), density=False) else: m, bins = np.histogram(pdat, np.linspace(0.0, dur, 201), density=False) n += m ax[i, 1].bar(bins[:-1] - xmin, n, width=bins[1], facecolor="k", alpha=0.75) ax[i, 1].set_xlim([0, dur - xmin]) ax[i, 1].set_ylim([0, 30]) vs = PU.vector_strength(spikelists[i], stiminfo["fmod"]) pre_vs = PU.vector_strength(prespikes[i], stiminfo["fmod"]) # print 'pre: ', pre_vs # print 'post: ', vs # apos = ax[i,1].get_position() # ax[i, 1].set_title('VS = %4.3f' % pre_vs['r']) # # vector_plot(fig, vs['ph'], np.ones(len(vs['ph'])), yp = apos) # phase_hist(fig, vs['ph'], yp=apos) # phase_hist(fig, pre_vs['ph'], yp=apos) prot = Variations(runtype, runname, "cochlea") # stim_info = {'nreps': nrep, 'cf': cf, 'f0': f0, 'rundur': rundur, 'pipdur': pipdur, 'dbspl': dbspl, 'fmod': fmod, 'dmod': dmod} if stiminfo["dmod"] > 0: stimulus = "SAM" else: stimulus = "tone" prot.make_stimulus( stimulus=stimulus, cf=stiminfo["cf"], f0=stiminfo["f0"], simulator=None, rundur=stiminfo["rundur"], pipdur=stiminfo["pipdur"], dbspl=stiminfo["dbspl"], fmod=stiminfo["fmod"], dmod=stiminfo["dmod"], ) ax[-1, 1].plot( prot.stim.time * 1000.0 - xmin, prot.stim.sound, "k-", linewidth=0.75 ) ax[-1, 1].set_xlim([0, (dur - xmin)]) PH.noaxes(ax[-1, 0]) ax[-1, 0].set_xlim([0, dur - xmin]) ax[-1, 0].set_ylim([-75, 50]) # PH.referenceline(ax[-1, 0], reference=-62.0, limits=None, color='0.33', linestyle='--' ,linewidth=0.5, dashes=[3, 3]) PH.calbar( ax[-1, 0], calbar=[20.0, 0.0, 25.0, 20.0], axesoff=True, orient="left", unitNames={"x": "ms", "y": "mV"}, fontsize=9, weight="normal", font="Arial", ) PH.cleanAxes(ax.ravel().tolist()) mpl.show() if __name__ == "__main__": runname = None panel = None if len(sys.argv) == 2: panel = sys.argv[1] if panel == "a": runtype = "IV" runname = "Figure6_IV" elif panel == "d": runtype = "sound" runname = "Figure6_AN" else: runtype = panel if panel is None: raise ValueError("Must specify figure panel to generate: 'a', 'b'") if runtype in ["sound", "IV"]: prot = Variations(runtype, runname, "cochlea") if runtype == "IV": start_time = timeit.default_timer() prot.runIV(parallelize=True) elapsed = timeit.default_timer() - start_time print(("Elapsed time for IV simulations: %f" % (elapsed))) showpicklediv(runname) if runtype == "sound": start_time = timeit.default_timer() prot.runSound(parallelize=True) elapsed = timeit.default_timer() - start_time print(("Elapsed time for AN simulations: %f" % (elapsed))) showplots(runname) # pg.show() # if sys.flags.interactive == 0: # pg.QtGui.QApplication.exec_() elif runtype in ["showiv"]: showpicklediv(runname) elif runtype in ["plots"]: showplots(runname) else: print("run type should be one of sound, IV, showiv, plots")