model of DCN pyramidal neuron
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

1047 lines
38 KiB

"""
test_bushy_variation.py
Test inputs to bushy cells as we co-vary the KLT and IH conductances
Two tests: IV and spikes.
Simulation results are first written to disk (.p files); plotting is done separately
Usage:
python test_bushy_variation.py [a, b]
a runs IV curvers with variations of gKlt/gh
b runs PSTHs to AN input (CF tones) across the same variations.
"""
import sys
import numpy as np
import pyqtgraph as pg
import pyqtgraph.multiprocess as mp
import pickle
import time
from neuron import h
from cnmodel.protocols import Protocol
from cnmodel.protocols import iv_curve
from cnmodel import cells
from cnmodel.util import sound
from cnmodel.util import custom_init
from cnmodel.util import make_pulse
import cnmodel.util.pynrnutilities as PU
from cnmodel import data
import matplotlib.pyplot as mpl
import cnmodel.util.PlotHelpers as PH
import timeit
synapseType = "multisite" # 'simple'
species = "mouse" # tables for other species do not yet exist
class RunTrial:
def __init__(self, post_cell, info):
"""
info is a dict
"""
pre_cells = []
synapses = []
j = 0
xmtr = {}
for nsgc, sgc in enumerate(range(info["n_sic0"])):
pre_cells.append(cells.DummySGC(cf=info["cf"], sr=info["sr"]))
if synapseType == "simple":
synapses.append(pre_cells[-1].connect(post_cell, type=synapseType))
synapses[-1].terminal.netcon.weight[0] = info["gmax"] * 2.0
elif synapseType == "multisite":
synapses.append(
pre_cells[-1].connect(
post_cell,
post_opts={"AMPAScale": 2.0, "NMDAScale": 2.0},
type=synapseType,
)
)
for i in range(synapses[-1].terminal.n_rzones):
xmtr["xmtr%04d" % j] = h.Vector()
xmtr["xmtr%04d" % j].record(
synapses[-1].terminal.relsite._ref_XMTR[i]
)
j = j + 1
synapses[
-1
].terminal.relsite.Dep_Flag = (
False
) # no depression in these simulations
pre_cells[-1].set_sound_stim(
info["stim"], seed=info["seed"] + nsgc, simulator=str(info["simulator"])
)
Vm = h.Vector()
Vm.record(post_cell.soma(0.5)._ref_v)
rtime = h.Vector()
rtime.record(h._ref_t)
h.tstop = 1e3 * info["run_duration"] # duration of a run
h.celsius = info["temp"]
h.dt = info["dt"]
post_cell.cell_initialize()
info["init"]()
h.t = 0.0
h.run()
return {
"time": np.array(rtime),
"vm": Vm.to_python(),
"xmtr": xmtr,
"pre_cells": pre_cells,
"post_cell": post_cell,
"synapses": synapses,
}
class SGCInputTestPSTH(Protocol):
def set_cell(self, cell="bushy"):
self.cell = cell
self.parallelize = False
def run(
self,
temp=34.0,
dt=0.025,
seed=575982035,
reps=10,
stimulus="tone",
simulator="cochlea",
):
assert stimulus in ["tone", "SAM", "clicks"] # cases available
assert self.cell in ["bushy", "tstellate", "octopus", "dstellate"]
self.nrep = reps
self.stimulus = stimulus
self.run_duration = 0.20 # in seconds
self.pip_duration = 0.05 # in seconds
self.pip_start = [0.1] # in seconds
self.Fs = 100e3 # in Hz
self.f0 = 4000.0 # stimulus in Hz
self.cf = 4000.0 # SGCs in Hz
self.fMod = 100.0 # mod freq, Hz
self.dMod = 0.0 # % mod depth, Hz
self.dbspl = 50.0
self.simulator = str(simulator)
self.sr = 1 # set SR group
if self.stimulus == "SAM":
self.stim = sound.SAMTone(
rate=self.Fs,
duration=self.run_duration,
f0=self.f0,
fmod=self.fMod,
dmod=self.dMod,
dbspl=self.dbspl,
ramp_duration=2.5e-3,
pip_duration=self.pip_duration,
pip_start=self.pip_start,
)
if self.stimulus == "tone":
self.f0 = 4000.0
self.cf = 4000.0
self.stim = sound.TonePip(
rate=self.Fs,
duration=self.run_duration,
f0=self.f0,
dbspl=self.dbspl,
ramp_duration=2.5e-3,
pip_duration=self.pip_duration,
pip_start=self.pip_start,
)
if self.stimulus == "clicks":
self.click_rate = 0.020 # msec
self.stim = sound.ClickTrain(
rate=self.Fs,
duration=self.run_duration,
f0=self.f0,
dbspl=self.dbspl,
click_start=0.010,
click_duration=100.0e-6,
click_interval=self.click_rate,
nclicks=int((self.run_duration - 0.01) / self.click_rate),
ramp_duration=2.5e-3,
)
n_sgc = data.get(
"convergence", species=species, post_type=self.cell, pre_type="sgc"
)[0]
self.n_sgc = int(np.round(n_sgc))
# for simple synapses, need this value:
self.AMPA_gmax = (
data.get(
"sgc_synapse", species=species, post_type=self.cell, field="AMPA_gmax"
)[0]
/ 1e3
) # convert nS to uS for NEURON
self.vms = [None for n in range(self.nrep)]
self.synapses = [None for n in range(self.nrep)]
self.xmtrs = [None for n in range(self.nrep)]
self.pre_cells = [None for n in range(self.nrep)]
self.time = [None for n in range(self.nrep)]
info = {
"n_sgc": self.n_sgc,
"gmax": self.AMPA_gmax,
"stim": self.stim,
"simulator": self.simulator,
"cf": self.cf,
"sr": self.sr,
"seed": seed,
"run_duration": self.run_duration,
"temp": temp,
"dt": dt,
"init": custom_init,
}
if not self.parallelize:
for nr in range(self.nrep):
info["seed"] = seed + 3 * self.n_sgc * nr
res = RunTrial(self.cell, info)
# res contains: {'time': time, 'vm': Vm, 'xmtr': xmtr, 'pre_cells': pre_cells, 'post_cell': post_cell}
self.pre_cells[nr] = res["pre_cells"]
self.time[nr] = res["time"]
self.xmtr = {k: v.to_python() for k, v in list(res["xmtr"].items())}
self.vms[nr] = res["vm"]
self.synapses[nr] = res["synapses"]
self.xmtrs[nr] = self.xmtr
if self.parallelize:
### Use parallelize with multiple workers
tasks = list(range(len(self.nrep)))
results3 = results[:]
start = time.time()
# with mp.Parallelize(enumerate(tasks), results=results, progressDialog='processing in parallel..') as tasker:
with mp.Parallelize(enumerate(tasks), results=results) as tasker:
for i, x in tasker:
tot = 0
for j in range(size):
tot += j * x
tasker.results[i] = tot
print(
(
"\nParallel time, %d workers: %0.2f"
% (mp.Parallelize.suggestedWorkerCount(), time.time() - start)
)
)
print(("Results match serial: %s" % str(results3 == results)))
def show(self):
self.win = pg.GraphicsWindow()
self.win.setBackground("w")
Fs = self.Fs
p1 = self.win.addPlot(
title="Stimulus", row=0, col=0, labels={"bottom": "T (ms)", "left": "V"}
)
p1.plot(self.stim.time * 1000, self.stim.sound, pen=pg.mkPen("k", width=0.75))
p1.setXLink(p1)
p2 = self.win.addPlot(
title="AN spikes",
row=1,
col=0,
labels={"bottom": "T (ms)", "left": "AN spikes (first trial)"},
)
for nr in range(self.nrep):
xan = []
yan = []
for k in range(len(self.pre_cells[nr])):
r = self.pre_cells[nr][k]._spiketrain
xan.extend(r)
yr = k + np.zeros_like(r) + 0.2
yan.extend(yr)
c = pg.PlotCurveItem()
xp = np.repeat(np.array(xan), 2)
yp = np.repeat(np.array(yan), 2)
yp[1::2] = yp[::2] + 0.6
c.setData(
xp.flatten(),
yp.flatten(),
connect="pairs",
pen=pg.mkPen(pg.intColor(nr, self.nrep), hues=self.nrep, width=1.0),
)
p2.addItem(c)
p2.setXLink(p1)
p3 = self.win.addPlot(
title="%s Spikes" % self.cell,
row=2,
col=0,
labels={"bottom": "T (ms)", "left": "Trial #"},
)
xcn = []
ycn = []
xspks = []
for k in range(self.nrep):
bspk = PU.findspikes(self.time[k], self.vms[k], -35.0)
xcn.extend(bspk)
yr = k + np.zeros_like(bspk) + 0.2
ycn.extend(yr)
d = pg.PlotCurveItem()
xp = np.repeat(np.array(xcn), 2)
yp = np.repeat(np.array(ycn), 2)
yp[1::2] = yp[::2] + 0.6
d.setData(
xp.flatten(), yp.flatten(), connect="pairs", pen=pg.mkPen("k", width=1.5)
)
p3.addItem(d)
p3.setXLink(p1)
p4 = self.win.addPlot(
title="%s Vm" % self.cell,
row=3,
col=0,
labels={"bottom": "T (ms)", "left": "Vm (mV)"},
)
for nr in range(self.nrep):
p4.plot(
self.time[nr],
self.vms[nr],
pen=pg.mkPen(pg.intColor(nr, self.nrep), hues=self.nrep, width=1.0),
)
p4.setXLink(p1)
p5 = self.win.addPlot(
title="xmtr", row=0, col=1, labels={"bottom": "T (ms)", "left": "gSyn"}
)
if synapseType == "multisite":
for nr in [0]:
syn = self.synapses[nr]
j = 0
for k in range(self.n_sgc):
synapse = syn[k]
for i in range(synapse.terminal.n_rzones):
p5.plot(
self.time[nr],
self.xmtrs[nr]["xmtr%04d" % j],
pen=pg.mkPen(
pg.intColor(nr, self.nrep), hues=self.nrep, width=1.0
),
)
j = j + 1
p5.setXLink(p1)
p6 = self.win.addPlot(
title="AN PSTH",
row=1,
col=1,
labels={"bottom": "T (ms)", "left": "Sp/ms/trial"},
)
bins = np.arange(0, 200, 1)
(hist, binedges) = np.histogram(xan, bins)
curve6 = p6.plot(
binedges,
hist,
stepMode=True,
fillBrush=(0, 0, 0, 255),
brush=pg.mkBrush("k"),
fillLevel=0,
)
p7 = self.win.addPlot(
title="%s PSTH" % self.cell,
row=2,
col=1,
labels={"bottom": "T (ms)", "left": "Sp/ms/trial"},
)
bins = np.arange(0, 200, 1)
(hist, binedges) = np.histogram(xcn, bins)
curve7 = p7.plot(
binedges,
hist,
stepMode=True,
fillBrush=(0, 0, 0, 255),
brush=pg.mkBrush("k"),
fillLevel=0,
)
self.win.show()
class Variations(Protocol):
def __init__(self, runtype, runname, simulator):
self.runtype = runtype
self.runname = runname
self.simulator = str(simulator)
self.npost = 5 # number of post cells to test
self.npre = 3 # number of presynaptic cells
self.reset()
def reset(self):
super(Variations, self).reset()
def make_cells(self, cf=16e3, temp=34.0, dt=0.025):
self.pre_cells = []
for n in range(self.npre):
self.pre_cells.append(cells.DummySGC(cf=cf, sr=3))
self.post_cells = []
for n in range(self.npost):
self.post_cell = cells.Bushy.create(species=species)
self.post_cells.append(self.post_cell)
for n in range(self.npre):
for m in range(self.npost):
synapse = self.pre_cells[n].connect(self.post_cells[m])
self.synapse = synapse
synapse.terminal.relsite.Dep_Flag = False
# make variations in the postsynaptic cells
varsg = [0.5, 0.75, 1.0, 1.5, 2.0]
for i, m in enumerate(range(self.npost)):
refgbar_klt = self.post_cells[m].soma().klt.gbar
refgbar_ih = self.post_cells[m].soma().ihvcn.gbar
self.post_cells[m].soma().klt.gbar = refgbar_klt * varsg[i]
self.post_cells[m].soma().ihvcn.gbar = refgbar_ih * varsg[i]
# self.stim = sound.TonePip(rate=100e3, duration=0.1, f0=4000, dbspl=80,
# ramp_duration=2.5e-3, pip_duration=0.04,
# pip_start=[0.02])
#
# preCell.set_sound_stim(self.stim, seed=seed)
#
# self['vm'] = postCell.soma(0.5)._ref_v
# #self['prevm'] = preCell.soma(0.5)._ref_v
# for i in range(30):
# self['xmtr%d'%i] = synapse.terminal.relsite._ref_XMTR[i]
# synapse.terminal.relsite.Dep_Flag = False
def make_stimulus(
self,
stimulus="tone",
cf=16000.0,
f0=16000.0,
simulator=None,
rundur=0.2,
pipdur=0.05,
dbspl=50.0,
fmod=100.0,
dmod=0.0,
):
self.stimulus = stimulus
self.run_duration = rundur # in seconds
self.pip_duration = pipdur # in seconds
self.pip_start = [0.1] # in seconds
self.Fs = 100e3 # in Hz
self.f0 = f0 # stimulus in Hz
self.cf = cf # SGCs in Hz
self.fMod = fmod # mod freq, Hz
self.dMod = dmod # % mod depth, Hz
self.dbspl = dbspl
# self.simulator = str(simulator)
self.sr = 1 # set SR group
if self.stimulus == "SAM":
self.stim = sound.SAMTone(
rate=self.Fs,
duration=self.run_duration,
f0=self.f0,
fmod=self.fMod,
dmod=self.dMod,
dbspl=self.dbspl,
ramp_duration=2.5e-3,
pip_duration=self.pip_duration,
pip_start=self.pip_start,
)
if self.stimulus == "tone":
self.stim = sound.TonePip(
rate=self.Fs,
duration=self.run_duration,
f0=self.f0,
dbspl=self.dbspl,
ramp_duration=2.5e-3,
pip_duration=self.pip_duration,
pip_start=self.pip_start,
)
if self.stimulus == "clicks":
self.click_rate = 0.020 # msec
self.stim = sound.ClickTrain(
rate=self.Fs,
duration=self.run_duration,
f0=self.f0,
dbspl=self.dbspl,
click_start=0.010,
click_duration=100.0e-6,
click_interval=self.click_rate,
nclicks=int((self.run_duration - 0.01) / self.click_rate),
ramp_duration=2.5e-3,
)
def run(self, mode="IV", cf=16e3, temp=34.0, dt=0.025, stimamp=0, iinj=[1.0]):
self.dt = dt
self.temp = temp
self.make_cells(cf, temp, dt)
print(dir(self.pre_cells[0]))
seed = 0
j = 0
if mode == "sound":
self.make_stimulus(stimulus="tone")
for np in range(len(self.pre_cells)):
self.pre_cells[np].set_sound_stim(self.stim, seed=seed)
seed += 1
synapses.append(
pre_cells[-1].connect(
post_cell,
post_opts={"AMPAScale": 2.0, "NMDAScale": 2.0},
type=synapseType,
)
)
for i in range(synapses[-1].terminal.n_rzones):
xmtr["xmtr%04d" % j] = h.Vector()
xmtr["xmtr%04d" % j].record(
synapses[-1].terminal.relsite._ref_XMTR[i]
)
j = j + 1
synapses[
-1
].terminal.relsite.Dep_Flag = (
False
) # no depression in these simulations
print("setup to run")
self.stim_params = []
self.istim = []
self.istims = []
if mode == "pulses":
for i, pre_cell in enumerate(self.pre_cells):
stim = {}
stim["NP"] = 10
stim["Sfreq"] = 100.0 # stimulus frequency
stim["delay"] = 10.0
stim["dur"] = 0.5
stim["amp"] = stimamp
stim["PT"] = 0.0
stim["dt"] = dt
(secmd, maxt, tstims) = make_pulse(stim)
self.stim_params.append(stim)
stim["amp"] = 0.0
(secmd2, maxt2, tstims2) = make_pulse(stim)
# istim current pulse train
istim = h.iStim(0.5, sec=pre_cell.soma)
i_stim_vec = h.Vector(secmd)
self.istim.append((istim, i_stim_vec))
self["istim%02d" % i] = istim._ref_i
self.postpars = []
self.poststims = []
if mode == "IV":
for i, post_cell in enumerate(self.post_cells):
pstim = {}
pstim["NP"] = 1
pstim["Sfreq"] = 100.0 # stimulus frequency
pstim["delay"] = 10.0
pstim["dur"] = 100
pstim["amp"] = iinj[0]
pstim["PT"] = 0.0
pstim["dt"] = dt
(secmd, maxt, tstims) = make_pulse(pstim)
self.postpars.append(pstim)
# istim current pulse train
pistim = h.iStim(0.5, sec=post_cell.soma)
pi_stim_vec = h.Vector(secmd)
self.poststims.append((pistim, pi_stim_vec))
self["poststim%02d" % i] = pistim._ref_i
#
# Run simulation
#
h.celsius = temp
self.temp = temp
# first find rmp for each cell
for m in range(self.npost):
self.post_cells[m].vm0 = None
self.post_cells[m].cell_initialize()
# set starting voltage...
self.post_cells[m].soma(0.5).v = self.post_cells[m].vm0
h.dt = 0.02
h.t = 0 # run a bit to find true stable rmp
h.tstop = 20.0
h.batch_save()
h.batch_run(h.tstop, h.dt, "v.dat")
# order matters: don't set these up until we need to
self["t"] = h._ref_t
# set up recordings
if mode == "pulses":
for i in range(self.npre):
self.istim[i][1].play(self.istim[i][0]._ref_i, dt, 0)
self["v_pre%02d" % i] = self.pre_cells[i].soma(0.5)._ref_v
for m in range(self.npost):
if mode == "IV":
self.poststims[m][1].play(self.poststims[m][0]._ref_i, dt, 0)
self["v_post%02d" % m] = self.post_cells[m].soma(0.5)._ref_v
h.finitialize() # init and instantiate recordings
print("running")
h.t = 0.0
h.tstop = 200.0
h.batch_run(h.tstop, h.dt, "v.dat")
# while h.t < h.tstop: # get data (do not use h.run() - try it and see why!)
# h.fadvance()
def runIV(self, parallelize):
self.civ = {}
self.iiv = []
varsg = np.linspace(
0.25, 2.0, int((2.0 - 0.25) / 0.25) + 1
) # [0.5, 0.75, 1.0, 1.5, 2.0] # covary Ih and gklt in constant ratio
self.gklts = np.zeros(len(varsg))
self.ghs = np.zeros(len(varsg))
if not parallelize:
for n in range(self.npost):
self.civ[n] = []
# self.iiv[c] = []
start = time.time()
for inj in np.arange(-1.0, 1.51, 0.5):
self.run(mode="IV", temp=34.0, dt=0.025, stimamp=10, iinj=[inj])
print("ran for current = ", inj)
for c in range(self.npost):
self.civ[c].append(self["v_post%02d" % c])
if c == 0: # just the first
self.iiv.append(self["poststim%02d" % c])
print(("\nSerial time, %0.2f" % (time.time() - start)))
if runname is not None:
f = open(runname, "w")
pickle.dump({"t": self["t"], "v": self.civ, "i": self.iiv}, f)
f.close()
else:
# mp.parallelizer.multiprocessing.cpu_count()
nworker = 16
self.npost = len(varsg)
tasks = list(range(self.npost))
results = [None] * len(tasks)
ivc = [None] * len(tasks)
start = time.time()
# with mp.Parallelize(enumerate(tasks), results=results, workers=nworker, progressDialog='processing in parallel..') as tasker:
with mp.Parallelize(
enumerate(tasks), results=results, workers=nworker
) as tasker:
for i, x in tasker:
post_cell = cells.Bushy.create(species=species)
refgbar_klt = post_cell.soma().klt.gbar
refgbar_ih = post_cell.soma().ihvcn.gbar
gklts = refgbar_klt * varsg[i]
ghs = refgbar_ih * varsg[i]
post_cell.soma().klt.gbar = gklts
post_cell.soma().ihvcn.gbar = ghs
post_cell.initial_mechanisms = (
None
) # forget the mechanisms we set up initially
post_cell.save_all_mechs() # and save new ones because we are explicitely varying them
ivc[i] = iv_curve.IVCurve()
ivc[i].run({"pulse": [(-1.0, 1.5, 0.25)]}, post_cell)
tasker.results[i] = {
"v": ivc[i].voltage_traces,
"i": ivc[i].current_traces,
"t": ivc[i].time_values,
"gklt": gklts,
"gh": ghs,
}
print(
(
"\nParallel time: %d workers, %0.2f sec"
% (nworker, time.time() - start)
)
)
cell_info = {"varrange": varsg}
print(cell_info)
res = {"cells": cell_info, "results": results}
if runname is not None:
f = open(runname, "wb")
pickle.dump(res, f, -1)
f.close()
def runSound(self, parallelize=False):
self.civ = {}
self.iiv = []
if not parallelize:
pass
if parallelize:
nworker = 16
varsg = np.linspace(
0.25, 2.0, int((2.0 - 0.25) / 0.25) + 1
) # [0.5, 0.75, 1.0, 1.5, 2.0] # covary Ih and gklt in constant ratio
self.npost = len(varsg)
nrep = 25
tasks = list(range(self.npost))
results = [None] * len(tasks)
ivc = [None] * len(tasks)
gklts = np.zeros(len(varsg))
ghs = np.zeros(len(varsg))
start = time.time()
seed = 0
cf = 16000.0
f0 = 16000.0
rundur = 0.25 # seconds
pipdur = 0.1 # seconds
dbspl = 50.0
fmod = 40.0
dmod = 0.0
stimulus = "tone"
# with mp.Parallelize(enumerate(tasks), results=results, workers=nworker, progressDialog='processing in parallel..') as tasker:
with mp.Parallelize(
enumerate(tasks), results=results, workers=nworker
) as tasker:
for i, x in tasker:
post_cell = cells.Bushy.create(species=species)
h.celsius = 34
self.temp = h.celsius
refgbar_klt = post_cell.soma().klt.gbar
refgbar_ih = post_cell.soma().ihvcn.gbar
gklts[i] = refgbar_klt * varsg[i]
ghs[i] = refgbar_ih * varsg[i]
post_cell.soma().klt.gbar = gklts[i]
post_cell.soma().ihvcn.gbar = ghs[i]
post_cell.initial_mechanisms = (
None
) # forget the mechanisms we set up initially
post_cell.save_all_mechs() # and save new ones because we are explicitely varying them
self.make_stimulus(
stimulus=stimulus,
cf=cf,
f0=f0,
rundur=rundur,
pipdur=pipdur,
dbspl=50.0,
simulator=self.simulator,
fmod=fmod,
dmod=dmod,
)
pre_cells = []
synapses = []
for n in range(self.npre):
pre_cells.append(cells.DummySGC(cf=cf, sr=2))
synapses.append(
pre_cells[n].connect(post_cell, type=synapseType)
)
v_reps = []
i_reps = []
p_reps = [] # pre spike on 0'th sgc
for j in range(nrep):
for prec in range(len(pre_cells)):
pre_cells[prec].set_sound_stim(
self.stim, seed=seed, simulator=self.simulator
)
seed += 1
# for i in range(synapses[-1].terminal.n_rzones):
# xmtr['xmtr%04d'%j] = h.Vector()
# xmtr['xmtr%04d'%j].record(synapses[-1].terminal.relsite._ref_XMTR[i])
# j = j + 1
# synapses[-1].terminal.relsite.Dep_Flag = False # no depression in these simulations
#
# Run simulation
#
post_cell.vm0 = None
post_cell.cell_initialize()
# set starting voltage...
post_cell.soma(0.5).v = post_cell.vm0
h.dt = 0.02
h.t = 0 # run a bit to find true stable rmp
h.tstop = 20.0
h.batch_save()
h.batch_run(h.tstop, h.dt, "v.dat")
self["t"] = h._ref_t
# set up recordings
self["v_post%02d" % j] = post_cell.soma(0.5)._ref_v
h.finitialize() # init and instantiate recordings
print("running %d" % i)
h.t = 0.0
h.tstop = rundur * 1000.0 # rundur is in seconds.
post_cell.check_all_mechs() # make sure no further changes were introduced before run.
h.batch_run(h.tstop, h.dt, "v.dat")
v_reps.append(self["v_post%02d" % j])
i_reps.append(0.0 * self["v_post%02d" % j])
p_reps.append(pre_cells[0]._stvec.to_python())
tasker.results[i] = {
"v": v_reps,
"i": i_reps,
"t": self["t"],
"pre": pre_cells[0]._stvec.to_python(),
}
print(
(
"\nParallel time: %d workers, %0.2f sec"
% (nworker, time.time() - start)
)
)
cell_info = {"gklt": gklts, "gh": ghs}
stim_info = {
"nreps": nrep,
"cf": cf,
"f0": f0,
"rundur": rundur,
"pipdur": pipdur,
"dbspl": dbspl,
"fmod": fmod,
"dmod": dmod,
}
res = {"cells": cell_info, "stim": stim_info, "results": results}
if runname is not None:
f = open(runname, "wb")
pickle.dump(res, f, -1)
f.close()
#
# def show(self):
#
# self.win = pg.GraphicsWindow()
# self.win.resize(1000, 1000)
#
# cmd_plot = self.win.addPlot(title='Stim')
# for i in range(len(self.pre_cells)):
# cmd_plot.plot(self['t'], self['istim%02d' % i], pen=pg.mkPen(pg.intColor(i, len(self.pre_cells)), hues=len(self.pre_cells), width=1.0))
#
# self.win.nextRow()
# pre_plot = self.win.addPlot(title='SGC Vm')
# for i in range(len(self.pre_cells)):
# pre_plot.plot(self['t'], self['v_pre%02d'%i], pen=pg.mkPen(pg.intColor(i, len(self.pre_cells)), hues=len(self.pre_cells), width=1.0))
#
# self.win.nextRow()
# post_plot = self.win.addPlot(title='Post Cell: %s' % self.post_cell.type)
# for m in range(self.npost):
# post_plot.plot(self['t'], self['v_post%02d' % m],
# pen=pg.mkPen(pg.intColor(m, len(self.post_cells)), hues=len(self.post_cells), width=1.0))
def showpicklediv(name):
f = open(name, "rb")
result = pickle.load(f)
f.close()
d = result["results"]
ncells = len(d)
vr = result["cells"]["varrange"]
fig, ax = mpl.subplots(ncells + 1, 2, figsize=(8.5, 11.0))
# fig.set_size_inches(8.5, 11., forward=True)
for ni in range(len(d[0]["i"])):
ax[-1, 0].plot(d[0]["t"], d[0]["i"][ni], "k", linewidth=0.5)
ax[-1, 0].set_ylim([-2.0, 2.0])
for nc in range(ncells):
for ni in range(len(d[nc]["v"])):
ax[nc, 0].plot(d[nc]["t"], d[nc]["v"][ni], "k", linewidth=0.5)
ax[nc, 0].set_ylim([-180.0, 40.0])
if ni == 0:
ax[nc, 0].annotate("%.2f" % vr[nc], (180.0, 20.0))
PH.nice_plot(ax.ravel().tolist())
PH.noaxes(ax.ravel()[:-1].tolist())
PH.calbar(
ax[0, 0],
calbar=[120.0, -130.0, 25.0, 50.0],
axesoff=True,
orient="left",
unitNames={"x": "ms", "y": "mV"},
fontsize=9,
weight="normal",
font="Arial",
)
PH.calbar(
ax[-1, 0],
calbar=[120.0, 0.5, 25.0, 1.0],
axesoff=True,
orient="left",
unitNames={"x": "ms", "y": "nA"},
fontsize=9,
weight="normal",
font="Arial",
)
mpl.show()
def vector_plot(f, r, l, yp=None):
ax2 = f.add_axes([yp.x1 - 0.06, yp.y1 - 0.06, 0.05, 0.05], polar=True)
r = np.repeat(r, 3)
l = np.repeat(l, 3)
for i in range(2, len(l), 3):
l[i] = 0.0
l[i - 2] = 0.0
ax2.plot(r, l, lw=0.5)
# ax2.arrow(0, 0, np.mean(l), np.mean(r), head_width=0.05, head_length=-0.1, fc='r', ec='r')
def phase_hist(f, spkphase, yp=None):
ax2 = f.add_axes([yp.x1 - 0.06, yp.y1 - 0.06, 0.05, 0.05])
n, bins = np.histogram(spkphase, np.linspace(0.0, 2 * np.pi, 91.0), density=False)
ax2.bar(bins[:-1], n, width=bins[1], facecolor="k", alpha=0.75)
def clean_spiketimes(spikeTimes, mindT=0.7):
"""
Clean up spike time array, removing all less than mindT
spikeTimes is a 1-D list or array
mindT is difference in time, same units as spikeTimes
If 1 or 0 spikes in array, just return the array
"""
if len(spikeTimes) > 1:
dst = np.diff(spikeTimes)
st = np.array(spikeTimes[0]) # get first spike
sok = np.where(dst > mindT)
st = np.append(st, [spikeTimes[s + 1] for s in sok])
# print st
spikeTimes = st
return spikeTimes
def showplots(name):
"""
Show traces from sound stimulation - without current injection
"""
f = open(name, "rb")
d = pickle.load(f)
f.close()
ncells = len(d["results"])
stiminfo = d["stim"]
dur = stiminfo["rundur"] * 1000.0
print("dur: ", dur)
print("stim info: ")
print(" fmod: ", stiminfo["fmod"])
print(" dmod: ", stiminfo["dmod"])
print(" f0: ", stiminfo["f0"])
print(" cf: ", stiminfo["cf"])
varsg = np.linspace(0.25, 2.0, int((2.0 - 0.25) / 0.25) + 1) # was not stored...
fig, ax = mpl.subplots(ncells + 1, 2, figsize=(8.5, 11.0))
spikelists = [[]] * ncells
prespikes = [[]] * ncells
xmin = 50.0
for i in range(ncells):
vdat = d["results"][i]["v"]
idat = d["results"][i]["i"]
tdat = d["results"][i]["t"]
pdat = d["results"][i]["pre"]
PH.noaxes(ax[i, 0])
# if i == 0:
# PH.calbar(ax[0, 0], calbar=[120., -120., 25., 20.], axesoff=True, orient='left',
# unitNames={'x': 'ms', 'y': 'mV'}, fontsize=9, weight='normal', font='Arial')
for j in range(len(vdat)):
if j == 2:
ax[i, 0].plot(tdat - xmin, vdat[j], "k", linewidth=0.5)
if j == 0:
ax[i, 0].annotate("%.2f" % varsg[i], (180.0, 20.0))
ax[i, 0].set_xlim([0, dur - xmin])
ax[i, 0].set_ylim([-75, 50])
PH.referenceline(
ax[i, 0],
reference=-62.0,
limits=None,
color="0.33",
linestyle="--",
linewidth=0.5,
dashes=[3, 3],
)
for j in range(len(vdat)):
detected = PU.findspikes(tdat, vdat[j], -20.0)
detected = clean_spiketimes(detected)
spikelists[i].extend(detected)
if j == 0:
n, bins = np.histogram(
detected, np.linspace(0.0, dur, 201), density=False
)
else:
m, bins = np.histogram(
detected, np.linspace(0.0, dur, 201), density=False
)
n += m
prespikes[i].extend(pdat)
if j == 0:
n, bins = np.histogram(pdat, np.linspace(0.0, dur, 201), density=False)
else:
m, bins = np.histogram(pdat, np.linspace(0.0, dur, 201), density=False)
n += m
ax[i, 1].bar(bins[:-1] - xmin, n, width=bins[1], facecolor="k", alpha=0.75)
ax[i, 1].set_xlim([0, dur - xmin])
ax[i, 1].set_ylim([0, 30])
vs = PU.vector_strength(spikelists[i], stiminfo["fmod"])
pre_vs = PU.vector_strength(prespikes[i], stiminfo["fmod"])
# print 'pre: ', pre_vs
# print 'post: ', vs
# apos = ax[i,1].get_position()
# ax[i, 1].set_title('VS = %4.3f' % pre_vs['r'])
# # vector_plot(fig, vs['ph'], np.ones(len(vs['ph'])), yp = apos)
# phase_hist(fig, vs['ph'], yp=apos)
# phase_hist(fig, pre_vs['ph'], yp=apos)
prot = Variations(runtype, runname, "cochlea")
# stim_info = {'nreps': nrep, 'cf': cf, 'f0': f0, 'rundur': rundur, 'pipdur': pipdur, 'dbspl': dbspl, 'fmod': fmod, 'dmod': dmod}
if stiminfo["dmod"] > 0:
stimulus = "SAM"
else:
stimulus = "tone"
prot.make_stimulus(
stimulus=stimulus,
cf=stiminfo["cf"],
f0=stiminfo["f0"],
simulator=None,
rundur=stiminfo["rundur"],
pipdur=stiminfo["pipdur"],
dbspl=stiminfo["dbspl"],
fmod=stiminfo["fmod"],
dmod=stiminfo["dmod"],
)
ax[-1, 1].plot(
prot.stim.time * 1000.0 - xmin, prot.stim.sound, "k-", linewidth=0.75
)
ax[-1, 1].set_xlim([0, (dur - xmin)])
PH.noaxes(ax[-1, 0])
ax[-1, 0].set_xlim([0, dur - xmin])
ax[-1, 0].set_ylim([-75, 50])
# PH.referenceline(ax[-1, 0], reference=-62.0, limits=None, color='0.33', linestyle='--' ,linewidth=0.5, dashes=[3, 3])
PH.calbar(
ax[-1, 0],
calbar=[20.0, 0.0, 25.0, 20.0],
axesoff=True,
orient="left",
unitNames={"x": "ms", "y": "mV"},
fontsize=9,
weight="normal",
font="Arial",
)
PH.cleanAxes(ax.ravel().tolist())
mpl.show()
if __name__ == "__main__":
runname = None
panel = None
if len(sys.argv) == 2:
panel = sys.argv[1]
if panel == "a":
runtype = "IV"
runname = "Figure6_IV"
elif panel == "d":
runtype = "sound"
runname = "Figure6_AN"
else:
runtype = panel
if panel is None:
raise ValueError("Must specify figure panel to generate: 'a', 'b'")
if runtype in ["sound", "IV"]:
prot = Variations(runtype, runname, "cochlea")
if runtype == "IV":
start_time = timeit.default_timer()
prot.runIV(parallelize=True)
elapsed = timeit.default_timer() - start_time
print(("Elapsed time for IV simulations: %f" % (elapsed)))
showpicklediv(runname)
if runtype == "sound":
start_time = timeit.default_timer()
prot.runSound(parallelize=True)
elapsed = timeit.default_timer() - start_time
print(("Elapsed time for AN simulations: %f" % (elapsed)))
showplots(runname)
# pg.show()
# if sys.flags.interactive == 0:
# pg.QtGui.QApplication.exec_()
elif runtype in ["showiv"]:
showpicklediv(runname)
elif runtype in ["plots"]:
showplots(runname)
else:
print("run type should be one of sound, IV, showiv, plots")