You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
443 lines
17 KiB
443 lines
17 KiB
import logging |
|
import scipy.stats |
|
import numpy as np |
|
|
|
from .. import data |
|
|
|
|
|
class Population(object): |
|
""" |
|
A Population represents a group of cell all having the same type. |
|
|
|
Populations provide methods for: |
|
|
|
* Adding cells to the population with characteristic distributions. |
|
* Connecting the cells in one population to the cells in another. |
|
* Automatically adding cells to satisfy connectivity requirements when |
|
connecting populations together. |
|
|
|
Populations have a concept of a "natural" underlying distribution of |
|
neurons, and behave as if all neurons in this distribution already exist |
|
in the model. However, initially all neurons are virtual, and are only |
|
instantiated to become a part of the running model if the neuron provides |
|
synaptic input to another non-virtual neuron, or if the user explicitly |
|
requests a recording of the neuron. |
|
|
|
Subclasses represent populations for a specific cell type, and at least |
|
need to reimplement the `create_cell` and `connection_stats` methods. |
|
""" |
|
|
|
def __init__(self, species, size, fields, synapsetype="multisite", **kwds): |
|
self._species = species |
|
self._post_connections = [] # populations this one connects to |
|
self._pre_connections = [] # populations connecting to this one |
|
self._synapsetype = synapsetype |
|
# numpy record array with information about each cell in the |
|
# population |
|
fields = [ |
|
("id", int), |
|
("cell", object), |
|
("input_resolved", bool), |
|
("connections", object), # {pop: [cells], ...} |
|
] + fields |
|
self._cells = np.zeros(size, dtype=fields) |
|
self._cells["id"] = np.arange(size) |
|
self._cell_indexes = {} # maps cell:index |
|
self._cell_args = kwds |
|
|
|
@property |
|
def cells(self): |
|
""" The array of cells in this population. |
|
|
|
For all populations, this array has a 'cell' field that is either 0 |
|
(for virtual cells) or a Cell instance (for real cells). |
|
|
|
Extra fields may be added by each Population subclass. |
|
""" |
|
return self._cells.copy() |
|
|
|
@property |
|
def species(self): |
|
return self._species |
|
|
|
def unresolved_cells(self): |
|
""" Return indexes of all real cells whose inputs have not been |
|
resolved. |
|
""" |
|
real = self._cells["cell"] != 0 |
|
unresolved = self._cells["input_resolved"] == False |
|
return np.argwhere(real & unresolved)[:, 0] |
|
|
|
def real_cells(self): |
|
""" Return indexes of all real cells in this population. |
|
|
|
Initially, all cells in the population are virtual--they are accounted |
|
for, but not actually instantiated as part of the NEURON simulation. |
|
Virtual cells can be made real by calling `get_cell()`. This method |
|
returns the indexes of all cells for which `get_cell()` has already |
|
been invoked. |
|
""" |
|
return np.argwhere(self._cells["cell"] != 0)[:, 0] |
|
|
|
def connect(self, *pops): |
|
""" Connect this population to any number of other populations. |
|
|
|
A connection is unidirectional; calling ``pop1.connect(pop2)`` can only |
|
result in projections from pop1 to pop2. |
|
|
|
Note that the connection is purely symbolic at first; no cells are |
|
actually connected by synapses at this time. |
|
""" |
|
self._post_connections.extend(pops) |
|
for pop in pops: |
|
pop._pre_connections.append(self) |
|
|
|
@property |
|
def pre_connections(self): |
|
""" The list of populations connected to this one. |
|
""" |
|
return self._pre_connections[:] |
|
|
|
def cell_connections(self, index): |
|
""" Return a dictionary containing, for each population, a list of |
|
cells connected to the cell in this population at *index*. |
|
""" |
|
return self._cells[index]["connections"] |
|
|
|
def resolve_inputs(self, depth=1, showlog=False): |
|
""" For each _real_ cell in the population, select a set of |
|
presynaptic partners from each connected population and generate a |
|
synapse from each. |
|
|
|
Although it is allowed to call ``resolve_inputs`` multiple times for |
|
a single population, each individual cell will only resolve its inputs |
|
once. Therefore, it is recommended to create and connect all |
|
populations before making any calls to ``resolve_inputs``. |
|
""" |
|
for i in self.unresolved_cells(): |
|
# loop over all cells whose presynaptic inputs have not been resolved |
|
cell = self._cells[i]["cell"] |
|
if showlog: |
|
logging.info("Resolving inputs for %s %d", self, i) |
|
self._cells[i]["connections"] = {} |
|
|
|
# select cells from each population to connect to this cell |
|
for pop in self._pre_connections: |
|
pre_cells = self.connect_pop_to_cell(pop, i) |
|
if showlog: |
|
logging.info(" connected %d cells from %s", len(pre_cells), pop) |
|
assert pre_cells is not None |
|
self._cells[i]["connections"][pop] = pre_cells |
|
self._cells[i]["input_resolved"] = True |
|
|
|
# recursively resolve inputs in connected populations |
|
if depth > 1: |
|
for pop in self.pre_connections: |
|
pop.resolve_inputs(depth - 1, showlog=showlog) |
|
|
|
def connect_pop_to_cell(self, pop, cell_index): |
|
""" Connect cells in a presynaptic population to the cell in this |
|
population at *cell_index*, and return the presynaptic indexes of cells |
|
that were connected. |
|
|
|
This method is responsible for choosing pairs of cells to be connected |
|
by synapses, and may be overridden in subclasses. |
|
|
|
The default implementation calls `self.connection_stats()` to determine |
|
the number and selection criteria of presynaptic cells. |
|
""" |
|
cell_rec = self._cells[cell_index] |
|
cell = cell_rec["cell"] |
|
size, dist = self.connection_stats(pop, cell_rec) |
|
# Select SGCs from distribution, create, and connect to this cell |
|
# todo: select sgcs with similar spont. rate? |
|
pre_cells = pop.select(size=size, create=False, **dist) |
|
for j in pre_cells: |
|
pre_cell = pop.get_cell(j) |
|
# use default settings for connecting these. |
|
pre_cell.connect(cell, type=self._synapsetype) |
|
return pre_cells |
|
|
|
def connection_stats(self, pop, cell_rec): |
|
""" The population *pop* is being connected to the cell described in |
|
*cell_rec*. |
|
|
|
This method is responsible for deciding the distributions of presynaptic |
|
cell properties for any given postsynaptic cell (for example, a cell |
|
with cf=10kHz might receive SGC input from 10 cells selected from a |
|
normal distribution centered at 10kHz). |
|
|
|
The default implementation of this method uses the 'convergence' and |
|
'convergence_range' values from the data tables to specify a lognormal |
|
distribution of presynaptic cells around the postsynaptic cell's CF. |
|
|
|
This method must return a tuple (size, dist) with the following values: |
|
|
|
* size: integer giving the number of cells that should be selected from |
|
the presynaptic population and connected to the postsynaptic cell. |
|
* dist: dictionary of {property_name: distribution} pairs that describe |
|
how cells should be selected from the presynaptic population. See |
|
keyword arguments to `select()` for more information on the content |
|
of this dictionary. |
|
""" |
|
cf = cell_rec["cf"] |
|
|
|
# Convergence distributions (how many presynaptic |
|
# cells to connect) |
|
try: |
|
n_connections = data.get( |
|
"convergence", |
|
species=self.species, |
|
pre_type=pop.type, |
|
post_type=self.type, |
|
) |
|
except KeyError: |
|
raise TypeError( |
|
"Cannot connect population %s to %s; no convergence specified in data table." |
|
% (pop, self) |
|
) |
|
|
|
if isinstance(n_connections, tuple): |
|
size_dist = scipy.stats.norm(loc=n_connections[0], scale=n_connections[1]) |
|
size = max(0, size_dist.rvs()) |
|
else: |
|
size = n_connections |
|
size = int(size) # must be an integer at this point |
|
|
|
# Convergence ranges -- over what range of CFs should we |
|
# select presynaptic cells. |
|
try: |
|
input_range = data.get( |
|
"convergence_range", |
|
species=self.species, |
|
pre_type=pop.type, |
|
post_type=self.type, |
|
) |
|
except KeyError: |
|
raise TypeError( |
|
"Cannot connect population %s to %s; no convergence range specified in data table." |
|
% (pop, self) |
|
) |
|
|
|
dist = {"cf": scipy.stats.lognorm(input_range, scale=cf)} |
|
# print(cf, input_range, dist) |
|
return size, dist |
|
|
|
def get_sgcsr_array(self, freqs, species="mouse"): |
|
""" |
|
Create an array of length (freqs) (number of SGCs) |
|
Each entry is a value indicating the SR group, according to some proportion |
|
2 = high, 1 = medium, 0 = low |
|
|
|
Parameters |
|
---------- |
|
freqs : nunpy array |
|
|
|
species : str (default: 'mouse') |
|
name of the species for the map. |
|
|
|
Returns: |
|
numpy array |
|
An array matched to freqs, with SR's indicated numerically |
|
""" |
|
assert species == "mouse" # only mice so far. |
|
nhs = np.random.random_sample( |
|
freqs.shape[0] |
|
) # uniform random distribution across frequency |
|
sr_array = np.zeros_like(freqs) # build array - initially all low sponts |
|
sr_array[ |
|
np.argwhere(nhs < 0.53) |
|
] = 2 # high spont (53% estimated from Taberner and Liberman, 2005) |
|
sr_array[ |
|
np.argwhere((nhs >= 0.53) & (nhs < 0.77)) |
|
] = 1 # medium spont, about 24% (1-20 sp/sec) |
|
# the rest have SR value of 0, corresponding to the low-spont group |
|
return sr_array |
|
|
|
def select_sgcsr_inputs(self, sr_array, weights): |
|
""" |
|
Subsample the arrays above to create a distribution for cells that might only get |
|
a subset of inputs (for example, only msr and lsr fibers) |
|
|
|
Parameters |
|
---------- |
|
sr_array : numpy array |
|
the SR array to draw the samples from |
|
|
|
weights : 3 element list |
|
Weights for [lsr, msr, hsr] ANFs. Proportions will be computed |
|
from these weights (e.g., [1,1,1] is uniform for all fibers) |
|
weights of [1,1,0] means all hsr fibers will be masked |
|
|
|
Returns: |
|
numpy array of "dist" |
|
Values of 0 are sgcs masked from input, 1 are ok |
|
""" |
|
assert len(weights) == 3 |
|
|
|
dist = np.zeros_like(sr_array) # boolean array, all values |
|
norm_wt = 3.0 * np.array( |
|
weights / np.sum(weights) |
|
) # fraction from within each group |
|
for i in range(len(weights)): |
|
dx = np.where(sr_array == i)[0] |
|
ind = np.random.choice(len(dx), int(norm_wt[i] * len(dx))) |
|
dist[dx[ind]] = 1 |
|
return dist |
|
|
|
def _get_cf_array(self, species): |
|
"""Return the array of CF values that should be used when instantiating |
|
this population. |
|
|
|
Commonly used by subclasses durin initialization. |
|
""" |
|
size = data.get( |
|
"populations", species=species, cell_type=self.type, field="n_cells" |
|
) |
|
fmin = data.get( |
|
"populations", species=species, cell_type=self.type, field="cf_min" |
|
) |
|
fmax = data.get( |
|
"populations", species=species, cell_type=self.type, field="cf_max" |
|
) |
|
s = (fmax / fmin) ** (1.0 / size) |
|
freqs = fmin * s ** np.arange(size) |
|
# print('frqs #: ', len(freqs)) |
|
# Cut off at 40kHz because the auditory nerve model only goes that far :( |
|
freqs = freqs[freqs <= 40e3] |
|
|
|
return freqs |
|
|
|
def select(self, size, create=False, **kwds): |
|
""" Return a list of indexes for cells matching the selection criteria. |
|
|
|
The *size* argument specifies the number of cells to return. |
|
|
|
If *create* is True, then any selected cells that are virtual will be |
|
instantiated. |
|
|
|
Each keyword argument must be the name of a field in self.cells. Values |
|
may be: |
|
|
|
* A distribution (see scipy.stats), in which case the distribution |
|
influences the selection of cells |
|
* An array giving the probability to assign to each cell in the |
|
population |
|
* A number, in which case the cell(s) with the closest match |
|
are returned. If this is used, it overrides all other criteria except |
|
where they evaluate to 0. |
|
|
|
If multiple distributions are provided, then the product of the survival |
|
functions of all distributions determines the probability of selecting |
|
each cell. |
|
""" |
|
if len(kwds) == 0: |
|
raise TypeError("Must specify at least one selection criteria") |
|
|
|
full_dist = np.ones(len(self._cells)) |
|
nearest = None |
|
nearest_field = None |
|
for field, dist in kwds.items(): |
|
if np.isscalar(dist): |
|
if nearest is not None: |
|
raise Exception( |
|
"May not specify multiple single-valued selection criteria." |
|
) |
|
nearest = dist |
|
nearest_field = field |
|
elif isinstance(dist, scipy.stats.distributions.rv_frozen): |
|
vals = self._cells[field] |
|
dens = np.diff(vals) |
|
dens = np.concatenate([dens[:1], dens]) |
|
pdf = dist.pdf(vals) * dens |
|
full_dist *= pdf / pdf.sum() |
|
elif isinstance(dist, np.ndarray): |
|
full_dist *= dist |
|
else: |
|
raise TypeError("Distributed criteria must be array or rv_frozen.") |
|
|
|
# Select cells nearest to the requested value, but only pick from |
|
# cells with nonzero probability. |
|
if nearest is not None: |
|
cells = [] |
|
mask = full_dist == 0 |
|
err = np.abs(self._cells[nearest_field] - nearest).astype(float) |
|
for i in range(size): |
|
err[mask] = np.inf |
|
cell = np.argmin(err) |
|
mask[cell] = True |
|
cells.append(cell) |
|
|
|
# Select cells randomly from the specified combined probability |
|
# distribution |
|
else: |
|
cells = [] |
|
full_dist /= full_dist.sum() |
|
vals = np.random.uniform(size=size) |
|
vals.sort() |
|
cumulative = np.cumsum(full_dist) |
|
for val in vals: |
|
u = np.argwhere(cumulative >= val) |
|
if len(u) > 0: |
|
cell = u[0, 0] |
|
cells.append(cell) |
|
if create: |
|
self.create_cells(cells) |
|
|
|
return cells |
|
|
|
def get_cell(self, i, create=True): |
|
""" Return the cell at index i. If the cell is virtual, then it will |
|
be instantiated first unless *create* is False. |
|
""" |
|
if create and self._cells[i]["cell"] == 0: |
|
self.create_cells([i]) |
|
return self._cells[i]["cell"] |
|
|
|
def get_cell_index(self, cell): |
|
"""Return the index of *cell*. |
|
""" |
|
return self._cell_indexes[cell] |
|
|
|
def create_cells(self, cell_inds): |
|
""" Instantiate each cell in *cell_inds*, which is a list of indexes into |
|
self.cells. |
|
""" |
|
for i in cell_inds: |
|
if self._cells[i]["cell"] != 0: |
|
continue |
|
cell = self.create_cell(self._cells[i]) |
|
self._cells[i]["cell"] = cell |
|
self._cell_indexes[cell] = i |
|
|
|
def create_cell(self, cell_rec): |
|
""" Return a single new cell to be used in this population. The |
|
*cell_rec* argument is the row from self.cells that describes the cell |
|
to be created. |
|
|
|
Subclasses must reimplement this method. |
|
""" |
|
raise NotImplementedError() |
|
|
|
def __str__(self): |
|
return "<Population %s (%d/%d real)>" % ( |
|
type(self).__name__, |
|
(self._cells["cell"] != 0).sum(), |
|
len(self._cells), |
|
) |
|
|
|
def __getstate__(self): |
|
"""Return a picklable copy of self.__dict__. |
|
|
|
Note that we remove references to the actual cells in order to allow pickling. |
|
""" |
|
state = self.__dict__.copy() |
|
state["_cells"] = state["_cells"].copy() |
|
|
|
for cell in state["_cells"]: |
|
if cell["cell"] != 0: |
|
cell["cell"] = cell[ |
|
"cell" |
|
].type # replace neuron object with just the cell type |
|
cell = str(cell) # make a string |
|
return state
|
|
|