You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
704 lines
24 KiB
704 lines
24 KiB
""" |
|
Test principal cell responses to tone pips of varying frequency and intensity. |
|
|
|
This is an example of model construction from a very high level--we specify |
|
only which populations of cells are present and which ones should be connected. |
|
The population and cell classes take care of all the details of generating the |
|
network needed to support a small number of output cells. |
|
|
|
Note: run time for this example can be very long. To speed things up, reduce |
|
n_frequencies or n_levels, or reduce the number of selected output cells (see |
|
cells_per_band). |
|
|
|
""" |
|
|
|
import os, sys, time |
|
from collections import OrderedDict |
|
import numpy as np |
|
import scipy.stats |
|
from neuron import h |
|
import pyqtgraph as pg |
|
import pyqtgraph.multiprocess as mp |
|
from pyqtgraph.Qt import QtGui, QtCore |
|
from cnmodel import populations |
|
from cnmodel.util import sound, random_seed |
|
from cnmodel.protocols import Protocol |
|
import timeit |
|
|
|
|
|
class CNSoundStim(Protocol): |
|
def __init__(self, seed, temp=34.0, dt=0.025, synapsetype="simple"): |
|
Protocol.__init__(self) |
|
|
|
self.seed = seed |
|
self.temp = temp |
|
self.dt = dt |
|
# self.synapsetype = synapsetype # simple or multisite |
|
|
|
# Seed now to ensure network generation is stable |
|
random_seed.set_seed(seed) |
|
# Create cell populations. |
|
# This creates a complete set of _virtual_ cells for each population. No |
|
# cells are instantiated at this point. |
|
self.sgc = populations.SGC(model="dummy") |
|
self.bushy = populations.Bushy() |
|
self.dstellate = populations.DStellate() |
|
self.tstellate = populations.TStellate() |
|
self.tuberculoventral = populations.Tuberculoventral() |
|
|
|
pops = [ |
|
self.sgc, |
|
self.dstellate, |
|
self.tuberculoventral, |
|
self.tstellate, |
|
self.bushy, |
|
] |
|
self.populations = OrderedDict([(pop.type, pop) for pop in pops]) |
|
|
|
# set synapse type to use in the sgc population - simple is fast, multisite is slower |
|
# (eventually, we could do this for all synapse types..) |
|
self.sgc._synapsetype = synapsetype |
|
|
|
# Connect populations. |
|
# This only defines the connections between populations; no synapses are |
|
# created at this stage. |
|
self.sgc.connect( |
|
self.bushy, self.dstellate, self.tuberculoventral, self.tstellate |
|
) |
|
self.dstellate.connect( |
|
self.bushy, self.tstellate |
|
) # should connect to dstellate as well? |
|
self.tuberculoventral.connect(self.bushy, self.tstellate) |
|
self.tstellate.connect(self.bushy) |
|
|
|
# Select cells to record from. |
|
# At this time, we actually instantiate the selected cells. |
|
|
|
# Pick a single bushy cell near 16kHz, with medium-SR inputs |
|
bc = self.bushy.cells |
|
msr_cells = bc[bc["sgc_sr"] == 1] # filter for msr cells |
|
ind = np.argmin(np.abs(msr_cells["cf"] - 16e3)) # find the one closest to 16kHz |
|
cell_id = msr_cells[ind]["id"] |
|
self.bushy.create_cells([cell_id]) # instantiate just one cell |
|
|
|
# Now create the supporting circuitry needed to drive the cells we selected. |
|
# At this time, cells are created in all populations and automatically |
|
# connected with synapses. |
|
self.bushy.resolve_inputs(depth=2) |
|
# self.tstellate.resolve_inputs(depth=2) |
|
# Note that using depth=2 indicates the level of recursion to use when |
|
# resolving inputs. For example, resolving inputs for the bushy cell population |
|
# (level 1) creates presynaptic cells in the dstellate population, and resolving |
|
# inputs for the dstellate population (level 2) creates presynaptic cells in the |
|
# sgc population. |
|
|
|
def run(self, stim, seed): |
|
"""Run the network simulation with *stim* as the sound source and a unique |
|
*seed* used to configure the random number generators. |
|
""" |
|
self.reset() |
|
|
|
# Generate 2 new seeds for the SGC spike generator and for the NEURON simulation |
|
rs = np.random.RandomState() |
|
rs.seed(self.seed ^ seed) |
|
seed1, seed2 = rs.randint(0, 2 ** 32, 2) |
|
random_seed.set_seed(seed1) |
|
self.sgc.set_seed(seed2) |
|
|
|
self.sgc.set_sound_stim(stim, parallel=False) |
|
|
|
# set up recording vectors |
|
for pop in self.bushy, self.dstellate, self.tstellate, self.tuberculoventral: |
|
for ind in pop.real_cells(): |
|
cell = pop.get_cell(ind) |
|
self[cell] = cell.soma(0.5)._ref_v |
|
self["t"] = h._ref_t |
|
|
|
h.tstop = stim.duration * 1000 |
|
h.celsius = self.temp |
|
h.dt = self.dt |
|
|
|
print("init..") |
|
self.custom_init() |
|
print("start..") |
|
last_update = time.time() |
|
while h.t < h.tstop: |
|
h.fadvance() |
|
now = time.time() |
|
if now - last_update > 1.0: |
|
print("%0.2f / %0.2f" % (h.t, h.tstop)) |
|
last_update = now |
|
|
|
# record vsoma and spike times for all cells |
|
vec = {} |
|
for k in self._vectors: |
|
v = self[k].copy() |
|
if k == "t": |
|
vec[k] = v |
|
continue |
|
spike_inds = np.argwhere((v[1:] > -20) & (v[:-1] <= -20))[:, 0] |
|
spikes = self["t"][spike_inds] |
|
pop = k.type |
|
cell_ind = getattr(self, pop).get_cell_index(k) |
|
vec[(pop, cell_ind)] = [v, spikes] |
|
|
|
# record SGC spike trains |
|
for ind in self.sgc.real_cells(): |
|
cell = self.sgc.get_cell(ind) |
|
vec[("sgc", ind)] = [None, cell._spiketrain] |
|
|
|
return vec |
|
|
|
|
|
class NetworkSimDisplay(pg.QtGui.QSplitter): |
|
def __init__(self, prot, results, baseline, response): |
|
pg.QtGui.QSplitter.__init__(self, QtCore.Qt.Horizontal) |
|
self.selected_cell = None |
|
|
|
self.prot = prot |
|
self.baseline = baseline # (start, stop) |
|
self.response = response # (start, stop) |
|
|
|
self.ctrl = QtGui.QWidget() |
|
self.layout = pg.QtGui.QVBoxLayout() |
|
self.layout.setContentsMargins(0, 0, 0, 0) |
|
self.ctrl.setLayout(self.layout) |
|
self.addWidget(self.ctrl) |
|
|
|
self.nv = NetworkVisualizer(prot.populations) |
|
self.layout.addWidget(self.nv) |
|
self.nv.cell_selected.connect(self.nv_cell_selected) |
|
|
|
self.stim_combo = pg.QtGui.QComboBox() |
|
self.layout.addWidget(self.stim_combo) |
|
self.trial_combo = pg.QtGui.QComboBox() |
|
self.layout.addWidget(self.trial_combo) |
|
self.results = OrderedDict() |
|
self.stim_order = [] |
|
freqs = set() |
|
levels = set() |
|
max_iter = 0 |
|
for k, v in list(results.items()): |
|
f0, dbspl, iteration = k |
|
max_iter = max(max_iter, iteration) |
|
stim, result = v |
|
key = "f0: %0.0f dBspl: %0.0f" % (f0, dbspl) |
|
self.results.setdefault(key, [stim, {}]) |
|
self.results[key][1][iteration] = result |
|
self.stim_order.append((f0, dbspl)) |
|
freqs.add(f0) |
|
levels.add(dbspl) |
|
self.stim_combo.addItem(key) |
|
self.freqs = sorted(list(freqs)) |
|
self.levels = sorted(list(levels)) |
|
self.iterations = max_iter + 1 |
|
self.trial_combo.addItem("all trials") |
|
for i in range(self.iterations): |
|
self.trial_combo.addItem(str(i)) |
|
|
|
self.stim_combo.currentIndexChanged.connect(self.stim_selected) |
|
self.trial_combo.currentIndexChanged.connect(self.trial_selected) |
|
|
|
self.tuning_plot = pg.PlotWidget() |
|
self.tuning_plot.setLogMode(x=True, y=False) |
|
self.tuning_plot.scene().sigMouseClicked.connect(self.tuning_plot_clicked) |
|
self.layout.addWidget(self.tuning_plot) |
|
|
|
self.tuning_img = pg.ImageItem() |
|
self.tuning_plot.addItem(self.tuning_img) |
|
|
|
df = np.log10(self.freqs[1]) - np.log10(self.freqs[0]) |
|
dl = self.levels[1] - self.levels[0] |
|
self.stim_rect = QtGui.QGraphicsRectItem(QtCore.QRectF(0, 0, df, dl)) |
|
self.stim_rect.setPen(pg.mkPen("c")) |
|
self.stim_rect.setZValue(20) |
|
self.tuning_plot.addItem(self.stim_rect) |
|
|
|
# self.network_tree = NetworkTree(self.prot) |
|
# self.layout.addWidget(self.network_tree) |
|
|
|
self.pw = pg.GraphicsLayoutWidget() |
|
self.addWidget(self.pw) |
|
|
|
self.stim_plot = self.pw.addPlot() |
|
self.pw.ci.layout.setRowFixedHeight(0, 100) |
|
|
|
self.pw.nextRow() |
|
self.cell_plot = self.pw.addPlot(labels={"left": "Vm"}) |
|
|
|
self.pw.nextRow() |
|
self.input_plot = self.pw.addPlot( |
|
labels={"left": "input #", "bottom": "time"}, title="Input spike times" |
|
) |
|
self.input_plot.setXLink(self.cell_plot) |
|
self.stim_plot.setXLink(self.cell_plot) |
|
|
|
self.stim_selected() |
|
|
|
def update_stim_plot(self): |
|
stim = self.selected_stim |
|
self.stim_plot.plot(stim.time * 1000, stim.sound, clear=True, antialias=True) |
|
|
|
def update_raster_plot(self): |
|
self.input_plot.clear() |
|
if self.selected_cell is None: |
|
return |
|
pop, ind = self.selected_cell |
|
|
|
rec = pop._cells[ind] |
|
i = 0 |
|
plots = [] |
|
# plot spike times for all presynaptic cells |
|
labels = [] |
|
if rec["connections"] == 0: |
|
return |
|
|
|
pop_colors = { |
|
"dstellate": "y", |
|
"tuberculoventral": "r", |
|
"sgc": "g", |
|
"tstellate": "b", |
|
} |
|
pop_symbols = { |
|
"dstellate": "x", |
|
"tuberculoventral": "+", |
|
"sgc": "t", |
|
"tstellate": "o", |
|
} |
|
pop_order = [self.prot.sgc, self.prot.dstellate, self.prot.tuberculoventral] |
|
trials = self.selected_trials() |
|
for pop in pop_order: |
|
pre_inds = rec["connections"].get(pop, []) |
|
for preind in pre_inds: |
|
# iterate over all trials |
|
for j in trials: |
|
result = self.selected_results[j] |
|
spikes = result[(pop.type, preind)][1] |
|
y = np.ones(len(spikes)) * i + j / ( |
|
2.0 * len(self.selected_results) |
|
) |
|
self.input_plot.plot( |
|
spikes, |
|
y, |
|
pen=None, |
|
symbolBrush=pop_colors[pop.type], |
|
symbol="+", |
|
symbolPen=None, |
|
) |
|
i += 1 |
|
labels.append(pop.type + " " + str(preind)) |
|
self.input_plot.getAxis("left").setTicks([list(enumerate(labels))]) |
|
|
|
def update_cell_plot(self): |
|
self.cell_plot.clear() |
|
if self.selected_cell is None: |
|
return |
|
pop, cell_ind = self.selected_cell |
|
|
|
self.cell_plot.setTitle( |
|
"%s %d %s" % (pop.type, cell_ind, str(self.stim_combo.currentText())) |
|
) |
|
trials = self.selected_trials() |
|
for i in trials: |
|
result = self.selected_results[i] |
|
y = result[(pop.type, cell_ind)][0] |
|
if y is not None: |
|
p = self.cell_plot.plot( |
|
self.selected_results[0]["t"], |
|
y, |
|
name="%s-%d" % self.selected_cell, |
|
antialias=True, |
|
pen=(i, len(self.selected_results) * 1.5), |
|
) |
|
# p.curve.setClickable(True) |
|
# p.sigClicked.connect(self.cell_curve_clicked) |
|
# p.cell_ind = ind |
|
|
|
def tuning_plot_clicked(self, event): |
|
spos = event.scenePos() |
|
stimpos = self.tuning_plot.plotItem.vb.mapSceneToView(spos) |
|
x = 10 ** stimpos.x() |
|
y = stimpos.y() |
|
|
|
best = None |
|
for stim, result in list(self.results.values()): |
|
f0 = stim.opts["f0"] |
|
dbspl = stim.opts["dbspl"] |
|
if x < f0 or y < dbspl: |
|
continue |
|
if best is None: |
|
best = stim |
|
continue |
|
if f0 > best.opts["f0"] or dbspl > best.opts["dbspl"]: |
|
best = stim |
|
continue |
|
|
|
if best is None: |
|
return |
|
self.select_stim(best.opts["f0"], best.opts["dbspl"]) |
|
|
|
def nv_cell_selected(self, nv, cell): |
|
self.select_cell(*cell) |
|
|
|
def stim_selected(self): |
|
key = str(self.stim_combo.currentText()) |
|
results = self.results[key] |
|
self.selected_results = results[1] |
|
self.selected_stim = results[0] |
|
self.update_stim_plot() |
|
self.update_raster_plot() |
|
self.update_cell_plot() |
|
|
|
self.stim_rect.setPos(np.log10(results[0].opts["f0"]), results[0].opts["dbspl"]) |
|
|
|
def trial_selected(self): |
|
self.update_raster_plot() |
|
self.update_cell_plot() |
|
self.update_tuning() |
|
|
|
def selected_trials(self): |
|
if self.trial_combo.currentIndex() == 0: |
|
return list(range(self.iterations)) |
|
else: |
|
return [self.trial_combo.currentIndex() - 1] |
|
|
|
def select_stim(self, f0, dbspl): |
|
i = self.stim_order.index((f0, dbspl)) |
|
self.stim_combo.setCurrentIndex(i) |
|
|
|
def select_cell(self, pop, cell_id): |
|
self.selected_cell = pop, cell_id |
|
self.update_tuning() |
|
self.update_cell_plot() |
|
self.update_raster_plot() |
|
|
|
# def cell_curve_clicked(self, c): |
|
# if self.selected is not None: |
|
# pen = self.selected.curve.opts['pen'] |
|
# pen.setWidth(1) |
|
# self.selected.setPen(pen) |
|
|
|
# pen = c.curve.opts['pen'] |
|
# pen.setWidth(3) |
|
# c.setPen(pen) |
|
# self.selected = c |
|
|
|
# self.show_cell(c.cell_ind) |
|
|
|
def update_tuning(self): |
|
# update matrix image |
|
if self.selected_cell is None: |
|
return |
|
|
|
pop, ind = self.selected_cell |
|
fvals = set() |
|
lvals = set() |
|
|
|
# first get lists of all frequencies and levels in the matrix |
|
for stim, vec in list(self.results.values()): |
|
fvals.add(stim.key()["f0"]) |
|
lvals.add(stim.key()["dbspl"]) |
|
fvals = sorted(list(fvals)) |
|
lvals = sorted(list(lvals)) |
|
|
|
# Get spontaneous rate statistics |
|
spont_spikes = 0 |
|
spont_time = 0 |
|
for stim, iterations in list(self.results.values()): |
|
for vec in list(iterations.values()): |
|
spikes = vec[(pop.type, ind)][1] |
|
spont_spikes += ( |
|
(spikes >= self.baseline[0]) & (spikes < self.baseline[1]) |
|
).sum() |
|
spont_time += self.baseline[1] - self.baseline[0] |
|
spont_rate = spont_spikes / spont_time |
|
|
|
# next count the number of spikes for the selected cell at each point in the matrix |
|
matrix = np.zeros((len(fvals), len(lvals))) |
|
trials = self.selected_trials() |
|
for stim, iteration in list(self.results.values()): |
|
for i in trials: |
|
vec = iteration[i] |
|
spikes = vec[(pop.type, ind)][1] |
|
n_spikes = ( |
|
(spikes >= self.response[0]) & (spikes < self.response[1]) |
|
).sum() |
|
i = fvals.index(stim.key()["f0"]) |
|
j = lvals.index(stim.key()["dbspl"]) |
|
matrix[i, j] += n_spikes - spont_rate * ( |
|
self.response[1] - self.response[0] |
|
) |
|
matrix /= self.iterations |
|
|
|
# plot and scale the matrix image |
|
# note that the origin (lower left) of each image pixel indicates its actual test freq/level. |
|
self.tuning_img.setImage(matrix) |
|
self.tuning_img.resetTransform() |
|
self.tuning_img.setPos(np.log10(min(fvals)), min(lvals)) |
|
self.tuning_img.scale( |
|
(np.log10(max(fvals)) - np.log10(min(fvals))) / (len(fvals) - 1), |
|
(max(lvals) - min(lvals)) / (len(lvals) - 1), |
|
) |
|
|
|
|
|
class NetworkTree(QtGui.QTreeWidget): |
|
def __init__(self, prot): |
|
self.prot = prot |
|
QtGui.QTreeWidget.__init__(self) |
|
self.setColumnCount(2) |
|
|
|
self.update_tree() |
|
|
|
def update_tree(self): |
|
for pop_name in ["bushy", "tstellate", "dstellate", "tuberculoventral", "sgc"]: |
|
if not hasattr(self.prot, pop_name): |
|
continue |
|
pop = getattr(self.prot, pop_name) |
|
grp = QtGui.QTreeWidgetItem([pop_name]) |
|
self.addTopLevelItem(grp) |
|
for cell in pop.real_cells(): |
|
self.add_cell(grp, pop, cell) |
|
|
|
def add_cell(self, grp_item, pop, cell): |
|
item = QtGui.QTreeWidgetItem([str(cell)]) |
|
grp_item.addChild(item) |
|
all_conns = pop.cell_connections(cell) |
|
if all_conns == 0: |
|
return |
|
for cpop, conns in list(all_conns.items()): |
|
pop_grp = QtGui.QTreeWidgetItem([cpop.type, str(conns)]) |
|
item.addChild(pop_grp) |
|
|
|
|
|
class NetworkVisualizer(pg.PlotWidget): |
|
|
|
cell_selected = pg.QtCore.Signal(object, object) |
|
|
|
def __init__(self, populations): |
|
self.pops = populations |
|
pg.PlotWidget.__init__(self) |
|
self.setLogMode(x=True, y=False) |
|
|
|
self.cells = pg.ScatterPlotItem(clickable=True) |
|
self.cells.setZValue(10) |
|
self.addItem(self.cells) |
|
self.cells.sigClicked.connect(self.cells_clicked) |
|
|
|
self.selected = pg.ScatterPlotItem() |
|
self.selected.setZValue(20) |
|
self.addItem(self.selected) |
|
|
|
self.connections = pg.PlotCurveItem() |
|
self.addItem(self.connections) |
|
|
|
# first assign positions of all cells |
|
cells = [] |
|
for y, pop in enumerate(self.pops.values()): |
|
pop.cell_spots = [] |
|
pop.fwd_connections = {} |
|
for i, cell in enumerate(pop._cells): |
|
pos = (np.log10(cell["cf"]), y) |
|
real = cell["cell"] != 0 |
|
if not real: |
|
pop.cell_spots.append(None) |
|
continue |
|
brush = pg.mkBrush("b") if real else pg.mkBrush(255, 255, 255, 30) |
|
spot = { |
|
"x": pos[0], |
|
"y": pos[1], |
|
"symbol": "o" if real else "x", |
|
"brush": brush, |
|
"pen": None, |
|
"data": (pop, i), |
|
} |
|
cells.append(spot) |
|
pop.cell_spots.append(spot) |
|
|
|
self.cells.setData(cells) |
|
|
|
self.getAxis("left").setTicks([list(enumerate(self.pops.keys()))]) |
|
|
|
# now assign connection lines and record forward connectivity |
|
con_x = [] |
|
con_y = [] |
|
for pop in list(self.pops.values()): |
|
for i, cell in enumerate(pop._cells): |
|
conns = cell["connections"] |
|
if conns == 0: |
|
continue |
|
for prepop, precells in list(conns.items()): |
|
spot = pop.cell_spots[i] |
|
if spot is None: |
|
continue |
|
p1 = spot["x"], spot["y"] |
|
for j in precells: |
|
prepop.fwd_connections.setdefault(j, []) |
|
prepop.fwd_connections[j].append((pop, i)) |
|
spot2 = prepop.cell_spots[j] |
|
if spot2 is None: |
|
return |
|
p2 = spot2["x"], spot2["y"] |
|
con_x.extend([p1[0], p2[0]]) |
|
con_y.extend([p1[1], p2[1]]) |
|
self.connections.setData( |
|
x=con_x, y=con_y, connect="pairs", pen=(255, 255, 255, 60) |
|
) |
|
|
|
def cells_clicked(self, *args): |
|
selected = None |
|
for spot in args[1]: |
|
# find the first real cell |
|
pop, i = spot.data() |
|
if pop._cells[i]["cell"] != 0: |
|
selected = spot |
|
break |
|
if selected is None: |
|
self.selected.hide() |
|
return |
|
|
|
rec = pop._cells[i] |
|
pos = selected.pos() |
|
spots = [ |
|
{ |
|
"x": pos.x(), |
|
"y": pos.y(), |
|
"size": 15, |
|
"symbol": "o", |
|
"pen": "y", |
|
"brush": "b", |
|
} |
|
] |
|
|
|
# display presynaptic cells |
|
if rec["connections"] != 0: |
|
for prepop, preinds in list(rec["connections"].items()): |
|
for preind in preinds: |
|
spot = prepop.cell_spots[preind].copy() |
|
spot["size"] = 15 |
|
spot["brush"] = "r" |
|
spots.append(spot) |
|
|
|
# display postsynaptic cells |
|
for postpop, postind in pop.fwd_connections.get(i, []): |
|
spot = postpop.cell_spots[postind].copy() |
|
spot["size"] = 15 |
|
spot["brush"] = "g" |
|
spots.append(spot) |
|
|
|
self.selected.setData(spots) |
|
self.selected.show() |
|
|
|
self.cell_selected.emit(self, selected.data()) |
|
|
|
|
|
if __name__ == "__main__": |
|
import pickle, os, sys |
|
|
|
app = pg.mkQApp() |
|
pg.dbg() |
|
|
|
# Create a sound stimulus and use it to generate spike trains for the SGC |
|
# population |
|
stims = [] |
|
parallel = True |
|
|
|
nreps = 5 |
|
fmin = 4e3 |
|
fmax = 32e3 |
|
octavespacing = 1 / 8.0 |
|
# octavespacing = 1. |
|
n_frequencies = int(np.log2(fmax / fmin) / octavespacing) + 1 |
|
fvals = ( |
|
np.logspace( |
|
np.log2(fmin / 1000.0), |
|
np.log2(fmax / 1000.0), |
|
num=n_frequencies, |
|
endpoint=True, |
|
base=2, |
|
) |
|
* 1000.0 |
|
) |
|
|
|
n_levels = 11 |
|
# n_levels = 3 |
|
levels = np.linspace(20, 100, n_levels) |
|
|
|
print(("Frequencies:", fvals / 1000.0)) |
|
print(("Levels:", levels)) |
|
|
|
syntype = "multisite" |
|
path = os.path.dirname(__file__) |
|
cachepath = os.path.join(path, "cache") |
|
if not os.path.isdir(cachepath): |
|
os.mkdir(cachepath) |
|
|
|
seed = 34657845 |
|
prot = CNSoundStim(seed=seed, synapsetype=syntype) |
|
i = 0 |
|
|
|
start_time = timeit.default_timer() |
|
|
|
# stimpar = {'dur': 0.06, 'pip': 0.025, 'start': [0.02], 'baseline': [10, 20], 'response': [20, 45]} |
|
stimpar = { |
|
"dur": 0.2, |
|
"pip": 0.04, |
|
"start": [0.1], |
|
"baseline": [50, 100], |
|
"response": [100, 140], |
|
} |
|
tasks = [] |
|
for f in fvals: |
|
for db in levels: |
|
for i in range(nreps): |
|
tasks.append((f, db, i)) |
|
|
|
results = {} |
|
workers = 1 if not parallel else None |
|
tot_runs = len(fvals) * len(levels) * nreps |
|
with mp.Parallelize( |
|
enumerate(tasks), |
|
results=results, |
|
progressDialog="Running parallel simulation..", |
|
workers=workers, |
|
) as tasker: |
|
for i, task in tasker: |
|
f, db, iteration = task |
|
stim = sound.TonePip( |
|
rate=100e3, |
|
duration=stimpar["dur"], |
|
f0=f, |
|
dbspl=db, # dura 0.2, pip_start 0.1 pipdur 0.04 |
|
ramp_duration=2.5e-3, |
|
pip_duration=stimpar["pip"], |
|
pip_start=stimpar["start"], |
|
) |
|
|
|
print(("=== Start run %d/%d ===" % (i + 1, tot_runs))) |
|
cachefile = os.path.join( |
|
cachepath, |
|
"seed=%d_f0=%f_dbspl=%f_syntype=%s_iter=%d.pk" |
|
% (seed, f, db, syntype, iteration), |
|
) |
|
if "--ignore-cache" in sys.argv or not os.path.isfile(cachefile): |
|
result = prot.run(stim, seed=i) |
|
pickle.dump(result, open(cachefile, "wb")) |
|
else: |
|
print(" (Loading cached results)") |
|
result = pickle.load(open(cachefile, "rb")) |
|
tasker.results[(f, db, iteration)] = (stim, result) |
|
print(("--- finished run %d/%d ---" % (i + 1, tot_runs))) |
|
|
|
# get time of run before display |
|
elapsed = timeit.default_timer() - start_time |
|
print( |
|
"Elapsed time for %d stimuli: %f (%f sec per stim), synapses: %s" |
|
% (len(tasks), elapsed, elapsed / len(tasks), prot.bushy._synapsetype) |
|
) |
|
|
|
nd = NetworkSimDisplay( |
|
prot, results, baseline=stimpar["baseline"], response=stimpar["response"] |
|
) |
|
nd.show() |
|
|
|
if sys.flags.interactive == 0: |
|
pg.QtGui.QApplication.exec_()
|
|
|