1
|
|
|
#!/usr/bin/env python |
2
|
|
|
# -*- coding: utf-8 -*- |
3
|
|
|
import sys |
4
|
|
|
import time |
5
|
|
|
import numpy as np |
6
|
|
|
import theano |
7
|
|
|
|
8
|
|
|
from ..conf import TrainerConfig |
9
|
|
|
from ..core import env, runtime |
10
|
|
|
from ..utils import Timer |
11
|
|
|
from ..dataset import Dataset |
12
|
|
|
from controllers import TrainingController |
13
|
|
|
|
14
|
|
|
from abc import ABCMeta, abstractmethod |
15
|
|
|
|
16
|
|
|
from logging import getLogger |
17
|
|
|
logging = getLogger("trainer") |
18
|
|
|
|
19
|
|
|
class NeuralTrainer(object): |
20
|
|
|
""" |
21
|
|
|
A base class for all trainers. |
22
|
|
|
""" |
23
|
|
|
__metaclass__ = ABCMeta |
24
|
|
|
|
25
|
|
|
def __init__(self, network, config=None, validator=None, annealer=None): |
26
|
|
|
""" |
27
|
|
|
Basic neural network trainer. |
28
|
|
|
:type network: deepy.NeuralNetwork |
29
|
|
|
:type config: deepy.conf.TrainerConfig |
30
|
|
|
:return: |
31
|
|
|
""" |
32
|
|
|
super(NeuralTrainer, self).__init__() |
33
|
|
|
|
34
|
|
|
self.config = None |
35
|
|
|
if isinstance(config, TrainerConfig): |
36
|
|
|
self.config = config |
37
|
|
|
elif isinstance(config, dict): |
38
|
|
|
self.config = TrainerConfig(config) |
39
|
|
|
else: |
40
|
|
|
self.config = TrainerConfig() |
41
|
|
|
if type(self.config.learning_rate) == float: |
42
|
|
|
self.config.learning_rate = np.array(self.config.learning_rate, dtype=env.FLOATX) |
43
|
|
|
# Model and network all refer to the computational graph |
44
|
|
|
self.model = self.network = network |
45
|
|
|
|
46
|
|
|
self.network.prepare_training() |
47
|
|
|
self._setup_costs() |
48
|
|
|
|
49
|
|
|
self.evaluation_func = None |
50
|
|
|
|
51
|
|
|
self.validation_frequency = self.config.validation_frequency |
52
|
|
|
self.min_improvement = self.config.min_improvement |
53
|
|
|
self.patience = self.config.patience |
54
|
|
|
|
55
|
|
|
self._iter_controllers = [] |
56
|
|
|
self._epoch_controllers = [] |
57
|
|
|
if annealer: |
58
|
|
|
annealer.bind(self) |
59
|
|
|
self._epoch_controllers.append(annealer) |
60
|
|
|
if validator: |
61
|
|
|
validator.bind(self) |
62
|
|
|
self._iter_controllers.append(validator) |
63
|
|
|
|
64
|
|
|
self.best_cost = 1e100 |
65
|
|
|
self.best_epoch = 0 |
66
|
|
|
self.best_params = self.copy_params() |
67
|
|
|
self._skip_batches = 0 |
68
|
|
|
self._skip_epochs = 0 |
69
|
|
|
self._progress = 0 |
70
|
|
|
self.last_cost = 0 |
71
|
|
|
self.last_run_costs = None |
72
|
|
|
self._report_time = True |
73
|
|
|
self._epoch = 0 |
74
|
|
|
|
75
|
|
|
self._current_train_set = None |
76
|
|
|
self._current_valid_set = None |
77
|
|
|
self._current_test_set = None |
78
|
|
|
self._ended = False |
79
|
|
|
|
80
|
|
|
|
81
|
|
|
def _compile_evaluation_func(self): |
82
|
|
|
if not self.evaluation_func: |
83
|
|
|
logging.info("compile evaluation function") |
84
|
|
|
self.evaluation_func = theano.function( |
85
|
|
|
self.network.input_variables + self.network.target_variables, |
86
|
|
|
self.evaluation_variables, |
87
|
|
|
updates=self.network.updates, |
88
|
|
|
allow_input_downcast=True, mode=self.config.get("theano_mode", None)) |
89
|
|
|
|
90
|
|
|
def skip(self, n_batches, n_epochs=0): |
91
|
|
|
""" |
92
|
|
|
Skip N batches in the training. |
93
|
|
|
""" |
94
|
|
|
logging.info("skip %d epochs and %d batches" % (n_epochs, n_batches)) |
95
|
|
|
self._skip_batches = n_batches |
96
|
|
|
self._skip_epochs = n_epochs |
97
|
|
|
|
98
|
|
|
def epoch(self): |
99
|
|
|
""" |
100
|
|
|
Get current epoch. |
101
|
|
|
""" |
102
|
|
|
return self._epoch |
103
|
|
|
|
104
|
|
|
def _setup_costs(self): |
105
|
|
|
self.cost = self._add_regularization(self.network.cost) |
106
|
|
|
self.test_cost = self._add_regularization(self.network.test_cost) |
107
|
|
|
self.training_variables = [self.cost] |
108
|
|
|
self.training_names = ['J'] |
109
|
|
|
for name, monitor in self.network.training_monitors: |
110
|
|
|
self.training_names.append(name) |
111
|
|
|
self.training_variables.append(monitor) |
112
|
|
|
logging.info("monitor list: %s" % ",".join(self.training_names)) |
113
|
|
|
|
114
|
|
|
self.evaluation_variables = [self.test_cost] |
115
|
|
|
self.evaluation_names = ['J'] |
116
|
|
|
for name, monitor in self.network.testing_monitors: |
117
|
|
|
self.evaluation_names.append(name) |
118
|
|
|
self.evaluation_variables.append(monitor) |
119
|
|
|
|
120
|
|
|
def _add_regularization(self, cost): |
121
|
|
|
if self.config.weight_l1 > 0: |
122
|
|
|
logging.info("L1 weight regularization: %f" % self.config.weight_l1) |
123
|
|
|
cost += self.config.weight_l1 * sum(abs(w).sum() for w in self.network.parameters) |
124
|
|
|
if self.config.hidden_l1 > 0: |
125
|
|
|
logging.info("L1 hidden unit regularization: %f" % self.config.hidden_l1) |
126
|
|
|
cost += self.config.hidden_l1 * sum(abs(h).mean(axis=0).sum() for h in self.network._hidden_outputs) |
127
|
|
|
if self.config.hidden_l2 > 0: |
128
|
|
|
logging.info("L2 hidden unit regularization: %f" % self.config.hidden_l2) |
129
|
|
|
cost += self.config.hidden_l2 * sum((h * h).mean(axis=0).sum() for h in self.network._hidden_outputs) |
130
|
|
|
|
131
|
|
|
return cost |
132
|
|
|
|
133
|
|
|
def set_params(self, targets, free_params=None): |
134
|
|
|
for param, target in zip(self.network.parameters, targets): |
135
|
|
|
param.set_value(target) |
136
|
|
|
if free_params: |
137
|
|
|
for param, param_value in zip(self.network.free_parameters, free_params): |
138
|
|
|
param.set_value(param_value) |
139
|
|
|
|
140
|
|
|
def save_params(self, path): |
141
|
|
|
self.set_params(*self.best_params) |
142
|
|
|
self.network.save_params(path) |
143
|
|
|
|
144
|
|
|
def load_params(self, path, exclude_free_params=False): |
145
|
|
|
""" |
146
|
|
|
Load parameters for the training. |
147
|
|
|
This method can load free parameters and resume the training progress. |
148
|
|
|
""" |
149
|
|
|
self.network.load_params(path, exclude_free_params=exclude_free_params) |
150
|
|
|
self.best_params = self.copy_params() |
151
|
|
|
# Resume the progress |
152
|
|
|
if self.network.train_logger.progress() > 0 or self.network.train_logger.epoch() > 0: |
153
|
|
|
self.skip(self.network.train_logger.progress(), self.network.train_logger.epoch() - 1) |
154
|
|
|
|
155
|
|
|
def copy_params(self): |
156
|
|
|
checkpoint = (map(lambda p: p.get_value().copy(), self.network.parameters), |
157
|
|
|
map(lambda p: p.get_value().copy(), self.network.free_parameters)) |
158
|
|
|
return checkpoint |
159
|
|
|
|
160
|
|
|
def exit(self): |
161
|
|
|
""" |
162
|
|
|
End the training. |
163
|
|
|
""" |
164
|
|
|
self._ended = True |
165
|
|
|
|
166
|
|
|
def add_iter_controllers(self, *controllers): |
167
|
|
|
""" |
168
|
|
|
Add iteration callbacks function (receives an argument of the trainer). |
169
|
|
|
:param controllers: can be a `TrainingController` or a function. |
170
|
|
|
:type funcs: list of TrainingContoller |
171
|
|
|
""" |
172
|
|
|
for controller in controllers: |
173
|
|
|
if isinstance(controller, TrainingController): |
174
|
|
|
controller.bind(self) |
175
|
|
|
self._iter_controllers.append(controller) |
176
|
|
|
|
177
|
|
|
def add_epoch_controllers(self, *controllers): |
178
|
|
|
""" |
179
|
|
|
Add epoch callbacks function. |
180
|
|
|
:param controllers: can be a `TrainingController` or a function. |
181
|
|
|
""" |
182
|
|
|
for controller in controllers: |
183
|
|
|
if isinstance(controller, TrainingController): |
184
|
|
|
controller.bind(self) |
185
|
|
|
self._epoch_controllers.append(controller) |
186
|
|
|
|
187
|
|
|
def train(self, train_set, valid_set=None, test_set=None, train_size=None): |
188
|
|
|
""" |
189
|
|
|
Train the model and return costs. |
190
|
|
|
""" |
191
|
|
|
self._epoch = 0 |
192
|
|
|
while True: |
193
|
|
|
if self._skip_epochs > 0: |
194
|
|
|
logging.info("skipping one epoch ...") |
195
|
|
|
self._skip_epochs -= 1 |
196
|
|
|
self._epoch += 1 |
197
|
|
|
yield None |
198
|
|
|
continue |
199
|
|
|
# Test |
200
|
|
|
if not self._epoch % self.config.test_frequency and test_set: |
201
|
|
|
try: |
202
|
|
|
self._run_test(self._epoch, test_set) |
203
|
|
|
except KeyboardInterrupt: |
204
|
|
|
logging.info('interrupted!') |
205
|
|
|
break |
206
|
|
|
# Validate |
207
|
|
|
if not self._epoch % self.validation_frequency and valid_set: |
208
|
|
|
try: |
209
|
|
|
|
210
|
|
|
if not self._run_valid(self._epoch, valid_set): |
211
|
|
|
logging.info('patience elapsed, bailing out') |
212
|
|
|
break |
213
|
|
|
except KeyboardInterrupt: |
214
|
|
|
logging.info('interrupted!') |
215
|
|
|
break |
216
|
|
|
# Train one step |
217
|
|
|
|
218
|
|
|
try: |
219
|
|
|
costs = self._run_train(self._epoch, train_set, train_size) |
220
|
|
|
except KeyboardInterrupt: |
221
|
|
|
logging.info('interrupted!') |
222
|
|
|
break |
223
|
|
|
# Check costs |
224
|
|
|
if np.isnan(costs[0][1]): |
225
|
|
|
logging.info("NaN detected in costs, rollback to last parameters") |
226
|
|
|
self.set_params(*self.checkpoint) |
227
|
|
|
else: |
228
|
|
|
self._epoch += 1 |
229
|
|
|
self.network.epoch_callback() |
230
|
|
|
|
231
|
|
|
yield dict(costs) |
232
|
|
|
|
233
|
|
|
if valid_set and self.config.get("save_best_parameters", True): |
234
|
|
|
self.set_params(*self.best_params) |
235
|
|
|
if test_set: |
236
|
|
|
self._run_test(-1, test_set) |
237
|
|
|
|
238
|
|
|
@abstractmethod |
239
|
|
|
def learn(self, *variables): |
240
|
|
|
""" |
241
|
|
|
Update the parameters and return the cost with given data points. |
242
|
|
|
:param variables: |
243
|
|
|
:return: |
244
|
|
|
""" |
245
|
|
|
|
246
|
|
|
def _run_test(self, epoch, test_set): |
247
|
|
|
""" |
248
|
|
|
Run on test epoch. |
249
|
|
|
""" |
250
|
|
|
costs = self.test_step(test_set) |
251
|
|
|
self.report(dict(costs), "test", epoch) |
252
|
|
|
self.last_run_costs = costs |
253
|
|
|
|
254
|
|
|
def _run_train(self, epoch, train_set, train_size=None): |
255
|
|
|
""" |
256
|
|
|
Run one training iteration. |
257
|
|
|
""" |
258
|
|
|
self.network.train_logger.record_epoch(epoch + 1) |
259
|
|
|
costs = self.train_step(train_set, train_size) |
260
|
|
|
if not epoch % self.config.monitor_frequency: |
261
|
|
|
self.report(dict(costs), "train", epoch) |
262
|
|
|
self.last_run_costs = costs |
263
|
|
|
return costs |
264
|
|
|
|
265
|
|
|
def _run_valid(self, epoch, valid_set, dry_run=False, save_path=None): |
266
|
|
|
""" |
267
|
|
|
Run one valid iteration, return true if to continue training. |
268
|
|
|
""" |
269
|
|
|
costs = self.valid_step(valid_set) |
270
|
|
|
# this is the same as: (J_i - J_f) / J_i > min improvement |
271
|
|
|
_, J = costs[0] |
272
|
|
|
new_best = False |
273
|
|
|
if self.best_cost - J > self.best_cost * self.min_improvement: |
274
|
|
|
# save the best cost and parameters |
275
|
|
|
self.best_params = self.copy_params() |
276
|
|
|
new_best = True |
277
|
|
|
if not dry_run: |
278
|
|
|
self.best_cost = J |
279
|
|
|
self.best_epoch = epoch |
280
|
|
|
self.save_checkpoint(save_path) |
281
|
|
|
|
282
|
|
|
self.report(dict(costs), type="valid", epoch=0 if dry_run else epoch, new_best=new_best) |
283
|
|
|
self.last_run_costs = costs |
284
|
|
|
return epoch - self.best_epoch < self.patience |
285
|
|
|
|
286
|
|
|
def save_checkpoint(self, save_path=None): |
287
|
|
|
save_path = save_path if save_path else self.config.auto_save |
288
|
|
|
self.checkpoint = self.copy_params() |
289
|
|
|
if save_path and self._skip_batches == 0: |
290
|
|
|
self.network.train_logger.record_progress(self._progress) |
291
|
|
|
self.network.save_params(save_path, new_thread=True) |
292
|
|
|
|
293
|
|
|
def report(self, score_map, type="valid", epoch=-1, new_best=False): |
294
|
|
|
""" |
295
|
|
|
Report the scores and record them in the log. |
296
|
|
|
""" |
297
|
|
|
type_str = type |
298
|
|
|
if len(type_str) < 5: |
299
|
|
|
type_str += " " * (5 - len(type_str)) |
300
|
|
|
info = " ".join("%s=%.2f" % el for el in score_map.items()) |
301
|
|
|
current_epoch = epoch if epoch > 0 else self.current_epoch() |
302
|
|
|
epoch_str = "epoch={}".format(current_epoch + 1) |
303
|
|
|
if epoch < 0: |
304
|
|
|
epoch_str = "dryrun" |
305
|
|
|
sys.stdout.write("\r") |
306
|
|
|
sys.stdout.flush() |
307
|
|
|
marker = " *" if new_best else "" |
308
|
|
|
message = "{} ({}) {}{}".format(type_str, epoch_str, info, marker) |
309
|
|
|
self.network.train_logger.record(message) |
310
|
|
|
logging.info(message) |
311
|
|
|
|
312
|
|
|
def test_step(self, test_set): |
313
|
|
|
runtime.switch_training(False) |
314
|
|
|
self._compile_evaluation_func() |
315
|
|
|
costs = list(zip( |
316
|
|
|
self.evaluation_names, |
317
|
|
|
np.mean([self.evaluation_func(*x) for x in test_set], axis=0))) |
318
|
|
|
return costs |
319
|
|
|
|
320
|
|
|
def valid_step(self, valid_set): |
321
|
|
|
runtime.switch_training(False) |
322
|
|
|
self._compile_evaluation_func() |
323
|
|
|
costs = list(zip( |
324
|
|
|
self.evaluation_names, |
325
|
|
|
np.mean([self.evaluation_func(*x) for x in valid_set], axis=0))) |
326
|
|
|
return costs |
327
|
|
|
|
328
|
|
|
def train_step(self, train_set, train_size=None): |
329
|
|
|
dirty_trick_times = 0 |
330
|
|
|
network_callback = bool(self.network.training_callbacks) |
331
|
|
|
trainer_callback = bool(self._iter_controllers) |
332
|
|
|
cost_matrix = [] |
333
|
|
|
exec_count = 0 |
334
|
|
|
start_time = time.time() |
335
|
|
|
self._compile_time = 0 |
336
|
|
|
self._progress = 0 |
337
|
|
|
|
338
|
|
|
for x in train_set: |
339
|
|
|
runtime.switch_training(True) |
340
|
|
|
if self._skip_batches == 0: |
341
|
|
|
|
342
|
|
|
if dirty_trick_times > 0: |
343
|
|
|
cost_x = self.learn(*[t[:(t.shape[0]/2)] for t in x]) |
344
|
|
|
cost_matrix.append(cost_x) |
345
|
|
|
cost_x = self.learn(*[t[(t.shape[0]/2):] for t in x]) |
346
|
|
|
dirty_trick_times -= 1 |
347
|
|
|
else: |
348
|
|
|
try: |
349
|
|
|
cost_x = self.learn(*x) |
350
|
|
|
except MemoryError: |
351
|
|
|
logging.info("Memory error was detected, perform dirty trick 30 times") |
352
|
|
|
dirty_trick_times = 30 |
353
|
|
|
# Dirty trick |
354
|
|
|
cost_x = self.learn(*[t[:(t.shape[0]/2)] for t in x]) |
355
|
|
|
cost_matrix.append(cost_x) |
356
|
|
|
cost_x = self.learn(*[t[(t.shape[0]/2):] for t in x]) |
357
|
|
|
cost_matrix.append(cost_x) |
358
|
|
|
self.last_cost = cost_x[0] |
359
|
|
|
exec_count += 1 |
360
|
|
|
if network_callback: |
361
|
|
|
self.network.training_callback() |
362
|
|
|
if trainer_callback: |
363
|
|
|
for func in self._iter_controllers: |
364
|
|
|
if isinstance(func, TrainingController): |
365
|
|
|
func.invoke() |
366
|
|
|
else: |
367
|
|
|
func(self) |
368
|
|
|
else: |
369
|
|
|
self._skip_batches -= 1 |
370
|
|
|
if train_size: |
371
|
|
|
self._progress += 1 |
372
|
|
|
spd = float(exec_count) / (time.time() - start_time - self._compile_time) |
373
|
|
|
sys.stdout.write("\x1b[2K\r> %d%% | J=%.2f | spd=%.2f batch/s" % (self._progress * 100 / train_size, self.last_cost, spd)) |
374
|
|
|
sys.stdout.flush() |
375
|
|
|
self._progress = 0 |
376
|
|
|
|
377
|
|
|
if train_size: |
378
|
|
|
sys.stdout.write("\r") |
379
|
|
|
sys.stdout.flush() |
380
|
|
|
costs = list(zip(self.training_names, np.mean(cost_matrix, axis=0))) |
381
|
|
|
return costs |
382
|
|
|
|
383
|
|
|
def current_epoch(self): |
384
|
|
|
""" |
385
|
|
|
Get current epoch. |
386
|
|
|
""" |
387
|
|
|
return self._epoch |
388
|
|
|
|
389
|
|
|
|
390
|
|
|
def get_data(self, data_split="train"): |
391
|
|
|
""" |
392
|
|
|
Get specified split of data. |
393
|
|
|
""" |
394
|
|
|
if data_split == 'train': |
395
|
|
|
return self._current_train_set |
396
|
|
|
elif data_split == 'valid': |
397
|
|
|
return self._current_valid_set |
398
|
|
|
elif data_split == 'test': |
399
|
|
|
return self._current_test_set |
400
|
|
|
else: |
401
|
|
|
return None |
402
|
|
|
|
403
|
|
|
def run(self, train_set, valid_set=None, test_set=None, train_size=None, epoch_controllers=None): |
404
|
|
|
""" |
405
|
|
|
Run until the end. |
406
|
|
|
:param epoch_controllers: deprecated |
407
|
|
|
""" |
408
|
|
|
epoch_controllers = epoch_controllers if epoch_controllers else [] |
409
|
|
|
epoch_controllers += self._epoch_controllers |
410
|
|
|
if isinstance(train_set, Dataset): |
411
|
|
|
dataset = train_set |
412
|
|
|
train_set = dataset.train_set() |
413
|
|
|
valid_set = dataset.valid_set() |
414
|
|
|
test_set = dataset.test_set() |
415
|
|
|
train_size = dataset.train_size() |
416
|
|
|
self._current_train_set = train_set |
417
|
|
|
self._current_valid_set = valid_set |
418
|
|
|
self._current_test_set = test_set |
419
|
|
|
if epoch_controllers: |
420
|
|
|
for controller in epoch_controllers: |
421
|
|
|
controller.bind(self) |
422
|
|
|
timer = Timer() |
423
|
|
|
for _ in self.train(train_set, valid_set=valid_set, test_set=test_set, train_size=train_size): |
424
|
|
|
if epoch_controllers: |
425
|
|
|
for controller in epoch_controllers: |
426
|
|
|
controller.invoke() |
427
|
|
|
if self._ended: |
428
|
|
|
break |
429
|
|
|
if self._report_time: |
430
|
|
|
timer.report() |
431
|
|
|
|