-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathapproximate_filtering_model.py
More file actions
151 lines (128 loc) · 5.99 KB
/
approximate_filtering_model.py
File metadata and controls
151 lines (128 loc) · 5.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from shared import *
from hrtf_analysis import *
from models import *
import gc
class ApproximateFilteringModel(object):
'''
Initialise this object with an hrtfset, a cochlear range (cfmin, cfmax, cfN),
and optionally:
a model for the coincidence detector neurons (cd_model),
a model for the filter neurons (filtergroup_model),
whether or not to use the correlation-based best delays (use_delays),
whether or not to use the best gains (use_gains),
whether or not to use only the phase information (delays between -pi and pi),
an alternative set of itd/ild pairs (itdild, see the file hrtf_analysis.py
for more information on this, function hrtfset_itd_ild).
The __call__ method returns a count (see docstring of that method).
'''
def __init__(self, hrtfset, cfmin, cfmax, cfN,
cd_model=standard_cd_model,
filtergroup_model=standard_filtergroup_model,
use_delays=True, use_gains=True, use_only_phase=False,
itdild=None,
):
self.hrtfset = hrtfset
self.cfmin, self.cfmax, self.cfN = cfmin, cfmax, cfN
self.cd_model = cd_model
self.filtergroup_model = filtergroup_model
self.num_indices = num_indices = hrtfset.num_indices
cf = erbspace(cfmin, cfmax, cfN)
# extract ITDs/ILDs in the right form
if itdild is None:
all_itds, all_ilds = hrtfset_itd_ild(hrtfset, cfmin, cfmax, cfN)
else:
all_itds, all_ilds = itdild
d = array([all_itds[j][i] for i in xrange(cfN) for j in xrange(num_indices)])
g = array([all_ilds[j][i] for i in xrange(cfN) for j in xrange(num_indices)])
gains = hstack((1/g, g))
gains_dB = 20*log10(gains)
abs_gains_dB = abs(gains_dB)
r = -abs_gains_dB[:len(gains)/2]
r = hstack((r, r))
gains_dB += r
gains = 10**(gains_dB/20)
gains = reshape(gains, (1, len(gains)))
if use_only_phase:
d_cf = repeat(cf, num_indices)
d = imag(log(exp(1j*2*pi*d*d_cf)))/(2*pi*d_cf) # delays constrained to having their phase be in [-pi, pi]
delays_L = where(d>=0, zeros(len(d)), -d)
delays_R = where(d>=0, d, zeros(len(d)))
delay_max = max(amax(delays_L), amax(delays_R))*second
if not use_gains:
gains = ones(len(gains))
if not use_delays:
delays_L = delays_R = zeros(len(d))
delay_max = 2/samplerate
# dummy sound, when we run apply() we replace it
sound = Sound((silence(1*ms), silence(1*ms)))
soundinput = DoNothingFilterbank(sound)
gfb = Gammatone(Repeat(soundinput, cfN), hstack((cf, cf)))
gains_fb = FunctionFilterbank(Repeat(gfb, num_indices),
lambda x:x*gains)
compress = filtergroup_model['compress']
cochlea = FunctionFilterbank(gains_fb, lambda x:compress(clip(x, 0, Inf)))
# Create the filterbank group
eqs = Equations(filtergroup_model['eqs'], **filtergroup_model['parameters'])
G = FilterbankGroup(cochlea, 'target_var', eqs,
threshold=filtergroup_model['threshold'],
reset=filtergroup_model['reset'],
refractory=filtergroup_model['refractory'])
# create the synchrony group
cd_eqs = Equations(cd_model['eqs'], **cd_model['parameters'])
cd = NeuronGroup(num_indices*cfN, cd_eqs,
threshold=cd_model['threshold'],
reset=cd_model['reset'],
refractory=cd_model['refractory'],
clock=G.clock)
# set up the synaptic connectivity
cd_weight = cd_model['weight']
C = Connection(G, cd, 'target_var', delay=True, max_delay=delay_max)
for i in xrange(num_indices*cfN):
C[i, i] = cd_weight
C[i+num_indices*cfN, i] = cd_weight
C.delay[i, i] = delays_L[i]
C.delay[i+cfN*num_indices, i] = delays_R[i]
self.soundinput = soundinput
self.filtergroup = G
self.synchronygroup = cd
self.synapses = C
self.counter = SpikeCounter(cd)
self.network = Network(G, cd, C, self.counter)
def __call__(self, sound, index=None, **indexkwds):
'''
Apply approximate filtering group to given sound, which should be a
stereo sound unless you specify the HRTF index, or coordinates of
the HRTF index as keyword arguments, in which case it should be a mono
sound which will have the given HRTF applied to it. You can also
specify index=hrtf. Returns the spike count of the neurons in the synchrony
group with shape (cfN, num_indices).
'''
hrtf = None
if index is not None:
hrtf = self.hrtfset[index]
elif isinstance(index, HRTF):
hrtf = index
elif len(indexkwds):
hrtf = self.hrtfset(**indexkwds)
if hrtf is not None:
sound = hrtf(sound)
self.soundinput.source = sound
self.network.reinit()
self.filtergroup_model['init'](self.filtergroup,
self.filtergroup_model['parameters'])
self.cd_model['init'](self.synchronygroup, self.cd_model['parameters'])
self.network.run(sound.duration, report='stderr')
count = reshape(self.counter.count, (self.cfN, self.num_indices))
return count
if __name__=='__main__':
from plot_count import ircam_plot_count
hrtfdb = get_ircam()
subject = 1002
hrtfset = hrtfdb.load_subject(subject)
index = randint(hrtfset.num_indices)
cfmin, cfmax, cfN = 150*Hz, 5*kHz, 80
sound = whitenoise(500*ms)
afmodel = ApproximateFilteringModel(hrtfset, cfmin, cfmax, cfN)
count = afmodel(sound, index)
ircam_plot_count(hrtfset, count, index=index)
show()