from __future__ import print_function __author__ = "pbmanis" """ ManageANSpikes is a class to read the output of the Zilany et al. 2009 AN model into python, and provides services to access that data. Basic usage is to create an instance of the class, and specify the data directory if necessary. You may then get the data in the format of a list using one of the "get" routines. The data is pulled from the nearest CF or the specified CF. """ import os import re import numpy as np from collections import OrderedDict import scipy.io import matplotlib.pyplot as MP class ManageANSpikes: """ ManageANSpikes is a class to read the output of the Zilany et al. 2009 AN model into python, and provides services to access that data. Basic usage is to create an instance of the class, and specify the data directory if necessary. You may then get the data in the format of a list using one of the "get" routines. The data is pulled from the nearest CF or the specified CF. """ def __init__(self): # self.datadir = environment['HOME'] + '/Desktop/Matlab/ZilanyCarney-JASAcode-2009/' self.data_dir = os.environ["HOME"] + "/Desktop/Matlab/ANData/" self.data_read_flag = False self.dataType = "RI" self.set_CF_map(4000, 38000, 25) # the default list self.all_AN = None def get_data_dir(self): return self.data_dir def set_data_dir(self, directory): if os.path.isdir(directory): self.data_dir = directory else: raise ValueError("ManageANSpikes.set_data_dir: Path %d is not a directory") def get_data_read(self): return self.data_read_flag def get_CF_map(self): return self.CF_map def set_CF_map(self, low, high, nfreq): self.CF_map = np.round(np.logspace(np.log10(low), np.log10(high), nfreq)) def get_dataType(self): return self.dataType def set_dataType(self, datatype): if datatype in ["RI", "PL"]: # only for recognized data types self.dataType = datatype else: raise ValueError( "get_anspikes.set_dataType: unrecognized type %s " % datatype ) def plot_RI_vs_F(self, freq=10000.0, spontclass="HS", display=False): cfd = {} if display: MP.figure(10) for i, fr in enumerate(self.CF_map): self.read_AN_data(freq=freq, CF=fr, spontclass=spontclass, ignoreflag=True) nsp = np.zeros(len(self.SPLs)) # this is the same regardless for i, db in enumerate(self.SPLs): spkl = self.combine_reps(self.spikelist[i]) nsp[i] = len(spkl) # /self.SPLs[i] if display: MP.plot(self.SPLs, nsp) def plot_Rate_vs_F(self, freq=10000.0, spontclass="HS", display=False, SPL=0.0): """ assumes we have done a "read all" :param freq: :param spontclass: :param display: :param SPL: :return: """ cfd = {} if display: MP.figure(10) for i, db in enumerate(self.SPLs): # plot will have lines at each SPL # retrieve_from_all_AN(self, cf, SPL) # self.read_AN_data(freq=freq, CF=fr, spontclass=spontclass) nsp = np.zeros(len(self.CF_map)) # this is the same regardless for j, fr in enumerate(self.CF_map): spikelist = self.retrieve_from_all_AN(fr, db) spkl = self.combine_reps(spikelist) nsp[j] = len(spkl) / len(spikelist) if display: MP.plot(self.CF_map, nsp) def getANatFandSPL(self, spontclass="MS", freq=10000.0, CF=None, SPL=0): """ Get the AN data at a particular frequency and SPL. Note that the tone freq is specified, but the CF of the fiber might be different. If CF is None, we try to get the closest fiber data to the tone Freq. Other arguments are the spontaneous rate and the SPL level. The data must exist, or we epically fail. :param spontclass: 'HS', 'MS', or 'LS' for high, middle or low spont groups (1,2,3 in Zilany et al model) :param freq: The stimulus frequency, in Hz :param CF: The 'characteristic frequency' of the desired AN fiber, in Hz :param SPL: The sound pressure level, in dB SPL for the stimulus :return: an array of nReps of spike times. """ if CF is None: closest = np.argmin( np.abs(self.CF_map - freq) ) # find closest to the stim freq else: closest = np.argmin( np.abs(self.CF_map - CF) ) # find closest to the stim freq CF = self.CF_map[closest] print("closest f: ", CF) self.read_AN_data(spontclass=spontclass, freq=10000.0, CF=CF) return self.get_AN_at_SPL(SPL) def get_AN_at_SPL(self, spl=None): """ grabs the AN data for the requested SPL for the data set currently loaded The data must exist, or we epically fail. :param spl: sound pressure level, in dB SPL :return: spike trains (times) as a list of nReps numpy arrays. """ if not self.data_read_flag: print("getANatSPL: No data read yet") return None try: k = int(np.where(self.SPLs == spl)[0]) except (ValueError, TypeError) as e: print( "get_anspikes::getANatSPL: spl=%6.1f not in list of spls:" % (spl), self.SPLs, ) exit() # no match # now clean up array to keep it iterable upon return spkl = [[]] * len(self.spikelist[k]) for i in xrange(len(self.spikelist[k])): try: len(self.spikelist[k][i]) spkl[i] = self.spikelist[k][i] except: spkl[i] = [np.array(self.spikelist[k][i])] return spkl # returns all the trials in the list def read_all_ANdata( self, freq=10000.0, CFList=None, spontclass="HS", stim="BFTone" ): """ Reads a bank of AN data, across frequency and intensity, for a given stimulus frequency Assumptions: the bank of data is consistent in terms of nReps and SPLs :param freq: Tone stimulus frequency (if stim is Tone or BFTone) :param CFList: A list of the CFs to read :param spontclass: 'HS', 'MS', or 'LS', corresponding to spont rate groups :param stim: Stimulus type - Tone/BFTone, or Noise :return: Nothing. The data are stored in an array accessible directly or through a selector function """ self.all_AN = OrderedDict( [(f, None) for f in CFList] ) # access through dictionary with keys as freq for cf in CFList: self.read_AN_data( freq=freq, CF=cf, spontclass=spontclass, stim=stim, setflag=False ) self.all_AN[cf] = self.spikelist self.data_read_flag = True # inform and block def retrieve_from_all_AN(self, cf, SPL): """ :param cf: :param SPL: :return: spike list (all nreps) at the cf and SPL requested, for the loaded stimulus set """ if not self.data_read_flag or self.all_AN is None: print("get_anspikes::retrieve_from_all_AN: No data read yet: ") exit() ispl = int(np.where(self.SPLs == float(SPL))[0]) icf = self.all_AN.keys().index(cf) spikelist = self.all_AN[cf][ispl] spkl = [[]] * len(spikelist) # print len(spikelist) for i in xrange(len(spikelist)): # across trials try: len(spikelist[i]) spkl[i] = spikelist[i] except: spkl[i] = [np.array(spikelist[i])] # print 'cf: %f spl: %d nspk: %f' % (cf, SPL, len(spkl[i])) return spikelist def read_AN_data( self, freq=10000, CF=5300, spontclass="HS", display=False, stim="BFTone", setflag=True, ignoreflag=True, ): """ read responses of auditory nerve model of Zilany et al. (2009). display = True plots the psth's for all ANF channels in the dataset Request response to stimulus at Freq, for fiber with CF and specified spont rate This version is for rate-intensity runs March 2014. Returns: Nothing """ if not ignoreflag: assert self.data_read_flag == False # print 'Each instance of ManageANSPikes is allowed ONE data set to manage' # print 'This is a design decision to avoid confusion about which data is in the instance' # exit() if stim in ["BFTone", "Tone"]: fname = "%s_F%06.3f_CF%06.3f_%2s.mat" % ( self.dataType, freq / 1000.0, CF / 1000.0, spontclass, ) elif stim == "Noise": fname = "%s_Noise_CF%06.3f_%2s.mat" % ( self.dataType, CF / 1000.0, spontclass, ) # print 'Reading: %s' % (fname) try: mfile = scipy.io.loadmat( os.path.join(self.data_dir, fname), squeeze_me=True ) except IOError: print("get_anspikes::read_AN_data: Failed to find data file %s" % (fname)) print( "Corresponding to Freq: %f CF: %f spontaneous rate class: %s" % (freq, CF, spontclass) ) exit() n_spl = len(mfile[self.dataType]) if display: n = int(np.sqrt(n_spl)) + 1 mg = np.meshgrid(range(n), range(n)) mg = [mg[0].flatten(), mg[1].flatten()] spkl = [[]] * n_spl spl = np.zeros(n_spl) for k in range(n_spl): spkl[k] = mfile[self.dataType]["data"][k] # get data for one SPL spl[k] = mfile[self.dataType]["SPL"][k] # get SPL for this set of runs # print spkl[k].shape if display: self.display(spkl[k], k, n, mg) if display: MP.show() self.spikelist = spkl # save these so we can have a single point to parse them self.SPLs = spl self.n_reps = mfile[self.dataType]["nrep"] if setflag: self.data_read_flag = True # return spkl, spl, mfile['RI']['nrep'] def getANatSPL(self, spl=None): if not self.dataRead: print("getANatSPL: No data read yet") return None try: k = int(np.where(self.SPLs == spl)[0]) except (ValueError, TypeError) as e: print( "get_anspikes::getANatSPL: spl=%6.1f not in list of spls:" % (spl), self.SPLs, ) exit() # no match # now clean up array to keep it iterable upon return spkl = [[]] * len(self.spikelist[k]) for i in xrange(len(self.spikelist[k])): try: len(self.spikelist[k][i]) spkl[i] = self.spikelist[k][i] except: spkl[i] = [np.array(self.spikelist[k][i])] return spkl # returns all the trials in the list def get_AN_info(self): """ :return: Dictionary of nReps and SPLs that are in the current data set (instance). """ if not self.data_read_flag: print("getANatSPL: No data read yet") return None return {"nReps": self.n_reps, "SPLs": self.SPLs} def read_cmrr_data(self, P, signal="S0", display=False): """ read responses of auditory nerve model of Zilany et al. (2009). The required parameters are passed via the class P. This includes: the SN (s2m) is the signal-to-masker ratio to select the mode is 'CMR', 'CMD' (deviant) or 'REF' (reference, no flanking bands), 'Tone', or 'Noise' the Spont rate group signal is either 'S0' (signal present) or 'NS' (no signal) display = True plots the psth's for all ANF channels in the dataset Returns: tuple of (spikelist and frequency list for modulated tones) """ if P.modF == 10: datablock = "%s_F4000.0" % (P.mode) else: datablock = "%s_F4000.0_M%06.1f" % ( P.mode, P.modF, ) # 100 Hz modulation directory # print P.mode fs = os.listdir(self.datadir + datablock) s_sn = "SN%03d" % (P.s2m) fntemplate = "(\S+)_%s_%s_%s_%s.mat" % (s_sn, P.mode, signal, P.SR) p = re.compile(fntemplate) fl = [re.match(p, file) for file in fs] fl = [f.group(0) for f in fl if f != None] # returns list of files matching... fr = [float(f[1:7]) for f in fl] # frequency list if display: self.makeFig() i = 0 j = 0 spkl = [[]] * len(fl) n = int(np.sqrt(len(fl))) mg = np.meshgrid(range(n), range(n)) mg = [mg[0].flatten(), mg[1].flatten()] for k, fi in enumerate(fl): mfile = scipy.io.loadmat( self.datadir + datablock + "/" + fi, squeeze_me=True ) spkl[k] = mfile["Cpsth"] if display: self.display(spkl[k], k, n, mg) if display: MP.show() return (spkl, fr) def make_fig(self): self.fig = MP.figure(1) def combine_reps(self, spkl): """ Just turns the spike list into one big linear sequence. :param spkl: a spike train (nreps of numpy arrays) :return: all the spikes in one long array. """ # print spkl allsp = np.array(self.flatten(spkl)) # print allsp.shape if allsp.shape == (): # nothing to show return # print allsp spks = [] for i in range(len(allsp)): if allsp[i] == (): continue if isinstance(allsp[i], float): spks.append(allsp[i]) else: spks.extend(np.array(allsp[i])) return spks def display(self, spkl, k, n, mg): # print spkl spks = self.combine_reps(spkl) if spks != []: spks = np.sort(spks) MP.hist(spks, 100) return def flatten(self, x): """ flatten(sequence) -> list Returns a single, flat list which contains all elements retrieved from the sequence and all recursively contained sub-sequences (iterables). """ result = [] if isinstance(x, float) or len([x]) <= 1: return x for el in x: if hasattr(el, "__iter__") and not isinstance(el, basestring): result.extend(self.flatten(el)) else: result.append(el) return result if __name__ == "__main__": from cnmodel.util import Params modes = ["CMR", "CMD", "REF"] sr_types = ["H", "M", "L"] sr_type = sr_types[0] sr_names = { "H": "HS", "M": "MS", "L": "LS", } # spontaneous rate groups (AN fiber input selection) # define params for reading/testing. This is a subset of params... P = Params( mode=modes[0], s2m=0, n_rep=0, start_rep=0, sr=sr_names["L"], modulation_freq=10, DS_phase=0, dataset="test", sr_types=sr_types, sr_names=sr_names, fileVersion=1.5, ) manager = ManageANSpikes() # create instance of the manager print(dir(manager)) test = "RI" if test == "all": manager.read_all_ANdata( freq=10000.0, CFList=manager.CF_map, spontclass="HS", stim="BFTone" ) # spkl = manager.retrieve_from_all_AN(manager.CF_map[0], 40.) # spks = manager.combine_reps(spkl) manager.plot_Rate_vs_F(freq=manager.CF_map[0], display=True, spontclass="HS") MP.show() if test == manager.dataType: spikes = manager.getANatFandSPL(spontclass="MS", freq=10000.0, CF=None, SPL=50) spks = manager.combine_reps(spikes) if spks != []: print("spikes.") manager.plot_RI_vs_F(freq=10000.0, display=True, spontclass="MS") MP.figure(11) spks = np.sort(spks) MP.hist(spks, 100) MP.show()