Skip to content

Commit

Permalink
remove models_mlp importing
Browse files Browse the repository at this point in the history
  • Loading branch information
keunwoochoi committed Dec 31, 2017
2 parents d8db150 + 03c86a8 commit 797edb5
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 153 deletions.
324 changes: 176 additions & 148 deletions Example 1 - a pitch detection network with Dense layers.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion Example_5-1.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def main(model_name, exp_name='fma'):
model = models_time_varying.model_convrnn(n_out=2)
elif model_name == 'lstm':
model = models_time_varying.model_lstm_leglaive_icassp2014(n_out=2, bidirectional=False)
elif model_name == 'lstm_bi'
elif model_name == 'lstm_bi':
model = models_time_varying.model_lstm_leglaive_icassp2014(n_out=2, bidirectional=True)

model.compile('adam', 'categorical_crossentropy', metrics=['accuracy'])
Expand Down
8 changes: 4 additions & 4 deletions models_time_varying.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def model_convrnn(n_out, input_shape=(1, None), out_activation='softmax'):
"""
assert input_shape[0] == 1, 'Mono input please!'
model = Sequential()
n_mels = 64
n_mels = 40
model.add(Melspectrogram(sr=SR, n_mels=n_mels, power_melgram=2.0,
return_decibel_melgram=True,
input_shape=input_shape))
Expand All @@ -57,7 +57,7 @@ def model_convrnn(n_out, input_shape=(1, None), out_activation='softmax'):
model.add(BatchNormalization(axis=channel_axis))
model.add(Activation('relu'))

if K.image_dim_ordering() == 'channels_first': # (ch, freq, time)
if K.image_data_format() == 'channels_first': # (ch, freq, time)
model.add(Permute((3, 2, 1))) # (time, freq, ch)
else: # (freq, time, ch)
model.add(Permute((2, 1, 3))) # (time, ch, freq)
Expand Down Expand Up @@ -98,13 +98,13 @@ def model_lstm_leglaive_icassp2014(n_out, input_shape=(1, None),
return_decibel_melgram=True,
input_shape=input_shape))

model.add(BatchNormalization(axis=channel_axis))

if K.image_data_format() == 'channels_first':
model.add(Permute((3, 2, 1))) # ch, freq, time -> time, freq, ch
else:
model.add(Permute((2, 1, 3))) # freq, time, ch -> time, freq, ch

model.add(BatchNormalization(axis=channel_axis))

# Reshape for LSTM
model.add(Lambda(lambda x: K.squeeze(x, axis=3),
output_shape=squeeze_output_shape))
Expand Down

0 comments on commit 797edb5

Please sign in to comment.