You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
451 lines
14 KiB
451 lines
14 KiB
""" |
|
Layout: |
|
|
|
if __name__==__main__ handles cmd args, instantiates test, runs it, displays results |
|
|
|
run_trial(): defines a model and then runs the test using preset params in hoc and return the info to the class |
|
SGCInputTest: |
|
__init__ : defines many static variables |
|
run(): calls delivers information to the run_trial() function and the recieved information of a single run |
|
back and stores it as an exstensible list in the class |
|
show(): displays graphs depending on the graph options selected below it and displayed in a printout |
|
|
|
""" |
|
import argparse |
|
import numpy as np |
|
import pyqtgraph as pg |
|
from neuron import h |
|
from cnmodel.protocols import Protocol |
|
from cnmodel import cells |
|
from cnmodel.util import sound |
|
from cnmodel.util import custom_init |
|
import cnmodel.util.pynrnutilities as PU |
|
from cnmodel import data |
|
|
|
try: |
|
from tqdm import trange |
|
except ImportError: |
|
raise ImportError("Please 'pip install tqdm' to allow for progress bar") |
|
|
|
species = "rat" # tables for other species do not yet exist |
|
|
|
|
|
def run_trial(cell, info): |
|
""" |
|
info is a dict |
|
""" |
|
assert cell == "pyramidal" |
|
post_cell = cells.Pyramidal.create(species=species) |
|
inhib_cell = cells.Tuberculoventral.create() |
|
inhib_cell2 = cells.Tuberculoventral.create() |
|
# dstell = cells.DStellate.create() |
|
pre_cells = [] |
|
synapses = [] |
|
inhib_synapses = [] |
|
for nsgc in range(48): |
|
# attach to pyramidal cell |
|
pre_cells.append(cells.DummySGC(cf=info["cf"], sr=info["sr"])) |
|
synapses.append(pre_cells[-1].connect(post_cell, type=info["synapse_type"])) |
|
pre_cells[-1].set_sound_stim( |
|
info["stim"], seed=info["seed"] + nsgc, simulator=info["simulator"] |
|
) |
|
synapses[ |
|
-1 |
|
].terminal.relsite.Dep_Flag = False # no depression in these simulations |
|
for nsgc in range(16): |
|
pre_cells.append(cells.DummySGC(cf=info["cf"], sr=info["sr"])) |
|
inhib_synapses.append( |
|
pre_cells[-1].connect(inhib_cell, type=info["synapse_type"]) |
|
) |
|
inhib_synapses.append( |
|
pre_cells[-1].connect(inhib_cell2, type=info["synapse_type"]) |
|
) |
|
pre_cells[-1].set_sound_stim( |
|
info["stim"], seed=info["seed"] + nsgc + 48, simulator=info["simulator"] |
|
) |
|
synapses[ |
|
-1 |
|
].terminal.relsite.Dep_Flag = False # no depression in these simulations |
|
# for nsgc in range(20): |
|
# pre_cells.append(cells.DummySGC(cf=info['cf'], sr=info['sr'])) |
|
# inhib_synapses.append(pre_cells[-1].connect(dstell, type=info['synapse_type'])) |
|
# pre_cells[-1].set_sound_stim(info['stim'], seed=info['seed'] + nsgc + 16 + 48, simulator=info['simulator']) |
|
# synapses[-1].terminal.relsite.Dep_Flag = False # no depression in these simulations |
|
for _ in range(21): |
|
inhib_synapses.append(inhib_cell.connect(post_cell, type="simple")) |
|
inhib_synapses.append(inhib_cell2.connect(post_cell, type="simple")) |
|
# for _ in range(15): |
|
# inhib_synapses.append(inhib_cell.connect(dstell, type='simple')) |
|
|
|
Vm = h.Vector() |
|
Vm.record(post_cell.soma(0.5)._ref_v) |
|
Vmtb = h.Vector() |
|
Vmtb.record(inhib_cell.soma(0.5)._ref_v) |
|
rtime = h.Vector() |
|
rtime.record(h._ref_t) |
|
h.tstop = 1e3 * info["run_duration"] # duration of a run |
|
h.celsius = info["temp"] |
|
h.dt = info["dt"] |
|
post_cell.cell_initialize() |
|
info["init"]() |
|
h.t = 0.0 |
|
h.run() |
|
# package data |
|
pre_cells_data = [x._spiketrain for x in pre_cells] |
|
Vm_list = np.array(Vm) |
|
Vmtb_list = np.array(Vmtb) |
|
time_list = np.array(rtime) |
|
# clean up |
|
del ( |
|
pre_cells, |
|
synapses, |
|
inhib_cell, |
|
inhib_synapses, |
|
Vm, |
|
Vmtb, |
|
post_cell, |
|
inhib_cell2, |
|
info, |
|
) |
|
|
|
return { |
|
"time": time_list, |
|
"vm": Vm_list, |
|
"pre_cells": pre_cells_data, |
|
"vmtb": Vmtb_list, |
|
} |
|
|
|
|
|
class SGCTestPSTH(Protocol): |
|
""" |
|
Tests a Single cell with input recieved from the SGC |
|
|
|
__init__: almost all parameters can be modified |
|
run(): simply loops over the run_trial() function and stores the results just |
|
show(): constructs the graphs using other functions |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
temp=34.0, |
|
seed=2918, |
|
nrep=10, |
|
stimulus="tone", |
|
simulator="cochlea", |
|
n_sgc=12, |
|
debug=True, |
|
cell="pyramidal", |
|
): |
|
""" |
|
:param temp: (float) must be at 34 for default pyramidal cells |
|
:param dt: (float) determine hoc resolution |
|
:param seed: (int) contributes to randomization, needs to be changed to see different results (reduce this |
|
number if you keep getting a timeout error |
|
:param nrep: (int) number of presentations !!must be changed in the __name__ function if not calling from cmd line!! |
|
:param stimulus: (str) must be 'tone' |
|
:param simulator: (str) currently using cochlea instead of matlab |
|
:param n_sgc:(int) This is the number of SGC fibers that connect to the post synaptic cell |
|
:param debug: (bool) controls most of the terminal printouts is on by default |
|
:param cell: (str) cell type !!must be changed in the __name__ function if not calling from cmd line!! |
|
""" |
|
super().__init__() |
|
assert stimulus == "tone" |
|
assert cell in [ |
|
"bushy", |
|
"tstellate", |
|
"octopus", |
|
"dstellate", |
|
"tuberculoventral", |
|
"pyramidal", |
|
] |
|
self.debug = debug |
|
self.nrep = nrep |
|
self.stimulus = stimulus |
|
self.run_duration = 0.30 # in seconds |
|
self.pip_duration = 0.05 # in seconds |
|
self.pip_start = [0.1] # in seconds |
|
self.Fs = 100e3 # in Hz |
|
self.f0 = 13000.0 # stimulus in Hz |
|
self.cf = 13000.0 # SGCs in Hz |
|
self.fMod = 100.0 # mod freq, Hz |
|
self.dMod = 0.0 # % mod depth, Hz |
|
self.dbspl = 40.0 |
|
self.simulator = simulator |
|
self.sr = 2 # set SR group |
|
self.seed = seed |
|
self.temp = temp |
|
self.dt = 0.025 |
|
self.cell = cell |
|
self.synapse_type = "multisite" |
|
|
|
if self.stimulus == "tone": |
|
self.stim = sound.TonePip( |
|
rate=self.Fs, |
|
duration=self.run_duration, |
|
f0=self.f0, |
|
dbspl=self.dbspl, |
|
ramp_duration=5e-3, |
|
pip_duration=self.pip_duration, |
|
pip_start=self.pip_start, |
|
) |
|
|
|
if not n_sgc: |
|
n_sgc = data.get( |
|
"convergence", species="mouse", post_type=self.cell, pre_type="sgc" |
|
)[0] |
|
self.n_sgc = int(np.round(n_sgc)) |
|
# convert nS to uS for NEURON |
|
|
|
self.vms = [None for n in range(self.nrep)] |
|
self.vmtbs = [None for n in range(self.nrep)] |
|
self.synapses = [None for n in range(self.nrep)] |
|
self.pre_cells = [None for n in range(self.nrep)] |
|
self.time = [None for n in range(self.nrep)] |
|
# debug function reports a print out of various information about the run |
|
if self.debug: |
|
print("SGCInputTest Created") |
|
print() |
|
print("Test parameters") |
|
print("#" * 70) |
|
print(f"Running test of {cell} cell synapse with Simulated SGC fibers") |
|
print() |
|
print(f"Run Conditions: Run Time: {self.run_duration}s,") |
|
print(f" Run Temp: {self.temp} ") |
|
print(f" Sgc Connections {n_sgc}") |
|
print(f" Number of Presentations: {nrep}") |
|
print() |
|
print(f"Stimulus Conditions: Type: {stimulus}") |
|
print(f" Stim Duration: {self.pip_duration}s") |
|
print(f" Characteristic F: {self.cf}hz") |
|
print(f" Stim Start:{str(self.pip_start)}s") |
|
|
|
def run(self): |
|
super().run() |
|
info = { |
|
"n_sgc": self.n_sgc, |
|
"stim": self.stim, |
|
"simulator": self.simulator, |
|
"cf": self.cf, |
|
"sr": self.sr, |
|
"seed": self.seed, |
|
"run_duration": self.run_duration, |
|
"synapse_type": self.synapse_type, |
|
"temp": self.temp, |
|
"dt": self.dt, |
|
"init": custom_init, |
|
} |
|
for nr in trange(self.nrep): |
|
info["seed"] = self.seed + self.n_sgc + (nr * (48 + 16 + 20)) |
|
res = run_trial(self.cell, info) |
|
# res contains: {'time': time, 'vm': list(Vm), 'pre_cells': pre_cells._spiketrain,'vmtb': list(Vmtb)} |
|
self.pre_cells[nr] = res["pre_cells"] |
|
self.time[nr] = res["time"] |
|
self.vms[nr] = res["vm"] |
|
self.vmtbs[nr] = res["vmtb"] |
|
|
|
def show(self): |
|
""" |
|
Creates a single page graph that contains all of the graphs based on the graphical functions in the class |
|
|
|
""" |
|
self.win = pg.GraphicsWindow() |
|
self.win.setBackground("w") |
|
p1 = self.stimulus_graph() |
|
p2 = self.an_spike_graph() |
|
p3 = self.pyram_spike_graph() |
|
p4 = self.voltage_graph() |
|
p5 = self.tb_cell_spike_graph() |
|
p6 = ( |
|
self.an_psth_graph() |
|
) # requires that an_spikes_graph() has been called before |
|
p7 = ( |
|
self.cell_psth_graph() |
|
) # requires that cell_spikes_graph() has been called before |
|
|
|
# links x axis |
|
p1.setXLink(p1) |
|
p2.setXLink(p1) |
|
p3.setXLink(p1) |
|
p4.setXLink(p1) |
|
p5.setXLink(p1) |
|
p6.setXLink(p1) |
|
p7.setXLink(p1) |
|
self.win.show() |
|
if self.debug: |
|
print("finished") |
|
|
|
############# Graph options to be included in the show() method ################### |
|
def stimulus_graph(self): |
|
p1 = self.win.addPlot( |
|
title="Stimulus", row=0, col=0, labels={"bottom": "T (ms)", "left": "V"} |
|
) |
|
p1.plot(self.stim.time * 1000, self.stim.sound, pen=pg.mkPen("k", width=0.75)) |
|
return p1 |
|
|
|
def an_spike_graph(self): |
|
p2 = self.win.addPlot( |
|
title="AN spikes", |
|
row=1, |
|
col=0, |
|
labels={"bottom": "T (ms)", "left": "AN spikes (first trial)"}, |
|
) |
|
self.all_xan = [] |
|
for nr in range(self.nrep): |
|
xan = [] |
|
yan = [] |
|
for k in range(len(self.pre_cells[nr])): |
|
r = self.pre_cells[nr][k] |
|
xan.extend(r) |
|
self.all_xan.extend(r) |
|
yr = k + np.zeros_like(r) + 0.2 |
|
yan.extend(yr) |
|
c = pg.PlotCurveItem() |
|
xp = np.repeat(np.array(xan), 2) |
|
yp = np.repeat(np.array(yan), 2) |
|
yp[1::2] = yp[::2] + 0.6 |
|
c.setData( |
|
xp.flatten(), |
|
yp.flatten(), |
|
connect="pairs", |
|
width=1.0, |
|
pen=pg.mkPen("k", width=1.5), |
|
) |
|
p2.addItem(c) |
|
|
|
return p2 |
|
|
|
def pyram_spike_graph(self): |
|
p3 = self.win.addPlot( |
|
title="Pyramidal Spikes", |
|
row=2, |
|
col=0, |
|
labels={"bottom": "T (ms)", "left": "Trial #"}, |
|
) |
|
xcn = [] |
|
ycn = [] |
|
for k in range(self.nrep): |
|
bspk = PU.findspikes(self.time[k], self.vms[k], -35.0) |
|
xcn.extend(bspk) |
|
yr = k + np.zeros_like(bspk) + 0.2 |
|
ycn.extend(yr) |
|
d = pg.PlotCurveItem() |
|
xp = np.repeat(np.array(xcn), 2) |
|
yp = np.repeat(np.array(ycn), 2) |
|
yp[1::2] = yp[::2] + 0.6 |
|
d.setData( |
|
xp.flatten(), yp.flatten(), connect="pairs", pen=pg.mkPen("k", width=1.5) |
|
) |
|
self.xcn = xcn |
|
self.ycn = ycn |
|
p3.addItem(d) |
|
|
|
return p3 |
|
|
|
def voltage_graph(self): |
|
p4 = self.win.addPlot( |
|
title="%s Vm" % self.cell, |
|
row=0, |
|
col=1, |
|
labels={"bottom": "T (ms)", "left": "Vm (mV)"}, |
|
) |
|
if self.nrep > 3: |
|
display = 3 |
|
else: |
|
display = self.nrep |
|
for nr in range(display): |
|
p4.plot( |
|
self.time[nr], |
|
self.vms[nr], |
|
pen=pg.mkPen(pg.intColor(nr, self.nrep), hues=self.nrep, width=1.0), |
|
) |
|
return p4 |
|
|
|
def tb_cell_spike_graph(self): |
|
p5 = self.win.addPlot( |
|
title="Tuberculoventral Spikes", |
|
row=3, |
|
col=0, |
|
labels={"bottom": "T (ms)", "left": "Trial #"}, |
|
) |
|
xtcn = [] |
|
ytcn = [] |
|
for k in range(self.nrep): |
|
bspk = PU.findspikes(self.time[k], self.vmtbs[k], -35.0) |
|
xtcn.extend(bspk) |
|
yr = k + np.zeros_like(bspk) + 0.2 |
|
ytcn.extend(yr) |
|
d = pg.PlotCurveItem() |
|
xp = np.repeat(np.array(xtcn), 2) |
|
yp = np.repeat(np.array(ytcn), 2) |
|
yp[1::2] = yp[::2] + 0.6 |
|
d.setData( |
|
xp.flatten(), yp.flatten(), connect="pairs", pen=pg.mkPen("k", width=1.5) |
|
) |
|
p5.addItem(d) |
|
|
|
return p5 |
|
|
|
def an_psth_graph(self): |
|
p6 = self.win.addPlot( |
|
title="AN PSTH", |
|
row=1, |
|
col=1, |
|
labels={"bottom": "T (ms)", "left": "Sp/ms/trial"}, |
|
) |
|
bins = np.arange(50, 200, 1) |
|
(hist, binedges) = np.histogram(self.all_xan, bins) |
|
curve6 = p6.plot( |
|
binedges, |
|
hist, |
|
stepMode=True, |
|
fillBrush=(0, 0, 0, 255), |
|
brush=pg.mkBrush("k"), |
|
fillLevel=0, |
|
) |
|
return p6 |
|
|
|
def cell_psth_graph(self): |
|
p7 = self.win.addPlot( |
|
title="Pyramidal PSTH", |
|
row=2, |
|
col=1, |
|
labels={"bottom": "T (ms)", "left": "Sp/ms/trial"}, |
|
) |
|
bins = np.arange(50, 200, 1) |
|
(hist, binedges) = np.histogram(self.xcn, bins) |
|
curve7 = p7.plot( |
|
binedges, |
|
hist, |
|
stepMode=True, |
|
fillBrush=(0, 0, 0, 255), |
|
brush=pg.mkBrush("k"), |
|
fillLevel=0, |
|
) |
|
return p7 |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser( |
|
description="Compute AN only PSTH in postsynaptic cell" |
|
) |
|
parser.add_argument( |
|
"-n", |
|
"--nrep", |
|
type=int, |
|
dest="nrep", |
|
default=10, |
|
help="Set number of repetitions", |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
nrep = args.nrep |
|
prot = SGCTestPSTH(nrep=50) |
|
prot.run() |
|
prot.show() |
|
|
|
import sys |
|
|
|
if sys.flags.interactive == 0: |
|
pg.QtGui.QApplication.exec_()
|
|
|