You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
431 lines
14 KiB
431 lines
14 KiB
import sys |
|
import argparse |
|
import numpy as np |
|
import pyqtgraph as pg |
|
from neuron import h |
|
from cnmodel.protocols import Protocol |
|
from cnmodel import cells |
|
from cnmodel.util import sound |
|
from cnmodel.util import custom_init |
|
import cnmodel.util.pynrnutilities as PU |
|
from cnmodel import data |
|
|
|
species = "mouse" # tables for other species do not yet exist |
|
|
|
|
|
def runTrial(cell, info): |
|
""" |
|
info is a dict |
|
""" |
|
if cell == "bushy": |
|
post_cell = cells.Bushy.create(species=species) |
|
elif cell == "tstellate": |
|
post_cell = cells.TStellate.create(species=species) |
|
elif cell == "octopus": |
|
post_cell = cells.Octopus.create(species=species) |
|
elif cell == "dstellate": |
|
post_cell = cells.DStellate.create(species=species) |
|
elif cell == "tuberculoventral": |
|
post_cell = cells.DStellate.create(species=species) |
|
elif cell == "pyramidal": |
|
post_cell = cells.DStellate.create(species=species) |
|
else: |
|
raise ValueError("cell %s is not yet implemented for PSTH testing" % self.cell) |
|
pre_cells = [] |
|
synapses = [] |
|
j = 0 |
|
xmtr = {} |
|
for nsgc, sgc in enumerate(range(info["n_sgc"])): |
|
pre_cells.append(cells.DummySGC(cf=info["cf"], sr=info["sr"])) |
|
if synapseType == "simple": |
|
synapses.append(pre_cells[-1].connect(post_cell, type=synapseType)) |
|
synapses[-1].terminal.netcon.weight[0] = info["gmax"] |
|
elif synapseType == "multisite": |
|
synapses.append( |
|
pre_cells[-1].connect( |
|
post_cell, |
|
post_opts={"AMPAScale": 1.0, "NMDAScale": 1.0}, |
|
type=synapseType, |
|
) |
|
) |
|
for i in range(synapses[-1].terminal.n_rzones): |
|
xmtr["xmtr%04d" % j] = h.Vector() |
|
xmtr["xmtr%04d" % j].record(synapses[-1].terminal.relsite._ref_XMTR[i]) |
|
j = j + 1 |
|
synapses[ |
|
-1 |
|
].terminal.relsite.Dep_Flag = False # no depression in these simulations |
|
pre_cells[-1].set_sound_stim( |
|
info["stim"], seed=info["seed"] + nsgc, simulator=info["simulator"] |
|
) |
|
Vm = h.Vector() |
|
Vm.record(post_cell.soma(0.5)._ref_v) |
|
rtime = h.Vector() |
|
rtime.record(h._ref_t) |
|
h.tstop = 1e3 * info["run_duration"] # duration of a run |
|
h.celsius = info["temp"] |
|
h.dt = info["dt"] |
|
post_cell.cell_initialize() |
|
info["init"]() |
|
h.t = 0.0 |
|
h.run() |
|
return { |
|
"time": np.array(rtime), |
|
"vm": Vm.to_python(), |
|
"xmtr": xmtr, |
|
"pre_cells": pre_cells, |
|
"post_cell": post_cell, |
|
"synapses": synapses, |
|
} |
|
|
|
|
|
class SGCInputTestPSTH(Protocol): |
|
def set_cell(self, cell="bushy"): |
|
self.cell = cell |
|
|
|
def run( |
|
self, |
|
temp=34.0, |
|
dt=0.025, |
|
seed=575982035, |
|
reps=10, |
|
stimulus="tone", |
|
simulator="cochlea", |
|
parallelize=False, |
|
): |
|
assert stimulus in ["tone", "SAM", "clicks"] # cases available |
|
assert self.cell in [ |
|
"bushy", |
|
"tstellate", |
|
"octopus", |
|
"dstellate", |
|
"tuberculoventral", |
|
"pyramidal", |
|
] |
|
self.nrep = reps |
|
self.stimulus = stimulus |
|
self.run_duration = 0.20 # in seconds |
|
self.pip_duration = 0.05 # in seconds |
|
self.pip_start = [0.1] # in seconds |
|
self.Fs = 100e3 # in Hz |
|
self.f0 = 4000.0 # stimulus in Hz |
|
self.cf = 4000.0 # SGCs in Hz |
|
self.fMod = 100.0 # mod freq, Hz |
|
self.dMod = 0.0 # % mod depth, Hz |
|
self.dbspl = 50.0 |
|
self.simulator = simulator |
|
self.sr = 2 # set SR group |
|
if self.stimulus == "SAM": |
|
self.dMod = 100.0 |
|
self.stim = sound.SAMTone( |
|
rate=self.Fs, |
|
duration=self.run_duration, |
|
f0=self.f0, |
|
fmod=self.fMod, |
|
dmod=self.dMod, |
|
dbspl=self.dbspl, |
|
ramp_duration=2.5e-3, |
|
pip_duration=self.pip_duration, |
|
pip_start=self.pip_start, |
|
) |
|
if self.stimulus == "tone": |
|
self.f0 = 4000.0 |
|
self.cf = 4000.0 |
|
self.stim = sound.TonePip( |
|
rate=self.Fs, |
|
duration=self.run_duration, |
|
f0=self.f0, |
|
dbspl=self.dbspl, |
|
ramp_duration=2.5e-3, |
|
pip_duration=self.pip_duration, |
|
pip_start=self.pip_start, |
|
) |
|
|
|
if self.stimulus == "clicks": |
|
self.click_rate = 0.020 # msec |
|
self.stim = sound.ClickTrain( |
|
rate=self.Fs, |
|
duration=self.run_duration, |
|
f0=self.f0, |
|
dbspl=self.dbspl, |
|
click_start=0.010, |
|
click_duration=100.0e-6, |
|
click_interval=self.click_rate, |
|
nclicks=int((self.run_duration - 0.01) / self.click_rate), |
|
ramp_duration=2.5e-3, |
|
) |
|
|
|
n_sgc = data.get( |
|
"convergence", species=species, post_type=self.cell, pre_type="sgc" |
|
)[0] |
|
self.n_sgc = int(np.round(n_sgc)) |
|
# for simple synapses, need this value: |
|
self.AMPA_gmax = ( |
|
data.get( |
|
"sgc_synapse", species=species, post_type=self.cell, field="AMPA_gmax" |
|
)[0] |
|
/ 1e3 |
|
) # convert nS to uS for NEURON |
|
self.vms = [None for n in range(self.nrep)] |
|
self.synapses = [None for n in range(self.nrep)] |
|
self.xmtrs = [None for n in range(self.nrep)] |
|
self.pre_cells = [None for n in range(self.nrep)] |
|
self.time = [None for n in range(self.nrep)] |
|
info = { |
|
"n_sgc": self.n_sgc, |
|
"gmax": self.AMPA_gmax, |
|
"stim": self.stim, |
|
"simulator": self.simulator, |
|
"cf": self.cf, |
|
"sr": self.sr, |
|
"seed": seed, |
|
"run_duration": self.run_duration, |
|
"temp": temp, |
|
"dt": dt, |
|
"init": custom_init, |
|
} |
|
if not parallelize: |
|
for nr in range(self.nrep): |
|
info["seed"] = seed + 3 * self.n_sgc * nr |
|
res = runTrial(self.cell, info) |
|
# res contains: {'time': time, 'vm': Vm, 'xmtr': xmtr, 'pre_cells': pre_cells, 'post_cell': post_cell} |
|
self.pre_cells[nr] = res["pre_cells"] |
|
self.time[nr] = res["time"] |
|
self.xmtr = {k: v.to_python() for k, v in list(res["xmtr"].items())} |
|
self.vms[nr] = res["vm"] |
|
self.synapses[nr] = res["synapses"] |
|
self.xmtrs[nr] = self.xmtr |
|
|
|
if parallelize: |
|
pass |
|
|
|
def show(self): |
|
self.win = pg.GraphicsWindow() |
|
self.win.setBackground("w") |
|
Fs = self.Fs |
|
p1 = self.win.addPlot( |
|
title="Stimulus", row=0, col=0, labels={"bottom": "T (ms)", "left": "V"} |
|
) |
|
p1.plot(self.stim.time * 1000, self.stim.sound, pen=pg.mkPen("k", width=0.75)) |
|
p1.setXLink(p1) |
|
|
|
p2 = self.win.addPlot( |
|
title="AN spikes", |
|
row=1, |
|
col=0, |
|
labels={"bottom": "T (ms)", "left": "AN spikes (first trial)"}, |
|
) |
|
for nr in range(self.nrep): |
|
xan = [] |
|
yan = [] |
|
for k in range(len(self.pre_cells[nr])): |
|
r = self.pre_cells[nr][k]._spiketrain |
|
xan.extend(r) |
|
yr = k + np.zeros_like(r) + 0.2 |
|
yan.extend(yr) |
|
c = pg.PlotCurveItem() |
|
xp = np.repeat(np.array(xan), 2) |
|
yp = np.repeat(np.array(yan), 2) |
|
yp[1::2] = yp[::2] + 0.6 |
|
c.setData( |
|
xp.flatten(), |
|
yp.flatten(), |
|
connect="pairs", |
|
pen=pg.mkPen(pg.intColor(nr, self.nrep), hues=self.nrep, width=1.0), |
|
) |
|
p2.addItem(c) |
|
p2.setXLink(p1) |
|
|
|
p3 = self.win.addPlot( |
|
title="%s Spikes" % self.cell, |
|
row=2, |
|
col=0, |
|
labels={"bottom": "T (ms)", "left": "Trial #"}, |
|
) |
|
xcn = [] |
|
ycn = [] |
|
xspks = [] |
|
for k in range(self.nrep): |
|
bspk = PU.findspikes(self.time[k], self.vms[k], -35.0) |
|
xcn.extend(bspk) |
|
yr = k + np.zeros_like(bspk) + 0.2 |
|
ycn.extend(yr) |
|
d = pg.PlotCurveItem() |
|
xp = np.repeat(np.array(xcn), 2) |
|
yp = np.repeat(np.array(ycn), 2) |
|
yp[1::2] = yp[::2] + 0.6 |
|
d.setData( |
|
xp.flatten(), yp.flatten(), connect="pairs", pen=pg.mkPen("k", width=1.5) |
|
) |
|
p3.addItem(d) |
|
p3.setXLink(p1) |
|
|
|
p4 = self.win.addPlot( |
|
title="%s Vm" % self.cell, |
|
row=3, |
|
col=0, |
|
labels={"bottom": "T (ms)", "left": "Vm (mV)"}, |
|
) |
|
for nr in range(self.nrep): |
|
p4.plot( |
|
self.time[nr], |
|
self.vms[nr], |
|
pen=pg.mkPen(pg.intColor(nr, self.nrep), hues=self.nrep, width=1.0), |
|
) |
|
p4.setXLink(p1) |
|
|
|
p5 = self.win.addPlot( |
|
title="xmtr", row=0, col=1, labels={"bottom": "T (ms)", "left": "gSyn"} |
|
) |
|
if synapseType == "multisite": |
|
for nr in [0]: |
|
syn = self.synapses[nr] |
|
j = 0 |
|
for k in range(self.n_sgc): |
|
synapse = syn[k] |
|
for i in range(synapse.terminal.n_rzones): |
|
p5.plot( |
|
self.time[nr], |
|
self.xmtrs[nr]["xmtr%04d" % j], |
|
pen=pg.mkPen( |
|
pg.intColor(nr, self.nrep), hues=self.nrep, width=1.0 |
|
), |
|
) |
|
j = j + 1 |
|
p5.setXLink(p1) |
|
|
|
p6 = self.win.addPlot( |
|
title="AN PSTH", |
|
row=1, |
|
col=1, |
|
labels={"bottom": "T (ms)", "left": "Sp/ms/trial"}, |
|
) |
|
bins = np.arange(0, 200, 1) |
|
(hist, binedges) = np.histogram(xan, bins) |
|
curve6 = p6.plot( |
|
binedges, |
|
hist, |
|
stepMode=True, |
|
fillBrush=(0, 0, 0, 255), |
|
brush=pg.mkBrush("k"), |
|
fillLevel=0, |
|
) |
|
|
|
p7 = self.win.addPlot( |
|
title="%s PSTH" % self.cell, |
|
row=2, |
|
col=1, |
|
labels={"bottom": "T (ms)", "left": "Sp/ms/trial"}, |
|
) |
|
bins = np.arange(0, 200, 1) |
|
(hist, binedges) = np.histogram(xcn, bins) |
|
curve7 = p7.plot( |
|
binedges, |
|
hist, |
|
stepMode=True, |
|
fillBrush=(0, 0, 0, 255), |
|
brush=pg.mkBrush("k"), |
|
fillLevel=0, |
|
) |
|
|
|
# p6 = self.win.addPlot(title='AN phase', row=1, col=1) |
|
# phasewin = [self.pip_start[0] + 0.25*self.pip_duration, self.pip_start[0] + self.pip_duration] |
|
# prespk = self.pre_cells[0]._spiketrain # just sample one... |
|
# spkin = prespk[np.where(prespk > phasewin[0]*1e3)] |
|
# spikesinwin = spkin[np.where(spkin <= phasewin[1]*1e3)] |
|
# print "\nCell type: %s" % self.cell |
|
# print "Stimulus: " |
|
# |
|
# # set freq for VS calculation |
|
# if self.stimulus == 'tone': |
|
# f0 = self.f0 |
|
# print "Tone: f0=%.3f at %3.1f dbSPL, cell CF=%.3f" % (self.f0, self.dbspl, self.cf) |
|
# if self.stimulus == 'SAM': |
|
# f0 = self.fMod |
|
# print ("SAM Tone: f0=%.3f at %3.1f dbSPL, fMod=%3.1f dMod=%5.2f, cell CF=%.3f" % |
|
# (self.f0, self.dbspl, self.fMod, self.dMod, self.cf)) |
|
# if self.stimulus == 'clicks': |
|
# f0 = 1./self.click_rate |
|
# print "Clicks: interval %.3f at %3.1f dbSPL, cell CF=%.3f " % (self.click_rate, self.dbspl, self.cf) |
|
# vs = PU.vector_strength(spikesinwin, f0) |
|
# |
|
# print 'AN Vector Strength: %7.3f, d=%.2f (us) Rayleigh: %7.3f p = %.3e n = %d' % (vs['r'], vs['d']*1e6, vs['R'], vs['p'], vs['n']) |
|
# (hist, binedges) = np.histogram(vs['ph']) |
|
# curve = p6.plot(binedges, hist, stepMode=True, fillBrush=(100, 100, 255, 150), fillLevel=0) |
|
# p6.setXRange(0., 2*np.pi) |
|
# |
|
# p7 = self.win.addPlot(title='%s phase' % self.cell, row=2, col=1) |
|
# spkin = bspk[np.where(bspk > phasewin[0]*1e3)] |
|
# spikesinwin = spkin[np.where(spkin <= phasewin[1]*1e3)] |
|
# vs = PU.vector_strength(spikesinwin, f0) |
|
# print '%s Vector Strength: %7.3f, d=%.2f (us) Rayleigh: %7.3f p = %.3e n = %d' % (self.cell, vs['r'], vs['d']*1e6, vs['R'], vs['p'], vs['n']) |
|
# (hist, binedges) = np.histogram(vs['ph']) |
|
# curve = p7.plot(binedges, hist, stepMode=True, fillBrush=(100, 100, 255, 150), fillLevel=0) |
|
# p7.setXRange(0., 2*np.pi) |
|
# p7.setXLink(p6) |
|
|
|
self.win.show() |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser( |
|
description="Compute AN only PSTH in postsynaptic cell" |
|
) |
|
parser.add_argument( |
|
type=str, |
|
dest="cell", |
|
default="bushy", |
|
choices=[ |
|
"bushy", |
|
"tstellate", |
|
"dstellate", |
|
"octopus", |
|
"tuberculoventral", |
|
"pyramida", |
|
], |
|
help="Select target cell", |
|
) |
|
parser.add_argument( |
|
"-s", |
|
"--stimulus", |
|
type=str, |
|
dest="stimulus", |
|
default="tone", |
|
choices=["tone", "SAM", "clicks"], |
|
help="Select stimulus from ['tone', 'SAM', 'clicks']", |
|
) |
|
parser.add_argument( |
|
"-t", |
|
"--type", |
|
type=str, |
|
dest="syntype", |
|
default="simple", |
|
choices=["simple", "multisite"], |
|
help="Set synapse type (simple, multisite)", |
|
) |
|
parser.add_argument( |
|
"-n", |
|
"--nrep", |
|
type=int, |
|
dest="nrep", |
|
default=10, |
|
help="Set number of repetitions", |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
cell = args.cell |
|
stimulus = args.stimulus |
|
nrep = args.nrep |
|
synapseType = args.syntype |
|
|
|
print("cell type: ", cell) |
|
prot = SGCInputTestPSTH() |
|
prot.set_cell(cell) |
|
prot.run(stimulus=stimulus, reps=nrep) |
|
prot.show() |
|
|
|
import sys |
|
|
|
if sys.flags.interactive == 0: |
|
pg.QtGui.QApplication.exec_()
|
|
|