You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
478 lines
16 KiB
478 lines
16 KiB
2 years ago
|
from __future__ import print_function
|
||
|
|
||
|
__author__ = "pbmanis"
|
||
|
"""
|
||
|
ManageANSpikes is a class to read the output of the Zilany et al. 2009 AN model into
|
||
|
python, and provides services to access that data.
|
||
|
|
||
|
Basic usage is to create an instance of the class, and specify the data directory
|
||
|
if necessary.
|
||
|
|
||
|
You may then get the data in the format of a list using one of the "get" routines.
|
||
|
The data is pulled from the nearest CF or the specified CF.
|
||
|
|
||
|
"""
|
||
|
import os
|
||
|
import re
|
||
|
import numpy as np
|
||
|
from collections import OrderedDict
|
||
|
import scipy.io
|
||
|
import matplotlib.pyplot as MP
|
||
|
|
||
|
|
||
|
class ManageANSpikes:
|
||
|
"""
|
||
|
ManageANSpikes is a class to read the output of the Zilany et al. 2009 AN model into
|
||
|
python, and provides services to access that data.
|
||
|
|
||
|
Basic usage is to create an instance of the class, and specify the data directory
|
||
|
if necessary.
|
||
|
|
||
|
You may then get the data in the format of a list using one of the "get" routines.
|
||
|
The data is pulled from the nearest CF or the specified CF.
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(self):
|
||
|
# self.datadir = environment['HOME'] + '/Desktop/Matlab/ZilanyCarney-JASAcode-2009/'
|
||
|
self.data_dir = os.environ["HOME"] + "/Desktop/Matlab/ANData/"
|
||
|
self.data_read_flag = False
|
||
|
self.dataType = "RI"
|
||
|
self.set_CF_map(4000, 38000, 25) # the default list
|
||
|
self.all_AN = None
|
||
|
|
||
|
def get_data_dir(self):
|
||
|
return self.data_dir
|
||
|
|
||
|
def set_data_dir(self, directory):
|
||
|
if os.path.isdir(directory):
|
||
|
self.data_dir = directory
|
||
|
else:
|
||
|
raise ValueError("ManageANSpikes.set_data_dir: Path %d is not a directory")
|
||
|
|
||
|
def get_data_read(self):
|
||
|
return self.data_read_flag
|
||
|
|
||
|
def get_CF_map(self):
|
||
|
return self.CF_map
|
||
|
|
||
|
def set_CF_map(self, low, high, nfreq):
|
||
|
self.CF_map = np.round(np.logspace(np.log10(low), np.log10(high), nfreq))
|
||
|
|
||
|
def get_dataType(self):
|
||
|
return self.dataType
|
||
|
|
||
|
def set_dataType(self, datatype):
|
||
|
if datatype in ["RI", "PL"]: # only for recognized data types
|
||
|
self.dataType = datatype
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
"get_anspikes.set_dataType: unrecognized type %s " % datatype
|
||
|
)
|
||
|
|
||
|
def plot_RI_vs_F(self, freq=10000.0, spontclass="HS", display=False):
|
||
|
cfd = {}
|
||
|
if display:
|
||
|
MP.figure(10)
|
||
|
for i, fr in enumerate(self.CF_map):
|
||
|
self.read_AN_data(freq=freq, CF=fr, spontclass=spontclass, ignoreflag=True)
|
||
|
nsp = np.zeros(len(self.SPLs)) # this is the same regardless
|
||
|
for i, db in enumerate(self.SPLs):
|
||
|
spkl = self.combine_reps(self.spikelist[i])
|
||
|
nsp[i] = len(spkl) # /self.SPLs[i]
|
||
|
if display:
|
||
|
MP.plot(self.SPLs, nsp)
|
||
|
|
||
|
def plot_Rate_vs_F(self, freq=10000.0, spontclass="HS", display=False, SPL=0.0):
|
||
|
"""
|
||
|
assumes we have done a "read all"
|
||
|
|
||
|
:param freq:
|
||
|
:param spontclass:
|
||
|
:param display:
|
||
|
:param SPL:
|
||
|
:return:
|
||
|
"""
|
||
|
cfd = {}
|
||
|
if display:
|
||
|
MP.figure(10)
|
||
|
for i, db in enumerate(self.SPLs): # plot will have lines at each SPL
|
||
|
# retrieve_from_all_AN(self, cf, SPL) # self.read_AN_data(freq=freq, CF=fr, spontclass=spontclass)
|
||
|
nsp = np.zeros(len(self.CF_map)) # this is the same regardless
|
||
|
for j, fr in enumerate(self.CF_map):
|
||
|
spikelist = self.retrieve_from_all_AN(fr, db)
|
||
|
spkl = self.combine_reps(spikelist)
|
||
|
nsp[j] = len(spkl) / len(spikelist)
|
||
|
if display:
|
||
|
MP.plot(self.CF_map, nsp)
|
||
|
|
||
|
def getANatFandSPL(self, spontclass="MS", freq=10000.0, CF=None, SPL=0):
|
||
|
"""
|
||
|
Get the AN data at a particular frequency and SPL. Note that the tone freq is specified,
|
||
|
but the CF of the fiber might be different. If CF is None, we try to get the closest
|
||
|
fiber data to the tone Freq.
|
||
|
Other arguments are the spontaneous rate and the SPL level.
|
||
|
The data must exist, or we epically fail.
|
||
|
|
||
|
:param spontclass: 'HS', 'MS', or 'LS' for high, middle or low spont groups (1,2,3 in Zilany et al model)
|
||
|
:param freq: The stimulus frequency, in Hz
|
||
|
:param CF: The 'characteristic frequency' of the desired AN fiber, in Hz
|
||
|
:param SPL: The sound pressure level, in dB SPL for the stimulus
|
||
|
:return: an array of nReps of spike times.
|
||
|
"""
|
||
|
|
||
|
if CF is None:
|
||
|
closest = np.argmin(
|
||
|
np.abs(self.CF_map - freq)
|
||
|
) # find closest to the stim freq
|
||
|
else:
|
||
|
closest = np.argmin(
|
||
|
np.abs(self.CF_map - CF)
|
||
|
) # find closest to the stim freq
|
||
|
CF = self.CF_map[closest]
|
||
|
print("closest f: ", CF)
|
||
|
self.read_AN_data(spontclass=spontclass, freq=10000.0, CF=CF)
|
||
|
return self.get_AN_at_SPL(SPL)
|
||
|
|
||
|
def get_AN_at_SPL(self, spl=None):
|
||
|
"""
|
||
|
grabs the AN data for the requested SPL for the data set currently loaded
|
||
|
The data must exist, or we epically fail.
|
||
|
|
||
|
:param spl: sound pressure level, in dB SPL
|
||
|
:return: spike trains (times) as a list of nReps numpy arrays.
|
||
|
"""
|
||
|
|
||
|
if not self.data_read_flag:
|
||
|
print("getANatSPL: No data read yet")
|
||
|
return None
|
||
|
try:
|
||
|
k = int(np.where(self.SPLs == spl)[0])
|
||
|
except (ValueError, TypeError) as e:
|
||
|
print(
|
||
|
"get_anspikes::getANatSPL: spl=%6.1f not in list of spls:" % (spl),
|
||
|
self.SPLs,
|
||
|
)
|
||
|
exit() # no match
|
||
|
# now clean up array to keep it iterable upon return
|
||
|
spkl = [[]] * len(self.spikelist[k])
|
||
|
for i in xrange(len(self.spikelist[k])):
|
||
|
try:
|
||
|
len(self.spikelist[k][i])
|
||
|
spkl[i] = self.spikelist[k][i]
|
||
|
except:
|
||
|
spkl[i] = [np.array(self.spikelist[k][i])]
|
||
|
return spkl # returns all the trials in the list
|
||
|
|
||
|
def read_all_ANdata(
|
||
|
self, freq=10000.0, CFList=None, spontclass="HS", stim="BFTone"
|
||
|
):
|
||
|
"""
|
||
|
Reads a bank of AN data, across frequency and intensity, for a given stimulus frequency
|
||
|
Assumptions: the bank of data is consistent in terms of nReps and SPLs
|
||
|
|
||
|
:param freq: Tone stimulus frequency (if stim is Tone or BFTone)
|
||
|
:param CFList: A list of the CFs to read
|
||
|
:param spontclass: 'HS', 'MS', or 'LS', corresponding to spont rate groups
|
||
|
:param stim: Stimulus type - Tone/BFTone, or Noise
|
||
|
:return: Nothing. The data are stored in an array accessible directly or through a selector function
|
||
|
"""
|
||
|
|
||
|
self.all_AN = OrderedDict(
|
||
|
[(f, None) for f in CFList]
|
||
|
) # access through dictionary with keys as freq
|
||
|
for cf in CFList:
|
||
|
self.read_AN_data(
|
||
|
freq=freq, CF=cf, spontclass=spontclass, stim=stim, setflag=False
|
||
|
)
|
||
|
self.all_AN[cf] = self.spikelist
|
||
|
self.data_read_flag = True # inform and block
|
||
|
|
||
|
def retrieve_from_all_AN(self, cf, SPL):
|
||
|
"""
|
||
|
:param cf:
|
||
|
:param SPL:
|
||
|
:return: spike list (all nreps) at the cf and SPL requested, for the loaded stimulus set
|
||
|
|
||
|
"""
|
||
|
|
||
|
if not self.data_read_flag or self.all_AN is None:
|
||
|
print("get_anspikes::retrieve_from_all_AN: No data read yet: ")
|
||
|
exit()
|
||
|
|
||
|
ispl = int(np.where(self.SPLs == float(SPL))[0])
|
||
|
icf = self.all_AN.keys().index(cf)
|
||
|
spikelist = self.all_AN[cf][ispl]
|
||
|
spkl = [[]] * len(spikelist)
|
||
|
# print len(spikelist)
|
||
|
for i in xrange(len(spikelist)): # across trials
|
||
|
try:
|
||
|
len(spikelist[i])
|
||
|
spkl[i] = spikelist[i]
|
||
|
except:
|
||
|
spkl[i] = [np.array(spikelist[i])]
|
||
|
# print 'cf: %f spl: %d nspk: %f' % (cf, SPL, len(spkl[i]))
|
||
|
return spikelist
|
||
|
|
||
|
def read_AN_data(
|
||
|
self,
|
||
|
freq=10000,
|
||
|
CF=5300,
|
||
|
spontclass="HS",
|
||
|
display=False,
|
||
|
stim="BFTone",
|
||
|
setflag=True,
|
||
|
ignoreflag=True,
|
||
|
):
|
||
|
"""
|
||
|
read responses of auditory nerve model of Zilany et al. (2009).
|
||
|
display = True plots the psth's for all ANF channels in the dataset
|
||
|
Request response to stimulus at Freq, for fiber with CF
|
||
|
and specified spont rate
|
||
|
This version is for rate-intensity runs March 2014.
|
||
|
Returns: Nothing
|
||
|
"""
|
||
|
|
||
|
if not ignoreflag:
|
||
|
assert self.data_read_flag == False
|
||
|
# print 'Each instance of ManageANSPikes is allowed ONE data set to manage'
|
||
|
# print 'This is a design decision to avoid confusion about which data is in the instance'
|
||
|
# exit()
|
||
|
if stim in ["BFTone", "Tone"]:
|
||
|
fname = "%s_F%06.3f_CF%06.3f_%2s.mat" % (
|
||
|
self.dataType,
|
||
|
freq / 1000.0,
|
||
|
CF / 1000.0,
|
||
|
spontclass,
|
||
|
)
|
||
|
elif stim == "Noise":
|
||
|
fname = "%s_Noise_CF%06.3f_%2s.mat" % (
|
||
|
self.dataType,
|
||
|
CF / 1000.0,
|
||
|
spontclass,
|
||
|
)
|
||
|
|
||
|
# print 'Reading: %s' % (fname)
|
||
|
try:
|
||
|
mfile = scipy.io.loadmat(
|
||
|
os.path.join(self.data_dir, fname), squeeze_me=True
|
||
|
)
|
||
|
except IOError:
|
||
|
print("get_anspikes::read_AN_data: Failed to find data file %s" % (fname))
|
||
|
print(
|
||
|
"Corresponding to Freq: %f CF: %f spontaneous rate class: %s"
|
||
|
% (freq, CF, spontclass)
|
||
|
)
|
||
|
exit()
|
||
|
|
||
|
n_spl = len(mfile[self.dataType])
|
||
|
if display:
|
||
|
n = int(np.sqrt(n_spl)) + 1
|
||
|
mg = np.meshgrid(range(n), range(n))
|
||
|
mg = [mg[0].flatten(), mg[1].flatten()]
|
||
|
spkl = [[]] * n_spl
|
||
|
spl = np.zeros(n_spl)
|
||
|
for k in range(n_spl):
|
||
|
spkl[k] = mfile[self.dataType]["data"][k] # get data for one SPL
|
||
|
spl[k] = mfile[self.dataType]["SPL"][k] # get SPL for this set of runs
|
||
|
# print spkl[k].shape
|
||
|
if display:
|
||
|
self.display(spkl[k], k, n, mg)
|
||
|
|
||
|
if display:
|
||
|
MP.show()
|
||
|
self.spikelist = spkl # save these so we can have a single point to parse them
|
||
|
self.SPLs = spl
|
||
|
self.n_reps = mfile[self.dataType]["nrep"]
|
||
|
if setflag:
|
||
|
self.data_read_flag = True
|
||
|
# return spkl, spl, mfile['RI']['nrep']
|
||
|
|
||
|
def getANatSPL(self, spl=None):
|
||
|
if not self.dataRead:
|
||
|
print("getANatSPL: No data read yet")
|
||
|
return None
|
||
|
try:
|
||
|
k = int(np.where(self.SPLs == spl)[0])
|
||
|
except (ValueError, TypeError) as e:
|
||
|
print(
|
||
|
"get_anspikes::getANatSPL: spl=%6.1f not in list of spls:" % (spl),
|
||
|
self.SPLs,
|
||
|
)
|
||
|
exit() # no match
|
||
|
# now clean up array to keep it iterable upon return
|
||
|
spkl = [[]] * len(self.spikelist[k])
|
||
|
for i in xrange(len(self.spikelist[k])):
|
||
|
try:
|
||
|
len(self.spikelist[k][i])
|
||
|
spkl[i] = self.spikelist[k][i]
|
||
|
except:
|
||
|
spkl[i] = [np.array(self.spikelist[k][i])]
|
||
|
return spkl # returns all the trials in the list
|
||
|
|
||
|
def get_AN_info(self):
|
||
|
"""
|
||
|
:return: Dictionary of nReps and SPLs that are in the current data set (instance).
|
||
|
"""
|
||
|
|
||
|
if not self.data_read_flag:
|
||
|
print("getANatSPL: No data read yet")
|
||
|
return None
|
||
|
return {"nReps": self.n_reps, "SPLs": self.SPLs}
|
||
|
|
||
|
def read_cmrr_data(self, P, signal="S0", display=False):
|
||
|
"""
|
||
|
read responses of auditory nerve model of Zilany et al. (2009).
|
||
|
The required parameters are passed via the class P.
|
||
|
This includes:
|
||
|
|
||
|
the SN (s2m) is the signal-to-masker ratio to select
|
||
|
the mode is 'CMR', 'CMD' (deviant) or 'REF' (reference, no flanking bands),
|
||
|
'Tone', or 'Noise'
|
||
|
the Spont rate group
|
||
|
|
||
|
signal is either 'S0' (signal present) or 'NS' (no signal)
|
||
|
display = True plots the psth's for all ANF channels in the dataset
|
||
|
Returns:
|
||
|
tuple of (spikelist and frequency list for modulated tones)
|
||
|
"""
|
||
|
|
||
|
if P.modF == 10:
|
||
|
datablock = "%s_F4000.0" % (P.mode)
|
||
|
else:
|
||
|
datablock = "%s_F4000.0_M%06.1f" % (
|
||
|
P.mode,
|
||
|
P.modF,
|
||
|
) # 100 Hz modulation directory
|
||
|
# print P.mode
|
||
|
fs = os.listdir(self.datadir + datablock)
|
||
|
s_sn = "SN%03d" % (P.s2m)
|
||
|
fntemplate = "(\S+)_%s_%s_%s_%s.mat" % (s_sn, P.mode, signal, P.SR)
|
||
|
p = re.compile(fntemplate)
|
||
|
fl = [re.match(p, file) for file in fs]
|
||
|
fl = [f.group(0) for f in fl if f != None] # returns list of files matching...
|
||
|
fr = [float(f[1:7]) for f in fl] # frequency list
|
||
|
if display:
|
||
|
self.makeFig()
|
||
|
i = 0
|
||
|
j = 0
|
||
|
spkl = [[]] * len(fl)
|
||
|
n = int(np.sqrt(len(fl)))
|
||
|
mg = np.meshgrid(range(n), range(n))
|
||
|
mg = [mg[0].flatten(), mg[1].flatten()]
|
||
|
for k, fi in enumerate(fl):
|
||
|
mfile = scipy.io.loadmat(
|
||
|
self.datadir + datablock + "/" + fi, squeeze_me=True
|
||
|
)
|
||
|
spkl[k] = mfile["Cpsth"]
|
||
|
if display:
|
||
|
self.display(spkl[k], k, n, mg)
|
||
|
|
||
|
if display:
|
||
|
MP.show()
|
||
|
return (spkl, fr)
|
||
|
|
||
|
def make_fig(self):
|
||
|
self.fig = MP.figure(1)
|
||
|
|
||
|
def combine_reps(self, spkl):
|
||
|
"""
|
||
|
|
||
|
Just turns the spike list into one big linear sequence.
|
||
|
:param spkl: a spike train (nreps of numpy arrays)
|
||
|
:return: all the spikes in one long array.
|
||
|
"""
|
||
|
|
||
|
# print spkl
|
||
|
allsp = np.array(self.flatten(spkl))
|
||
|
# print allsp.shape
|
||
|
if allsp.shape == (): # nothing to show
|
||
|
return
|
||
|
# print allsp
|
||
|
spks = []
|
||
|
for i in range(len(allsp)):
|
||
|
if allsp[i] == ():
|
||
|
continue
|
||
|
if isinstance(allsp[i], float):
|
||
|
spks.append(allsp[i])
|
||
|
else:
|
||
|
spks.extend(np.array(allsp[i]))
|
||
|
return spks
|
||
|
|
||
|
def display(self, spkl, k, n, mg):
|
||
|
# print spkl
|
||
|
spks = self.combine_reps(spkl)
|
||
|
if spks != []:
|
||
|
spks = np.sort(spks)
|
||
|
MP.hist(spks, 100)
|
||
|
return
|
||
|
|
||
|
def flatten(self, x):
|
||
|
"""
|
||
|
flatten(sequence) -> list
|
||
|
|
||
|
Returns a single, flat list which contains all elements retrieved
|
||
|
from the sequence and all recursively contained sub-sequences
|
||
|
(iterables).
|
||
|
"""
|
||
|
result = []
|
||
|
if isinstance(x, float) or len([x]) <= 1:
|
||
|
return x
|
||
|
for el in x:
|
||
|
if hasattr(el, "__iter__") and not isinstance(el, basestring):
|
||
|
result.extend(self.flatten(el))
|
||
|
else:
|
||
|
result.append(el)
|
||
|
return result
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
from cnmodel.util import Params
|
||
|
|
||
|
modes = ["CMR", "CMD", "REF"]
|
||
|
sr_types = ["H", "M", "L"]
|
||
|
sr_type = sr_types[0]
|
||
|
sr_names = {
|
||
|
"H": "HS",
|
||
|
"M": "MS",
|
||
|
"L": "LS",
|
||
|
} # spontaneous rate groups (AN fiber input selection)
|
||
|
# define params for reading/testing. This is a subset of params...
|
||
|
P = Params(
|
||
|
mode=modes[0],
|
||
|
s2m=0,
|
||
|
n_rep=0,
|
||
|
start_rep=0,
|
||
|
sr=sr_names["L"],
|
||
|
modulation_freq=10,
|
||
|
DS_phase=0,
|
||
|
dataset="test",
|
||
|
sr_types=sr_types,
|
||
|
sr_names=sr_names,
|
||
|
fileVersion=1.5,
|
||
|
)
|
||
|
|
||
|
manager = ManageANSpikes() # create instance of the manager
|
||
|
print(dir(manager))
|
||
|
test = "RI"
|
||
|
|
||
|
if test == "all":
|
||
|
manager.read_all_ANdata(
|
||
|
freq=10000.0, CFList=manager.CF_map, spontclass="HS", stim="BFTone"
|
||
|
)
|
||
|
# spkl = manager.retrieve_from_all_AN(manager.CF_map[0], 40.)
|
||
|
# spks = manager.combine_reps(spkl)
|
||
|
manager.plot_Rate_vs_F(freq=manager.CF_map[0], display=True, spontclass="HS")
|
||
|
MP.show()
|
||
|
|
||
|
if test == manager.dataType:
|
||
|
spikes = manager.getANatFandSPL(spontclass="MS", freq=10000.0, CF=None, SPL=50)
|
||
|
spks = manager.combine_reps(spikes)
|
||
|
if spks != []:
|
||
|
print("spikes.")
|
||
|
manager.plot_RI_vs_F(freq=10000.0, display=True, spontclass="MS")
|
||
|
MP.figure(11)
|
||
|
spks = np.sort(spks)
|
||
|
MP.hist(spks, 100)
|
||
|
MP.show()
|