model of DCN pyramidal neuron
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

705 lines
24 KiB

"""
Test principal cell responses to tone pips of varying frequency and intensity.
This is an example of model construction from a very high level--we specify
only which populations of cells are present and which ones should be connected.
The population and cell classes take care of all the details of generating the
network needed to support a small number of output cells.
Note: run time for this example can be very long. To speed things up, reduce
n_frequencies or n_levels, or reduce the number of selected output cells (see
cells_per_band).
"""
import os, sys, time
from collections import OrderedDict
import numpy as np
import scipy.stats
from neuron import h
import pyqtgraph as pg
import pyqtgraph.multiprocess as mp
from pyqtgraph.Qt import QtGui, QtCore
from cnmodel import populations
from cnmodel.util import sound, random_seed
from cnmodel.protocols import Protocol
import timeit
class CNSoundStim(Protocol):
def __init__(self, seed, temp=34.0, dt=0.025, synapsetype="simple"):
Protocol.__init__(self)
self.seed = seed
self.temp = temp
self.dt = dt
# self.synapsetype = synapsetype # simple or multisite
# Seed now to ensure network generation is stable
random_seed.set_seed(seed)
# Create cell populations.
# This creates a complete set of _virtual_ cells for each population. No
# cells are instantiated at this point.
self.sgc = populations.SGC(model="dummy")
self.bushy = populations.Bushy()
self.dstellate = populations.DStellate()
self.tstellate = populations.TStellate()
self.tuberculoventral = populations.Tuberculoventral()
pops = [
self.sgc,
self.dstellate,
self.tuberculoventral,
self.tstellate,
self.bushy,
]
self.populations = OrderedDict([(pop.type, pop) for pop in pops])
# set synapse type to use in the sgc population - simple is fast, multisite is slower
# (eventually, we could do this for all synapse types..)
self.sgc._synapsetype = synapsetype
# Connect populations.
# This only defines the connections between populations; no synapses are
# created at this stage.
self.sgc.connect(
self.bushy, self.dstellate, self.tuberculoventral, self.tstellate
)
self.dstellate.connect(
self.bushy, self.tstellate
) # should connect to dstellate as well?
self.tuberculoventral.connect(self.bushy, self.tstellate)
self.tstellate.connect(self.bushy)
# Select cells to record from.
# At this time, we actually instantiate the selected cells.
# Pick a single bushy cell near 16kHz, with medium-SR inputs
bc = self.bushy.cells
msr_cells = bc[bc["sgc_sr"] == 1] # filter for msr cells
ind = np.argmin(np.abs(msr_cells["cf"] - 16e3)) # find the one closest to 16kHz
cell_id = msr_cells[ind]["id"]
self.bushy.create_cells([cell_id]) # instantiate just one cell
# Now create the supporting circuitry needed to drive the cells we selected.
# At this time, cells are created in all populations and automatically
# connected with synapses.
self.bushy.resolve_inputs(depth=2)
# self.tstellate.resolve_inputs(depth=2)
# Note that using depth=2 indicates the level of recursion to use when
# resolving inputs. For example, resolving inputs for the bushy cell population
# (level 1) creates presynaptic cells in the dstellate population, and resolving
# inputs for the dstellate population (level 2) creates presynaptic cells in the
# sgc population.
def run(self, stim, seed):
"""Run the network simulation with *stim* as the sound source and a unique
*seed* used to configure the random number generators.
"""
self.reset()
# Generate 2 new seeds for the SGC spike generator and for the NEURON simulation
rs = np.random.RandomState()
rs.seed(self.seed ^ seed)
seed1, seed2 = rs.randint(0, 2 ** 32, 2)
random_seed.set_seed(seed1)
self.sgc.set_seed(seed2)
self.sgc.set_sound_stim(stim, parallel=False)
# set up recording vectors
for pop in self.bushy, self.dstellate, self.tstellate, self.tuberculoventral:
for ind in pop.real_cells():
cell = pop.get_cell(ind)
self[cell] = cell.soma(0.5)._ref_v
self["t"] = h._ref_t
h.tstop = stim.duration * 1000
h.celsius = self.temp
h.dt = self.dt
print("init..")
self.custom_init()
print("start..")
last_update = time.time()
while h.t < h.tstop:
h.fadvance()
now = time.time()
if now - last_update > 1.0:
print("%0.2f / %0.2f" % (h.t, h.tstop))
last_update = now
# record vsoma and spike times for all cells
vec = {}
for k in self._vectors:
v = self[k].copy()
if k == "t":
vec[k] = v
continue
spike_inds = np.argwhere((v[1:] > -20) & (v[:-1] <= -20))[:, 0]
spikes = self["t"][spike_inds]
pop = k.type
cell_ind = getattr(self, pop).get_cell_index(k)
vec[(pop, cell_ind)] = [v, spikes]
# record SGC spike trains
for ind in self.sgc.real_cells():
cell = self.sgc.get_cell(ind)
vec[("sgc", ind)] = [None, cell._spiketrain]
return vec
class NetworkSimDisplay(pg.QtGui.QSplitter):
def __init__(self, prot, results, baseline, response):
pg.QtGui.QSplitter.__init__(self, QtCore.Qt.Horizontal)
self.selected_cell = None
self.prot = prot
self.baseline = baseline # (start, stop)
self.response = response # (start, stop)
self.ctrl = QtGui.QWidget()
self.layout = pg.QtGui.QVBoxLayout()
self.layout.setContentsMargins(0, 0, 0, 0)
self.ctrl.setLayout(self.layout)
self.addWidget(self.ctrl)
self.nv = NetworkVisualizer(prot.populations)
self.layout.addWidget(self.nv)
self.nv.cell_selected.connect(self.nv_cell_selected)
self.stim_combo = pg.QtGui.QComboBox()
self.layout.addWidget(self.stim_combo)
self.trial_combo = pg.QtGui.QComboBox()
self.layout.addWidget(self.trial_combo)
self.results = OrderedDict()
self.stim_order = []
freqs = set()
levels = set()
max_iter = 0
for k, v in list(results.items()):
f0, dbspl, iteration = k
max_iter = max(max_iter, iteration)
stim, result = v
key = "f0: %0.0f dBspl: %0.0f" % (f0, dbspl)
self.results.setdefault(key, [stim, {}])
self.results[key][1][iteration] = result
self.stim_order.append((f0, dbspl))
freqs.add(f0)
levels.add(dbspl)
self.stim_combo.addItem(key)
self.freqs = sorted(list(freqs))
self.levels = sorted(list(levels))
self.iterations = max_iter + 1
self.trial_combo.addItem("all trials")
for i in range(self.iterations):
self.trial_combo.addItem(str(i))
self.stim_combo.currentIndexChanged.connect(self.stim_selected)
self.trial_combo.currentIndexChanged.connect(self.trial_selected)
self.tuning_plot = pg.PlotWidget()
self.tuning_plot.setLogMode(x=True, y=False)
self.tuning_plot.scene().sigMouseClicked.connect(self.tuning_plot_clicked)
self.layout.addWidget(self.tuning_plot)
self.tuning_img = pg.ImageItem()
self.tuning_plot.addItem(self.tuning_img)
df = np.log10(self.freqs[1]) - np.log10(self.freqs[0])
dl = self.levels[1] - self.levels[0]
self.stim_rect = QtGui.QGraphicsRectItem(QtCore.QRectF(0, 0, df, dl))
self.stim_rect.setPen(pg.mkPen("c"))
self.stim_rect.setZValue(20)
self.tuning_plot.addItem(self.stim_rect)
# self.network_tree = NetworkTree(self.prot)
# self.layout.addWidget(self.network_tree)
self.pw = pg.GraphicsLayoutWidget()
self.addWidget(self.pw)
self.stim_plot = self.pw.addPlot()
self.pw.ci.layout.setRowFixedHeight(0, 100)
self.pw.nextRow()
self.cell_plot = self.pw.addPlot(labels={"left": "Vm"})
self.pw.nextRow()
self.input_plot = self.pw.addPlot(
labels={"left": "input #", "bottom": "time"}, title="Input spike times"
)
self.input_plot.setXLink(self.cell_plot)
self.stim_plot.setXLink(self.cell_plot)
self.stim_selected()
def update_stim_plot(self):
stim = self.selected_stim
self.stim_plot.plot(stim.time * 1000, stim.sound, clear=True, antialias=True)
def update_raster_plot(self):
self.input_plot.clear()
if self.selected_cell is None:
return
pop, ind = self.selected_cell
rec = pop._cells[ind]
i = 0
plots = []
# plot spike times for all presynaptic cells
labels = []
if rec["connections"] == 0:
return
pop_colors = {
"dstellate": "y",
"tuberculoventral": "r",
"sgc": "g",
"tstellate": "b",
}
pop_symbols = {
"dstellate": "x",
"tuberculoventral": "+",
"sgc": "t",
"tstellate": "o",
}
pop_order = [self.prot.sgc, self.prot.dstellate, self.prot.tuberculoventral]
trials = self.selected_trials()
for pop in pop_order:
pre_inds = rec["connections"].get(pop, [])
for preind in pre_inds:
# iterate over all trials
for j in trials:
result = self.selected_results[j]
spikes = result[(pop.type, preind)][1]
y = np.ones(len(spikes)) * i + j / (
2.0 * len(self.selected_results)
)
self.input_plot.plot(
spikes,
y,
pen=None,
symbolBrush=pop_colors[pop.type],
symbol="+",
symbolPen=None,
)
i += 1
labels.append(pop.type + " " + str(preind))
self.input_plot.getAxis("left").setTicks([list(enumerate(labels))])
def update_cell_plot(self):
self.cell_plot.clear()
if self.selected_cell is None:
return
pop, cell_ind = self.selected_cell
self.cell_plot.setTitle(
"%s %d %s" % (pop.type, cell_ind, str(self.stim_combo.currentText()))
)
trials = self.selected_trials()
for i in trials:
result = self.selected_results[i]
y = result[(pop.type, cell_ind)][0]
if y is not None:
p = self.cell_plot.plot(
self.selected_results[0]["t"],
y,
name="%s-%d" % self.selected_cell,
antialias=True,
pen=(i, len(self.selected_results) * 1.5),
)
# p.curve.setClickable(True)
# p.sigClicked.connect(self.cell_curve_clicked)
# p.cell_ind = ind
def tuning_plot_clicked(self, event):
spos = event.scenePos()
stimpos = self.tuning_plot.plotItem.vb.mapSceneToView(spos)
x = 10 ** stimpos.x()
y = stimpos.y()
best = None
for stim, result in list(self.results.values()):
f0 = stim.opts["f0"]
dbspl = stim.opts["dbspl"]
if x < f0 or y < dbspl:
continue
if best is None:
best = stim
continue
if f0 > best.opts["f0"] or dbspl > best.opts["dbspl"]:
best = stim
continue
if best is None:
return
self.select_stim(best.opts["f0"], best.opts["dbspl"])
def nv_cell_selected(self, nv, cell):
self.select_cell(*cell)
def stim_selected(self):
key = str(self.stim_combo.currentText())
results = self.results[key]
self.selected_results = results[1]
self.selected_stim = results[0]
self.update_stim_plot()
self.update_raster_plot()
self.update_cell_plot()
self.stim_rect.setPos(np.log10(results[0].opts["f0"]), results[0].opts["dbspl"])
def trial_selected(self):
self.update_raster_plot()
self.update_cell_plot()
self.update_tuning()
def selected_trials(self):
if self.trial_combo.currentIndex() == 0:
return list(range(self.iterations))
else:
return [self.trial_combo.currentIndex() - 1]
def select_stim(self, f0, dbspl):
i = self.stim_order.index((f0, dbspl))
self.stim_combo.setCurrentIndex(i)
def select_cell(self, pop, cell_id):
self.selected_cell = pop, cell_id
self.update_tuning()
self.update_cell_plot()
self.update_raster_plot()
# def cell_curve_clicked(self, c):
# if self.selected is not None:
# pen = self.selected.curve.opts['pen']
# pen.setWidth(1)
# self.selected.setPen(pen)
# pen = c.curve.opts['pen']
# pen.setWidth(3)
# c.setPen(pen)
# self.selected = c
# self.show_cell(c.cell_ind)
def update_tuning(self):
# update matrix image
if self.selected_cell is None:
return
pop, ind = self.selected_cell
fvals = set()
lvals = set()
# first get lists of all frequencies and levels in the matrix
for stim, vec in list(self.results.values()):
fvals.add(stim.key()["f0"])
lvals.add(stim.key()["dbspl"])
fvals = sorted(list(fvals))
lvals = sorted(list(lvals))
# Get spontaneous rate statistics
spont_spikes = 0
spont_time = 0
for stim, iterations in list(self.results.values()):
for vec in list(iterations.values()):
spikes = vec[(pop.type, ind)][1]
spont_spikes += (
(spikes >= self.baseline[0]) & (spikes < self.baseline[1])
).sum()
spont_time += self.baseline[1] - self.baseline[0]
spont_rate = spont_spikes / spont_time
# next count the number of spikes for the selected cell at each point in the matrix
matrix = np.zeros((len(fvals), len(lvals)))
trials = self.selected_trials()
for stim, iteration in list(self.results.values()):
for i in trials:
vec = iteration[i]
spikes = vec[(pop.type, ind)][1]
n_spikes = (
(spikes >= self.response[0]) & (spikes < self.response[1])
).sum()
i = fvals.index(stim.key()["f0"])
j = lvals.index(stim.key()["dbspl"])
matrix[i, j] += n_spikes - spont_rate * (
self.response[1] - self.response[0]
)
matrix /= self.iterations
# plot and scale the matrix image
# note that the origin (lower left) of each image pixel indicates its actual test freq/level.
self.tuning_img.setImage(matrix)
self.tuning_img.resetTransform()
self.tuning_img.setPos(np.log10(min(fvals)), min(lvals))
self.tuning_img.scale(
(np.log10(max(fvals)) - np.log10(min(fvals))) / (len(fvals) - 1),
(max(lvals) - min(lvals)) / (len(lvals) - 1),
)
class NetworkTree(QtGui.QTreeWidget):
def __init__(self, prot):
self.prot = prot
QtGui.QTreeWidget.__init__(self)
self.setColumnCount(2)
self.update_tree()
def update_tree(self):
for pop_name in ["bushy", "tstellate", "dstellate", "tuberculoventral", "sgc"]:
if not hasattr(self.prot, pop_name):
continue
pop = getattr(self.prot, pop_name)
grp = QtGui.QTreeWidgetItem([pop_name])
self.addTopLevelItem(grp)
for cell in pop.real_cells():
self.add_cell(grp, pop, cell)
def add_cell(self, grp_item, pop, cell):
item = QtGui.QTreeWidgetItem([str(cell)])
grp_item.addChild(item)
all_conns = pop.cell_connections(cell)
if all_conns == 0:
return
for cpop, conns in list(all_conns.items()):
pop_grp = QtGui.QTreeWidgetItem([cpop.type, str(conns)])
item.addChild(pop_grp)
class NetworkVisualizer(pg.PlotWidget):
cell_selected = pg.QtCore.Signal(object, object)
def __init__(self, populations):
self.pops = populations
pg.PlotWidget.__init__(self)
self.setLogMode(x=True, y=False)
self.cells = pg.ScatterPlotItem(clickable=True)
self.cells.setZValue(10)
self.addItem(self.cells)
self.cells.sigClicked.connect(self.cells_clicked)
self.selected = pg.ScatterPlotItem()
self.selected.setZValue(20)
self.addItem(self.selected)
self.connections = pg.PlotCurveItem()
self.addItem(self.connections)
# first assign positions of all cells
cells = []
for y, pop in enumerate(self.pops.values()):
pop.cell_spots = []
pop.fwd_connections = {}
for i, cell in enumerate(pop._cells):
pos = (np.log10(cell["cf"]), y)
real = cell["cell"] != 0
if not real:
pop.cell_spots.append(None)
continue
brush = pg.mkBrush("b") if real else pg.mkBrush(255, 255, 255, 30)
spot = {
"x": pos[0],
"y": pos[1],
"symbol": "o" if real else "x",
"brush": brush,
"pen": None,
"data": (pop, i),
}
cells.append(spot)
pop.cell_spots.append(spot)
self.cells.setData(cells)
self.getAxis("left").setTicks([list(enumerate(self.pops.keys()))])
# now assign connection lines and record forward connectivity
con_x = []
con_y = []
for pop in list(self.pops.values()):
for i, cell in enumerate(pop._cells):
conns = cell["connections"]
if conns == 0:
continue
for prepop, precells in list(conns.items()):
spot = pop.cell_spots[i]
if spot is None:
continue
p1 = spot["x"], spot["y"]
for j in precells:
prepop.fwd_connections.setdefault(j, [])
prepop.fwd_connections[j].append((pop, i))
spot2 = prepop.cell_spots[j]
if spot2 is None:
return
p2 = spot2["x"], spot2["y"]
con_x.extend([p1[0], p2[0]])
con_y.extend([p1[1], p2[1]])
self.connections.setData(
x=con_x, y=con_y, connect="pairs", pen=(255, 255, 255, 60)
)
def cells_clicked(self, *args):
selected = None
for spot in args[1]:
# find the first real cell
pop, i = spot.data()
if pop._cells[i]["cell"] != 0:
selected = spot
break
if selected is None:
self.selected.hide()
return
rec = pop._cells[i]
pos = selected.pos()
spots = [
{
"x": pos.x(),
"y": pos.y(),
"size": 15,
"symbol": "o",
"pen": "y",
"brush": "b",
}
]
# display presynaptic cells
if rec["connections"] != 0:
for prepop, preinds in list(rec["connections"].items()):
for preind in preinds:
spot = prepop.cell_spots[preind].copy()
spot["size"] = 15
spot["brush"] = "r"
spots.append(spot)
# display postsynaptic cells
for postpop, postind in pop.fwd_connections.get(i, []):
spot = postpop.cell_spots[postind].copy()
spot["size"] = 15
spot["brush"] = "g"
spots.append(spot)
self.selected.setData(spots)
self.selected.show()
self.cell_selected.emit(self, selected.data())
if __name__ == "__main__":
import pickle, os, sys
app = pg.mkQApp()
pg.dbg()
# Create a sound stimulus and use it to generate spike trains for the SGC
# population
stims = []
parallel = True
nreps = 5
fmin = 4e3
fmax = 32e3
octavespacing = 1 / 8.0
# octavespacing = 1.
n_frequencies = int(np.log2(fmax / fmin) / octavespacing) + 1
fvals = (
np.logspace(
np.log2(fmin / 1000.0),
np.log2(fmax / 1000.0),
num=n_frequencies,
endpoint=True,
base=2,
)
* 1000.0
)
n_levels = 11
# n_levels = 3
levels = np.linspace(20, 100, n_levels)
print(("Frequencies:", fvals / 1000.0))
print(("Levels:", levels))
syntype = "multisite"
path = os.path.dirname(__file__)
cachepath = os.path.join(path, "cache")
if not os.path.isdir(cachepath):
os.mkdir(cachepath)
seed = 34657845
prot = CNSoundStim(seed=seed, synapsetype=syntype)
i = 0
start_time = timeit.default_timer()
# stimpar = {'dur': 0.06, 'pip': 0.025, 'start': [0.02], 'baseline': [10, 20], 'response': [20, 45]}
stimpar = {
"dur": 0.2,
"pip": 0.04,
"start": [0.1],
"baseline": [50, 100],
"response": [100, 140],
}
tasks = []
for f in fvals:
for db in levels:
for i in range(nreps):
tasks.append((f, db, i))
results = {}
workers = 1 if not parallel else None
tot_runs = len(fvals) * len(levels) * nreps
with mp.Parallelize(
enumerate(tasks),
results=results,
progressDialog="Running parallel simulation..",
workers=workers,
) as tasker:
for i, task in tasker:
f, db, iteration = task
stim = sound.TonePip(
rate=100e3,
duration=stimpar["dur"],
f0=f,
dbspl=db, # dura 0.2, pip_start 0.1 pipdur 0.04
ramp_duration=2.5e-3,
pip_duration=stimpar["pip"],
pip_start=stimpar["start"],
)
print(("=== Start run %d/%d ===" % (i + 1, tot_runs)))
cachefile = os.path.join(
cachepath,
"seed=%d_f0=%f_dbspl=%f_syntype=%s_iter=%d.pk"
% (seed, f, db, syntype, iteration),
)
if "--ignore-cache" in sys.argv or not os.path.isfile(cachefile):
result = prot.run(stim, seed=i)
pickle.dump(result, open(cachefile, "wb"))
else:
print(" (Loading cached results)")
result = pickle.load(open(cachefile, "rb"))
tasker.results[(f, db, iteration)] = (stim, result)
print(("--- finished run %d/%d ---" % (i + 1, tot_runs)))
# get time of run before display
elapsed = timeit.default_timer() - start_time
print(
"Elapsed time for %d stimuli: %f (%f sec per stim), synapses: %s"
% (len(tasks), elapsed, elapsed / len(tasks), prot.bushy._synapsetype)
)
nd = NetworkSimDisplay(
prot, results, baseline=stimpar["baseline"], response=stimpar["response"]
)
nd.show()
if sys.flags.interactive == 0:
pg.QtGui.QApplication.exec_()