import sys import argparse import numpy as np import pyqtgraph as pg from neuron import h from cnmodel.protocols import Protocol from cnmodel import cells from cnmodel.util import sound from cnmodel.util import custom_init import cnmodel.util.pynrnutilities as PU from cnmodel import data species = "mouse" # tables for other species do not yet exist def runTrial(cell, info): """ info is a dict """ if cell == "bushy": post_cell = cells.Bushy.create(species=species) elif cell == "tstellate": post_cell = cells.TStellate.create(species=species) elif cell == "octopus": post_cell = cells.Octopus.create(species=species) elif cell == "dstellate": post_cell = cells.DStellate.create(species=species) elif cell == "tuberculoventral": post_cell = cells.DStellate.create(species=species) elif cell == "pyramidal": post_cell = cells.DStellate.create(species=species) else: raise ValueError("cell %s is not yet implemented for PSTH testing" % self.cell) pre_cells = [] synapses = [] j = 0 xmtr = {} for nsgc, sgc in enumerate(range(info["n_sgc"])): pre_cells.append(cells.DummySGC(cf=info["cf"], sr=info["sr"])) if synapseType == "simple": synapses.append(pre_cells[-1].connect(post_cell, type=synapseType)) synapses[-1].terminal.netcon.weight[0] = info["gmax"] elif synapseType == "multisite": synapses.append( pre_cells[-1].connect( post_cell, post_opts={"AMPAScale": 1.0, "NMDAScale": 1.0}, type=synapseType, ) ) for i in range(synapses[-1].terminal.n_rzones): xmtr["xmtr%04d" % j] = h.Vector() xmtr["xmtr%04d" % j].record(synapses[-1].terminal.relsite._ref_XMTR[i]) j = j + 1 synapses[ -1 ].terminal.relsite.Dep_Flag = False # no depression in these simulations pre_cells[-1].set_sound_stim( info["stim"], seed=info["seed"] + nsgc, simulator=info["simulator"] ) Vm = h.Vector() Vm.record(post_cell.soma(0.5)._ref_v) rtime = h.Vector() rtime.record(h._ref_t) h.tstop = 1e3 * info["run_duration"] # duration of a run h.celsius = info["temp"] h.dt = info["dt"] post_cell.cell_initialize() info["init"]() h.t = 0.0 h.run() return { "time": np.array(rtime), "vm": Vm.to_python(), "xmtr": xmtr, "pre_cells": pre_cells, "post_cell": post_cell, "synapses": synapses, } class SGCInputTestPSTH(Protocol): def set_cell(self, cell="bushy"): self.cell = cell def run( self, temp=34.0, dt=0.025, seed=575982035, reps=10, stimulus="tone", simulator="cochlea", parallelize=False, ): assert stimulus in ["tone", "SAM", "clicks"] # cases available assert self.cell in [ "bushy", "tstellate", "octopus", "dstellate", "tuberculoventral", "pyramidal", ] self.nrep = reps self.stimulus = stimulus self.run_duration = 0.20 # in seconds self.pip_duration = 0.05 # in seconds self.pip_start = [0.1] # in seconds self.Fs = 100e3 # in Hz self.f0 = 4000.0 # stimulus in Hz self.cf = 4000.0 # SGCs in Hz self.fMod = 100.0 # mod freq, Hz self.dMod = 0.0 # % mod depth, Hz self.dbspl = 50.0 self.simulator = simulator self.sr = 2 # set SR group if self.stimulus == "SAM": self.dMod = 100.0 self.stim = sound.SAMTone( rate=self.Fs, duration=self.run_duration, f0=self.f0, fmod=self.fMod, dmod=self.dMod, dbspl=self.dbspl, ramp_duration=2.5e-3, pip_duration=self.pip_duration, pip_start=self.pip_start, ) if self.stimulus == "tone": self.f0 = 4000.0 self.cf = 4000.0 self.stim = sound.TonePip( rate=self.Fs, duration=self.run_duration, f0=self.f0, dbspl=self.dbspl, ramp_duration=2.5e-3, pip_duration=self.pip_duration, pip_start=self.pip_start, ) if self.stimulus == "clicks": self.click_rate = 0.020 # msec self.stim = sound.ClickTrain( rate=self.Fs, duration=self.run_duration, f0=self.f0, dbspl=self.dbspl, click_start=0.010, click_duration=100.0e-6, click_interval=self.click_rate, nclicks=int((self.run_duration - 0.01) / self.click_rate), ramp_duration=2.5e-3, ) n_sgc = data.get( "convergence", species=species, post_type=self.cell, pre_type="sgc" )[0] self.n_sgc = int(np.round(n_sgc)) # for simple synapses, need this value: self.AMPA_gmax = ( data.get( "sgc_synapse", species=species, post_type=self.cell, field="AMPA_gmax" )[0] / 1e3 ) # convert nS to uS for NEURON self.vms = [None for n in range(self.nrep)] self.synapses = [None for n in range(self.nrep)] self.xmtrs = [None for n in range(self.nrep)] self.pre_cells = [None for n in range(self.nrep)] self.time = [None for n in range(self.nrep)] info = { "n_sgc": self.n_sgc, "gmax": self.AMPA_gmax, "stim": self.stim, "simulator": self.simulator, "cf": self.cf, "sr": self.sr, "seed": seed, "run_duration": self.run_duration, "temp": temp, "dt": dt, "init": custom_init, } if not parallelize: for nr in range(self.nrep): info["seed"] = seed + 3 * self.n_sgc * nr res = runTrial(self.cell, info) # res contains: {'time': time, 'vm': Vm, 'xmtr': xmtr, 'pre_cells': pre_cells, 'post_cell': post_cell} self.pre_cells[nr] = res["pre_cells"] self.time[nr] = res["time"] self.xmtr = {k: v.to_python() for k, v in list(res["xmtr"].items())} self.vms[nr] = res["vm"] self.synapses[nr] = res["synapses"] self.xmtrs[nr] = self.xmtr if parallelize: pass def show(self): self.win = pg.GraphicsWindow() self.win.setBackground("w") Fs = self.Fs p1 = self.win.addPlot( title="Stimulus", row=0, col=0, labels={"bottom": "T (ms)", "left": "V"} ) p1.plot(self.stim.time * 1000, self.stim.sound, pen=pg.mkPen("k", width=0.75)) p1.setXLink(p1) p2 = self.win.addPlot( title="AN spikes", row=1, col=0, labels={"bottom": "T (ms)", "left": "AN spikes (first trial)"}, ) for nr in range(self.nrep): xan = [] yan = [] for k in range(len(self.pre_cells[nr])): r = self.pre_cells[nr][k]._spiketrain xan.extend(r) yr = k + np.zeros_like(r) + 0.2 yan.extend(yr) c = pg.PlotCurveItem() xp = np.repeat(np.array(xan), 2) yp = np.repeat(np.array(yan), 2) yp[1::2] = yp[::2] + 0.6 c.setData( xp.flatten(), yp.flatten(), connect="pairs", pen=pg.mkPen(pg.intColor(nr, self.nrep), hues=self.nrep, width=1.0), ) p2.addItem(c) p2.setXLink(p1) p3 = self.win.addPlot( title="%s Spikes" % self.cell, row=2, col=0, labels={"bottom": "T (ms)", "left": "Trial #"}, ) xcn = [] ycn = [] xspks = [] for k in range(self.nrep): bspk = PU.findspikes(self.time[k], self.vms[k], -35.0) xcn.extend(bspk) yr = k + np.zeros_like(bspk) + 0.2 ycn.extend(yr) d = pg.PlotCurveItem() xp = np.repeat(np.array(xcn), 2) yp = np.repeat(np.array(ycn), 2) yp[1::2] = yp[::2] + 0.6 d.setData( xp.flatten(), yp.flatten(), connect="pairs", pen=pg.mkPen("k", width=1.5) ) p3.addItem(d) p3.setXLink(p1) p4 = self.win.addPlot( title="%s Vm" % self.cell, row=3, col=0, labels={"bottom": "T (ms)", "left": "Vm (mV)"}, ) for nr in range(self.nrep): p4.plot( self.time[nr], self.vms[nr], pen=pg.mkPen(pg.intColor(nr, self.nrep), hues=self.nrep, width=1.0), ) p4.setXLink(p1) p5 = self.win.addPlot( title="xmtr", row=0, col=1, labels={"bottom": "T (ms)", "left": "gSyn"} ) if synapseType == "multisite": for nr in [0]: syn = self.synapses[nr] j = 0 for k in range(self.n_sgc): synapse = syn[k] for i in range(synapse.terminal.n_rzones): p5.plot( self.time[nr], self.xmtrs[nr]["xmtr%04d" % j], pen=pg.mkPen( pg.intColor(nr, self.nrep), hues=self.nrep, width=1.0 ), ) j = j + 1 p5.setXLink(p1) p6 = self.win.addPlot( title="AN PSTH", row=1, col=1, labels={"bottom": "T (ms)", "left": "Sp/ms/trial"}, ) bins = np.arange(0, 200, 1) (hist, binedges) = np.histogram(xan, bins) curve6 = p6.plot( binedges, hist, stepMode=True, fillBrush=(0, 0, 0, 255), brush=pg.mkBrush("k"), fillLevel=0, ) p7 = self.win.addPlot( title="%s PSTH" % self.cell, row=2, col=1, labels={"bottom": "T (ms)", "left": "Sp/ms/trial"}, ) bins = np.arange(0, 200, 1) (hist, binedges) = np.histogram(xcn, bins) curve7 = p7.plot( binedges, hist, stepMode=True, fillBrush=(0, 0, 0, 255), brush=pg.mkBrush("k"), fillLevel=0, ) # p6 = self.win.addPlot(title='AN phase', row=1, col=1) # phasewin = [self.pip_start[0] + 0.25*self.pip_duration, self.pip_start[0] + self.pip_duration] # prespk = self.pre_cells[0]._spiketrain # just sample one... # spkin = prespk[np.where(prespk > phasewin[0]*1e3)] # spikesinwin = spkin[np.where(spkin <= phasewin[1]*1e3)] # print "\nCell type: %s" % self.cell # print "Stimulus: " # # # set freq for VS calculation # if self.stimulus == 'tone': # f0 = self.f0 # print "Tone: f0=%.3f at %3.1f dbSPL, cell CF=%.3f" % (self.f0, self.dbspl, self.cf) # if self.stimulus == 'SAM': # f0 = self.fMod # print ("SAM Tone: f0=%.3f at %3.1f dbSPL, fMod=%3.1f dMod=%5.2f, cell CF=%.3f" % # (self.f0, self.dbspl, self.fMod, self.dMod, self.cf)) # if self.stimulus == 'clicks': # f0 = 1./self.click_rate # print "Clicks: interval %.3f at %3.1f dbSPL, cell CF=%.3f " % (self.click_rate, self.dbspl, self.cf) # vs = PU.vector_strength(spikesinwin, f0) # # print 'AN Vector Strength: %7.3f, d=%.2f (us) Rayleigh: %7.3f p = %.3e n = %d' % (vs['r'], vs['d']*1e6, vs['R'], vs['p'], vs['n']) # (hist, binedges) = np.histogram(vs['ph']) # curve = p6.plot(binedges, hist, stepMode=True, fillBrush=(100, 100, 255, 150), fillLevel=0) # p6.setXRange(0., 2*np.pi) # # p7 = self.win.addPlot(title='%s phase' % self.cell, row=2, col=1) # spkin = bspk[np.where(bspk > phasewin[0]*1e3)] # spikesinwin = spkin[np.where(spkin <= phasewin[1]*1e3)] # vs = PU.vector_strength(spikesinwin, f0) # print '%s Vector Strength: %7.3f, d=%.2f (us) Rayleigh: %7.3f p = %.3e n = %d' % (self.cell, vs['r'], vs['d']*1e6, vs['R'], vs['p'], vs['n']) # (hist, binedges) = np.histogram(vs['ph']) # curve = p7.plot(binedges, hist, stepMode=True, fillBrush=(100, 100, 255, 150), fillLevel=0) # p7.setXRange(0., 2*np.pi) # p7.setXLink(p6) self.win.show() if __name__ == "__main__": parser = argparse.ArgumentParser( description="Compute AN only PSTH in postsynaptic cell" ) parser.add_argument( type=str, dest="cell", default="bushy", choices=[ "bushy", "tstellate", "dstellate", "octopus", "tuberculoventral", "pyramida", ], help="Select target cell", ) parser.add_argument( "-s", "--stimulus", type=str, dest="stimulus", default="tone", choices=["tone", "SAM", "clicks"], help="Select stimulus from ['tone', 'SAM', 'clicks']", ) parser.add_argument( "-t", "--type", type=str, dest="syntype", default="simple", choices=["simple", "multisite"], help="Set synapse type (simple, multisite)", ) parser.add_argument( "-n", "--nrep", type=int, dest="nrep", default=10, help="Set number of repetitions", ) args = parser.parse_args() cell = args.cell stimulus = args.stimulus nrep = args.nrep synapseType = args.syntype print("cell type: ", cell) prot = SGCInputTestPSTH() prot.set_cell(cell) prot.run(stimulus=stimulus, reps=nrep) prot.show() import sys if sys.flags.interactive == 0: pg.QtGui.QApplication.exec_()