-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrain.py
210 lines (190 loc) · 11.1 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import argparse
import tensorflow as tf
import multiprocessing
import os
from tensor2tensor.data_generators.text_encoder import PAD_ID, EOS_ID
import utils
from model_helper import las_model_fn
def parse_args():
parser = argparse.ArgumentParser(
description='Listen, Attend and Spell(LAS) implementation based on Tensorflow. '
'The model utilizes input pipeline and estimator API of Tensorflow, '
'which makes the training procedure truly end-to-end.')
parser.add_argument('--train', type=str, required=True,
help='training data in TFRecord format')
parser.add_argument('--valid', type=str,
help='validation data in TFRecord format')
parser.add_argument('--t2t_format', action='store_true',
help='Use dataset in the format of ASR problems of Tensor2Tensor framework. --train param should be directory')
parser.add_argument('--t2t_problem_name', type=str,
help='Problem name for data in T2T format.')
parser.add_argument('--mapping', type=str,
help='additional mapping when evaluation')
parser.add_argument('--model_dir', type=str, required=True,
help='path of saving model')
parser.add_argument('--eval_secs', type=int, default=300,
help='evaluation every N seconds, only happening when `valid` is specified')
parser.add_argument('--encoder_units', type=int, default=128,
help='rnn hidden units of encoder')
parser.add_argument('--encoder_layers', type=int, default=3,
help='rnn layers of encoder')
parser.add_argument('--use_pyramidal', action='store_true',
help='whether to use pyramidal rnn')
parser.add_argument('--unidirectional', action='store_true',
help='Use unidirectional RNN')
parser.add_argument('--decoder_units', type=int, default=128,
help='rnn hidden units of decoder')
parser.add_argument('--decoder_layers', type=int, default=2,
help='rnn layers of decoder')
parser.add_argument('--embedding_size', type=int, default=0,
help='embedding size of target vocabulary, if 0, one hot encoding is applied')
parser.add_argument('--sampling_probability', type=float, default=0.1,
help='sampling probabilty of decoder during training')
parser.add_argument('--attention_type', type=str, default='luong', choices=['luong', 'bahdanau', 'custom',
'luong_monotonic', 'bahdanau_monotonic'],
help='type of attention mechanism')
parser.add_argument('--attention_layer_size', type=int,
help='size of attention layer, see tensorflow.contrib.seq2seq.AttentionWrapper'
'for more details')
parser.add_argument('--bottom_only', action='store_true',
help='apply attention mechanism only at the bottommost rnn cell')
parser.add_argument('--pass_hidden_state', action='store_true',
help='whether to pass encoder state to decoder')
parser.add_argument('--batch_size', type=int, default=8,
help='batch size')
parser.add_argument('--num_parallel_calls', type=int, default=multiprocessing.cpu_count(),
help='Number of elements to be processed in parallel during the dataset transformation')
parser.add_argument('--num_channels', type=int,
help='number of input channels')
parser.add_argument('--num_epochs', type=int, default=150,
help='number of training epochs')
parser.add_argument('--learning_rate', type=float, default=1e-3,
help='learning rate')
parser.add_argument('--dropout', type=float, default=0.2,
help='dropout rate of rnn cell')
parser.add_argument('--l2_reg_scale', type=float, default=1e-6,
help='L2 regularization scale')
parser.add_argument('--add_noise', type=int, default=0,
help='How often (in steps) to add Gaussian noise to the weights, zero for disabling noise addition.')
parser.add_argument('--noise_std', type=float, default=0.1,
help='Weigth noise standard deviation.')
parser.add_argument('--binary_outputs', action='store_true',
help='make projection layer output binary feature posteriors instead of phone posteriors')
parser.add_argument('--output_ipa', action='store_true',
help='With --binary_outputs on, make the graph output phones and'
' change sampling algorithm at training')
parser.add_argument('--binf_map', type=str, default='misc/binf_map.csv',
help='Path to CSV with phonemes to binary features map')
parser.add_argument('--ctc_weight', type=float, default=-1.,
help='If possitive, adds CTC mutlitask target based on encoder.')
parser.add_argument('--reset', help='Reset HParams.', action='store_true')
parser.add_argument('--binf_sampling', action='store_true',
help='with --output_ipa, do not use ipa sampling algorithm for trainin, only for validation')
parser.add_argument('--binf_projection', action='store_true',
help='with --binary_outputs and --output_ipa, use binary features mapping instead of decoder''s projection layer.')
parser.add_argument('--binf_projection_reg_weight', type=float, default=1.0,
help='with --binf_projection, weight for regularization term for binary features log probabilities.')
parser.add_argument('--binf_trainable', action='store_true',
help='trainable binary features matrix'),
parser.add_argument('--multitask', action='store_true',
help='with --binary_outputs use both binary features and IPA decoders.')
parser.add_argument('--tpu_name', type=str, default='', help='TPU name. Leave blank to prevent TPU training.')
parser.add_argument('--max_frames', type=int, default=-1,
help='If positives, sets that much frames for each batch.')
parser.add_argument('--max_symbols', type=int, default=-1,
help='If positives, sets that much symbols for each batch.')
parser.add_argument('--tpu_checkpoints_interval', type=int, default=600,
help='Interval for saving checkpoints on TPU, in steps.')
parser.add_argument('--t2t_features_hparams_override', type=str, default='',
help='String with overrided parameters used by Tensor2Tensor problem.')
return parser.parse_args()
def main(args):
train_dir = os.path.dirname(args.train)
vocab_name = os.path.join(train_dir, 'vocab.txt')
norm_name = os.path.join(train_dir, 'norm.dmp')
vocab_list = utils.load_vocab(vocab_name)
binf2phone_np = None
mapping = None
vocab_size = len(vocab_list)
binf_count = None
if args.binary_outputs:
if args.mapping is not None:
vocab_list, mapping = utils.get_mapping(args.mapping, vocab_name)
args.mapping = None
binf2phone = utils.load_binf2phone(args.binf_map, vocab_list)
binf_count = len(binf2phone.index)
if args.output_ipa:
binf2phone_np = binf2phone.values
if args.tpu_name:
iterations_per_loop = 100
tpu_cluster_resolver = None
if args.tpu_name != 'fake':
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(args.tpu_name)
config = tf.estimator.tpu.RunConfig(
cluster=tpu_cluster_resolver,
model_dir=args.model_dir,
save_checkpoints_steps=max(args.tpu_checkpoints_interval, iterations_per_loop),
tpu_config=tf.estimator.tpu.TPUConfig(
iterations_per_loop=iterations_per_loop,
per_host_input_for_training=tf.estimator.tpu.InputPipelineConfig.PER_HOST_V2))
else:
config = tf.estimator.RunConfig(model_dir=args.model_dir)
hparams = utils.create_hparams(
args, vocab_size, binf_count, utils.SOS_ID if not args.t2t_format else PAD_ID,
utils.EOS_ID if not args.t2t_format else EOS_ID)
if mapping is not None:
hparams.del_hparam('mapping')
hparams.add_hparam('mapping', mapping)
def model_fn(features, labels,
mode, config, params):
binf_map = binf2phone_np
return las_model_fn(features, labels, mode, config, params,
binf2phone=binf_map)
if args.tpu_name:
model = tf.estimator.tpu.TPUEstimator(
model_fn=model_fn, config=config, params=hparams, eval_on_tpu=False,
train_batch_size=args.batch_size, use_tpu=args.tpu_name != 'fake'
)
else:
model = tf.estimator.Estimator(
model_fn=model_fn,
config=config,
params=hparams)
def create_input_fn(mode):
if args.t2t_format:
return lambda params: utils.input_fn_t2t(args.train, mode, hparams,
args.t2t_problem_name,
batch_size=params.batch_size if 'batch_size' in params else args.batch_size,
num_epochs=args.num_epochs if mode == tf.estimator.ModeKeys.TRAIN else 1,
num_parallel_calls=64 if args.tpu_name and args.tpu_name != 'fake' else args.num_parallel_calls,
max_frames=args.max_frames, max_symbols=args.max_symbols,
features_hparams_override=args.t2t_features_hparams_override)
else:
return lambda params: utils.input_fn(
args.valid if mode == tf.estimator.ModeKeys.EVAL and args.valid else args.train,
vocab_name, norm_name,
num_channels=args.num_channels if args.num_channels is not None else hparams.get_hparam('num_channels'),
batch_size=params.batch_size if 'batch_size' in params else args.batch_size,
num_epochs=args.num_epochs if mode == tf.estimator.ModeKeys.TRAIN else 1,
num_parallel_calls=64 if args.tpu_name and args.tpu_name != 'fake' else args.num_parallel_calls,
max_frames=args.max_frames, max_symbols=args.max_symbols)
if args.valid or args.t2t_format:
train_spec = tf.estimator.TrainSpec(
input_fn=create_input_fn(tf.estimator.ModeKeys.TRAIN),
max_steps=args.num_epochs * 1000 * args.batch_size
)
eval_spec = tf.estimator.EvalSpec(
input_fn=create_input_fn(tf.estimator.ModeKeys.EVAL),
start_delay_secs=60,
throttle_secs=args.eval_secs)
tf.estimator.train_and_evaluate(model, train_spec, eval_spec)
else:
tf.logging.warning('Training without evaluation!')
model.train(
input_fn=create_input_fn(tf.estimator.ModeKeys.TRAIN),
steps=args.num_epochs * 1000 * args.batch_size
)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
args = parse_args()
main(args)