You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
444 lines
17 KiB
444 lines
17 KiB
2 years ago
|
import logging
|
||
|
import scipy.stats
|
||
|
import numpy as np
|
||
|
|
||
|
from .. import data
|
||
|
|
||
|
|
||
|
class Population(object):
|
||
|
"""
|
||
|
A Population represents a group of cell all having the same type.
|
||
|
|
||
|
Populations provide methods for:
|
||
|
|
||
|
* Adding cells to the population with characteristic distributions.
|
||
|
* Connecting the cells in one population to the cells in another.
|
||
|
* Automatically adding cells to satisfy connectivity requirements when
|
||
|
connecting populations together.
|
||
|
|
||
|
Populations have a concept of a "natural" underlying distribution of
|
||
|
neurons, and behave as if all neurons in this distribution already exist
|
||
|
in the model. However, initially all neurons are virtual, and are only
|
||
|
instantiated to become a part of the running model if the neuron provides
|
||
|
synaptic input to another non-virtual neuron, or if the user explicitly
|
||
|
requests a recording of the neuron.
|
||
|
|
||
|
Subclasses represent populations for a specific cell type, and at least
|
||
|
need to reimplement the `create_cell` and `connection_stats` methods.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, species, size, fields, synapsetype="multisite", **kwds):
|
||
|
self._species = species
|
||
|
self._post_connections = [] # populations this one connects to
|
||
|
self._pre_connections = [] # populations connecting to this one
|
||
|
self._synapsetype = synapsetype
|
||
|
# numpy record array with information about each cell in the
|
||
|
# population
|
||
|
fields = [
|
||
|
("id", int),
|
||
|
("cell", object),
|
||
|
("input_resolved", bool),
|
||
|
("connections", object), # {pop: [cells], ...}
|
||
|
] + fields
|
||
|
self._cells = np.zeros(size, dtype=fields)
|
||
|
self._cells["id"] = np.arange(size)
|
||
|
self._cell_indexes = {} # maps cell:index
|
||
|
self._cell_args = kwds
|
||
|
|
||
|
@property
|
||
|
def cells(self):
|
||
|
""" The array of cells in this population.
|
||
|
|
||
|
For all populations, this array has a 'cell' field that is either 0
|
||
|
(for virtual cells) or a Cell instance (for real cells).
|
||
|
|
||
|
Extra fields may be added by each Population subclass.
|
||
|
"""
|
||
|
return self._cells.copy()
|
||
|
|
||
|
@property
|
||
|
def species(self):
|
||
|
return self._species
|
||
|
|
||
|
def unresolved_cells(self):
|
||
|
""" Return indexes of all real cells whose inputs have not been
|
||
|
resolved.
|
||
|
"""
|
||
|
real = self._cells["cell"] != 0
|
||
|
unresolved = self._cells["input_resolved"] == False
|
||
|
return np.argwhere(real & unresolved)[:, 0]
|
||
|
|
||
|
def real_cells(self):
|
||
|
""" Return indexes of all real cells in this population.
|
||
|
|
||
|
Initially, all cells in the population are virtual--they are accounted
|
||
|
for, but not actually instantiated as part of the NEURON simulation.
|
||
|
Virtual cells can be made real by calling `get_cell()`. This method
|
||
|
returns the indexes of all cells for which `get_cell()` has already
|
||
|
been invoked.
|
||
|
"""
|
||
|
return np.argwhere(self._cells["cell"] != 0)[:, 0]
|
||
|
|
||
|
def connect(self, *pops):
|
||
|
""" Connect this population to any number of other populations.
|
||
|
|
||
|
A connection is unidirectional; calling ``pop1.connect(pop2)`` can only
|
||
|
result in projections from pop1 to pop2.
|
||
|
|
||
|
Note that the connection is purely symbolic at first; no cells are
|
||
|
actually connected by synapses at this time.
|
||
|
"""
|
||
|
self._post_connections.extend(pops)
|
||
|
for pop in pops:
|
||
|
pop._pre_connections.append(self)
|
||
|
|
||
|
@property
|
||
|
def pre_connections(self):
|
||
|
""" The list of populations connected to this one.
|
||
|
"""
|
||
|
return self._pre_connections[:]
|
||
|
|
||
|
def cell_connections(self, index):
|
||
|
""" Return a dictionary containing, for each population, a list of
|
||
|
cells connected to the cell in this population at *index*.
|
||
|
"""
|
||
|
return self._cells[index]["connections"]
|
||
|
|
||
|
def resolve_inputs(self, depth=1, showlog=False):
|
||
|
""" For each _real_ cell in the population, select a set of
|
||
|
presynaptic partners from each connected population and generate a
|
||
|
synapse from each.
|
||
|
|
||
|
Although it is allowed to call ``resolve_inputs`` multiple times for
|
||
|
a single population, each individual cell will only resolve its inputs
|
||
|
once. Therefore, it is recommended to create and connect all
|
||
|
populations before making any calls to ``resolve_inputs``.
|
||
|
"""
|
||
|
for i in self.unresolved_cells():
|
||
|
# loop over all cells whose presynaptic inputs have not been resolved
|
||
|
cell = self._cells[i]["cell"]
|
||
|
if showlog:
|
||
|
logging.info("Resolving inputs for %s %d", self, i)
|
||
|
self._cells[i]["connections"] = {}
|
||
|
|
||
|
# select cells from each population to connect to this cell
|
||
|
for pop in self._pre_connections:
|
||
|
pre_cells = self.connect_pop_to_cell(pop, i)
|
||
|
if showlog:
|
||
|
logging.info(" connected %d cells from %s", len(pre_cells), pop)
|
||
|
assert pre_cells is not None
|
||
|
self._cells[i]["connections"][pop] = pre_cells
|
||
|
self._cells[i]["input_resolved"] = True
|
||
|
|
||
|
# recursively resolve inputs in connected populations
|
||
|
if depth > 1:
|
||
|
for pop in self.pre_connections:
|
||
|
pop.resolve_inputs(depth - 1, showlog=showlog)
|
||
|
|
||
|
def connect_pop_to_cell(self, pop, cell_index):
|
||
|
""" Connect cells in a presynaptic population to the cell in this
|
||
|
population at *cell_index*, and return the presynaptic indexes of cells
|
||
|
that were connected.
|
||
|
|
||
|
This method is responsible for choosing pairs of cells to be connected
|
||
|
by synapses, and may be overridden in subclasses.
|
||
|
|
||
|
The default implementation calls `self.connection_stats()` to determine
|
||
|
the number and selection criteria of presynaptic cells.
|
||
|
"""
|
||
|
cell_rec = self._cells[cell_index]
|
||
|
cell = cell_rec["cell"]
|
||
|
size, dist = self.connection_stats(pop, cell_rec)
|
||
|
# Select SGCs from distribution, create, and connect to this cell
|
||
|
# todo: select sgcs with similar spont. rate?
|
||
|
pre_cells = pop.select(size=size, create=False, **dist)
|
||
|
for j in pre_cells:
|
||
|
pre_cell = pop.get_cell(j)
|
||
|
# use default settings for connecting these.
|
||
|
pre_cell.connect(cell, type=self._synapsetype)
|
||
|
return pre_cells
|
||
|
|
||
|
def connection_stats(self, pop, cell_rec):
|
||
|
""" The population *pop* is being connected to the cell described in
|
||
|
*cell_rec*.
|
||
|
|
||
|
This method is responsible for deciding the distributions of presynaptic
|
||
|
cell properties for any given postsynaptic cell (for example, a cell
|
||
|
with cf=10kHz might receive SGC input from 10 cells selected from a
|
||
|
normal distribution centered at 10kHz).
|
||
|
|
||
|
The default implementation of this method uses the 'convergence' and
|
||
|
'convergence_range' values from the data tables to specify a lognormal
|
||
|
distribution of presynaptic cells around the postsynaptic cell's CF.
|
||
|
|
||
|
This method must return a tuple (size, dist) with the following values:
|
||
|
|
||
|
* size: integer giving the number of cells that should be selected from
|
||
|
the presynaptic population and connected to the postsynaptic cell.
|
||
|
* dist: dictionary of {property_name: distribution} pairs that describe
|
||
|
how cells should be selected from the presynaptic population. See
|
||
|
keyword arguments to `select()` for more information on the content
|
||
|
of this dictionary.
|
||
|
"""
|
||
|
cf = cell_rec["cf"]
|
||
|
|
||
|
# Convergence distributions (how many presynaptic
|
||
|
# cells to connect)
|
||
|
try:
|
||
|
n_connections = data.get(
|
||
|
"convergence",
|
||
|
species=self.species,
|
||
|
pre_type=pop.type,
|
||
|
post_type=self.type,
|
||
|
)
|
||
|
except KeyError:
|
||
|
raise TypeError(
|
||
|
"Cannot connect population %s to %s; no convergence specified in data table."
|
||
|
% (pop, self)
|
||
|
)
|
||
|
|
||
|
if isinstance(n_connections, tuple):
|
||
|
size_dist = scipy.stats.norm(loc=n_connections[0], scale=n_connections[1])
|
||
|
size = max(0, size_dist.rvs())
|
||
|
else:
|
||
|
size = n_connections
|
||
|
size = int(size) # must be an integer at this point
|
||
|
|
||
|
# Convergence ranges -- over what range of CFs should we
|
||
|
# select presynaptic cells.
|
||
|
try:
|
||
|
input_range = data.get(
|
||
|
"convergence_range",
|
||
|
species=self.species,
|
||
|
pre_type=pop.type,
|
||
|
post_type=self.type,
|
||
|
)
|
||
|
except KeyError:
|
||
|
raise TypeError(
|
||
|
"Cannot connect population %s to %s; no convergence range specified in data table."
|
||
|
% (pop, self)
|
||
|
)
|
||
|
|
||
|
dist = {"cf": scipy.stats.lognorm(input_range, scale=cf)}
|
||
|
# print(cf, input_range, dist)
|
||
|
return size, dist
|
||
|
|
||
|
def get_sgcsr_array(self, freqs, species="mouse"):
|
||
|
"""
|
||
|
Create an array of length (freqs) (number of SGCs)
|
||
|
Each entry is a value indicating the SR group, according to some proportion
|
||
|
2 = high, 1 = medium, 0 = low
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
freqs : nunpy array
|
||
|
|
||
|
species : str (default: 'mouse')
|
||
|
name of the species for the map.
|
||
|
|
||
|
Returns:
|
||
|
numpy array
|
||
|
An array matched to freqs, with SR's indicated numerically
|
||
|
"""
|
||
|
assert species == "mouse" # only mice so far.
|
||
|
nhs = np.random.random_sample(
|
||
|
freqs.shape[0]
|
||
|
) # uniform random distribution across frequency
|
||
|
sr_array = np.zeros_like(freqs) # build array - initially all low sponts
|
||
|
sr_array[
|
||
|
np.argwhere(nhs < 0.53)
|
||
|
] = 2 # high spont (53% estimated from Taberner and Liberman, 2005)
|
||
|
sr_array[
|
||
|
np.argwhere((nhs >= 0.53) & (nhs < 0.77))
|
||
|
] = 1 # medium spont, about 24% (1-20 sp/sec)
|
||
|
# the rest have SR value of 0, corresponding to the low-spont group
|
||
|
return sr_array
|
||
|
|
||
|
def select_sgcsr_inputs(self, sr_array, weights):
|
||
|
"""
|
||
|
Subsample the arrays above to create a distribution for cells that might only get
|
||
|
a subset of inputs (for example, only msr and lsr fibers)
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
sr_array : numpy array
|
||
|
the SR array to draw the samples from
|
||
|
|
||
|
weights : 3 element list
|
||
|
Weights for [lsr, msr, hsr] ANFs. Proportions will be computed
|
||
|
from these weights (e.g., [1,1,1] is uniform for all fibers)
|
||
|
weights of [1,1,0] means all hsr fibers will be masked
|
||
|
|
||
|
Returns:
|
||
|
numpy array of "dist"
|
||
|
Values of 0 are sgcs masked from input, 1 are ok
|
||
|
"""
|
||
|
assert len(weights) == 3
|
||
|
|
||
|
dist = np.zeros_like(sr_array) # boolean array, all values
|
||
|
norm_wt = 3.0 * np.array(
|
||
|
weights / np.sum(weights)
|
||
|
) # fraction from within each group
|
||
|
for i in range(len(weights)):
|
||
|
dx = np.where(sr_array == i)[0]
|
||
|
ind = np.random.choice(len(dx), int(norm_wt[i] * len(dx)))
|
||
|
dist[dx[ind]] = 1
|
||
|
return dist
|
||
|
|
||
|
def _get_cf_array(self, species):
|
||
|
"""Return the array of CF values that should be used when instantiating
|
||
|
this population.
|
||
|
|
||
|
Commonly used by subclasses durin initialization.
|
||
|
"""
|
||
|
size = data.get(
|
||
|
"populations", species=species, cell_type=self.type, field="n_cells"
|
||
|
)
|
||
|
fmin = data.get(
|
||
|
"populations", species=species, cell_type=self.type, field="cf_min"
|
||
|
)
|
||
|
fmax = data.get(
|
||
|
"populations", species=species, cell_type=self.type, field="cf_max"
|
||
|
)
|
||
|
s = (fmax / fmin) ** (1.0 / size)
|
||
|
freqs = fmin * s ** np.arange(size)
|
||
|
# print('frqs #: ', len(freqs))
|
||
|
# Cut off at 40kHz because the auditory nerve model only goes that far :(
|
||
|
freqs = freqs[freqs <= 40e3]
|
||
|
|
||
|
return freqs
|
||
|
|
||
|
def select(self, size, create=False, **kwds):
|
||
|
""" Return a list of indexes for cells matching the selection criteria.
|
||
|
|
||
|
The *size* argument specifies the number of cells to return.
|
||
|
|
||
|
If *create* is True, then any selected cells that are virtual will be
|
||
|
instantiated.
|
||
|
|
||
|
Each keyword argument must be the name of a field in self.cells. Values
|
||
|
may be:
|
||
|
|
||
|
* A distribution (see scipy.stats), in which case the distribution
|
||
|
influences the selection of cells
|
||
|
* An array giving the probability to assign to each cell in the
|
||
|
population
|
||
|
* A number, in which case the cell(s) with the closest match
|
||
|
are returned. If this is used, it overrides all other criteria except
|
||
|
where they evaluate to 0.
|
||
|
|
||
|
If multiple distributions are provided, then the product of the survival
|
||
|
functions of all distributions determines the probability of selecting
|
||
|
each cell.
|
||
|
"""
|
||
|
if len(kwds) == 0:
|
||
|
raise TypeError("Must specify at least one selection criteria")
|
||
|
|
||
|
full_dist = np.ones(len(self._cells))
|
||
|
nearest = None
|
||
|
nearest_field = None
|
||
|
for field, dist in kwds.items():
|
||
|
if np.isscalar(dist):
|
||
|
if nearest is not None:
|
||
|
raise Exception(
|
||
|
"May not specify multiple single-valued selection criteria."
|
||
|
)
|
||
|
nearest = dist
|
||
|
nearest_field = field
|
||
|
elif isinstance(dist, scipy.stats.distributions.rv_frozen):
|
||
|
vals = self._cells[field]
|
||
|
dens = np.diff(vals)
|
||
|
dens = np.concatenate([dens[:1], dens])
|
||
|
pdf = dist.pdf(vals) * dens
|
||
|
full_dist *= pdf / pdf.sum()
|
||
|
elif isinstance(dist, np.ndarray):
|
||
|
full_dist *= dist
|
||
|
else:
|
||
|
raise TypeError("Distributed criteria must be array or rv_frozen.")
|
||
|
|
||
|
# Select cells nearest to the requested value, but only pick from
|
||
|
# cells with nonzero probability.
|
||
|
if nearest is not None:
|
||
|
cells = []
|
||
|
mask = full_dist == 0
|
||
|
err = np.abs(self._cells[nearest_field] - nearest).astype(float)
|
||
|
for i in range(size):
|
||
|
err[mask] = np.inf
|
||
|
cell = np.argmin(err)
|
||
|
mask[cell] = True
|
||
|
cells.append(cell)
|
||
|
|
||
|
# Select cells randomly from the specified combined probability
|
||
|
# distribution
|
||
|
else:
|
||
|
cells = []
|
||
|
full_dist /= full_dist.sum()
|
||
|
vals = np.random.uniform(size=size)
|
||
|
vals.sort()
|
||
|
cumulative = np.cumsum(full_dist)
|
||
|
for val in vals:
|
||
|
u = np.argwhere(cumulative >= val)
|
||
|
if len(u) > 0:
|
||
|
cell = u[0, 0]
|
||
|
cells.append(cell)
|
||
|
if create:
|
||
|
self.create_cells(cells)
|
||
|
|
||
|
return cells
|
||
|
|
||
|
def get_cell(self, i, create=True):
|
||
|
""" Return the cell at index i. If the cell is virtual, then it will
|
||
|
be instantiated first unless *create* is False.
|
||
|
"""
|
||
|
if create and self._cells[i]["cell"] == 0:
|
||
|
self.create_cells([i])
|
||
|
return self._cells[i]["cell"]
|
||
|
|
||
|
def get_cell_index(self, cell):
|
||
|
"""Return the index of *cell*.
|
||
|
"""
|
||
|
return self._cell_indexes[cell]
|
||
|
|
||
|
def create_cells(self, cell_inds):
|
||
|
""" Instantiate each cell in *cell_inds*, which is a list of indexes into
|
||
|
self.cells.
|
||
|
"""
|
||
|
for i in cell_inds:
|
||
|
if self._cells[i]["cell"] != 0:
|
||
|
continue
|
||
|
cell = self.create_cell(self._cells[i])
|
||
|
self._cells[i]["cell"] = cell
|
||
|
self._cell_indexes[cell] = i
|
||
|
|
||
|
def create_cell(self, cell_rec):
|
||
|
""" Return a single new cell to be used in this population. The
|
||
|
*cell_rec* argument is the row from self.cells that describes the cell
|
||
|
to be created.
|
||
|
|
||
|
Subclasses must reimplement this method.
|
||
|
"""
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def __str__(self):
|
||
|
return "<Population %s (%d/%d real)>" % (
|
||
|
type(self).__name__,
|
||
|
(self._cells["cell"] != 0).sum(),
|
||
|
len(self._cells),
|
||
|
)
|
||
|
|
||
|
def __getstate__(self):
|
||
|
"""Return a picklable copy of self.__dict__.
|
||
|
|
||
|
Note that we remove references to the actual cells in order to allow pickling.
|
||
|
"""
|
||
|
state = self.__dict__.copy()
|
||
|
state["_cells"] = state["_cells"].copy()
|
||
|
|
||
|
for cell in state["_cells"]:
|
||
|
if cell["cell"] != 0:
|
||
|
cell["cell"] = cell[
|
||
|
"cell"
|
||
|
].type # replace neuron object with just the cell type
|
||
|
cell = str(cell) # make a string
|
||
|
return state
|