1
|
|
|
#!/usr/bin/env python |
2
|
|
|
# -*- coding: utf-8 -*- |
3
|
|
|
import sys |
4
|
|
|
import time |
5
|
|
|
import numpy as np |
6
|
|
|
import theano |
7
|
|
|
|
8
|
|
|
from deepy.conf import TrainerConfig |
9
|
|
|
from deepy.dataset import Dataset |
10
|
|
|
from deepy.utils import Timer |
11
|
|
|
|
12
|
|
|
from abc import ABCMeta, abstractmethod |
13
|
|
|
|
14
|
|
|
from logging import getLogger |
15
|
|
|
logging = getLogger("trainer") |
16
|
|
|
|
17
|
|
|
class NeuralTrainer(object): |
18
|
|
|
""" |
19
|
|
|
A base class for all trainers. |
20
|
|
|
""" |
21
|
|
|
__metaclass__ = ABCMeta |
22
|
|
|
|
23
|
|
|
def __init__(self, network, config=None): |
24
|
|
|
""" |
25
|
|
|
Basic neural network trainer. |
26
|
|
|
:type network: deepy.NeuralNetwork |
27
|
|
|
:type config: deepy.conf.TrainerConfig |
28
|
|
|
:return: |
29
|
|
|
""" |
30
|
|
|
super(NeuralTrainer, self).__init__() |
31
|
|
|
|
32
|
|
|
self.config = None |
33
|
|
|
if isinstance(config, TrainerConfig): |
34
|
|
|
self.config = config |
35
|
|
|
elif isinstance(config, dict): |
36
|
|
|
self.config = TrainerConfig(config) |
37
|
|
|
else: |
38
|
|
|
self.config = TrainerConfig() |
39
|
|
|
# Model and network all refer to the computational graph |
40
|
|
|
self.model = self.network = network |
41
|
|
|
|
42
|
|
|
self.network.prepare_training() |
43
|
|
|
self._setup_costs() |
44
|
|
|
|
45
|
|
|
self.evaluation_func = None |
46
|
|
|
|
47
|
|
|
self.validation_frequency = self.config.validation_frequency |
48
|
|
|
self.min_improvement = self.config.min_improvement |
49
|
|
|
self.patience = self.config.patience |
50
|
|
|
self._iter_callbacks = [] |
51
|
|
|
|
52
|
|
|
self.best_cost = 1e100 |
53
|
|
|
self.best_epoch = 0 |
54
|
|
|
self.best_params = self.copy_params() |
55
|
|
|
self._skip_batches = 0 |
56
|
|
|
self._skip_epochs = 0 |
57
|
|
|
self._progress = 0 |
58
|
|
|
self.last_cost = 0 |
59
|
|
|
self.last_run_costs = None |
60
|
|
|
self._report_time = True |
61
|
|
|
self._epoch = 0 |
62
|
|
|
|
63
|
|
|
def _compile_evaluation_func(self): |
64
|
|
|
if not self.evaluation_func: |
65
|
|
|
logging.info("compile evaluation function") |
66
|
|
|
self.evaluation_func = theano.function( |
67
|
|
|
self.network.input_variables + self.network.target_variables, |
68
|
|
|
self.evaluation_variables, |
69
|
|
|
updates=self.network.updates, |
70
|
|
|
allow_input_downcast=True, mode=self.config.get("theano_mode", None)) |
71
|
|
|
|
72
|
|
|
def skip(self, n_batches, n_epochs=0): |
73
|
|
|
""" |
74
|
|
|
Skip N batches in the training. |
75
|
|
|
""" |
76
|
|
|
logging.info("skip %d epochs and %d batches" % (n_epochs, n_batches)) |
77
|
|
|
self._skip_batches = n_batches |
78
|
|
|
self._skip_epochs = n_epochs |
79
|
|
|
|
80
|
|
|
def epoch(self): |
81
|
|
|
""" |
82
|
|
|
Get current epoch. |
83
|
|
|
""" |
84
|
|
|
return self._epoch |
85
|
|
|
|
86
|
|
|
def _setup_costs(self): |
87
|
|
|
self.cost = self._add_regularization(self.network.cost) |
88
|
|
|
self.test_cost = self._add_regularization(self.network.test_cost) |
89
|
|
|
self.training_variables = [self.cost] |
90
|
|
|
self.training_names = ['J'] |
91
|
|
|
for name, monitor in self.network.training_monitors: |
92
|
|
|
self.training_names.append(name) |
93
|
|
|
self.training_variables.append(monitor) |
94
|
|
|
logging.info("monitor list: %s" % ",".join(self.training_names)) |
95
|
|
|
|
96
|
|
|
self.evaluation_variables = [self.test_cost] |
97
|
|
|
self.evaluation_names = ['J'] |
98
|
|
|
for name, monitor in self.network.testing_monitors: |
99
|
|
|
self.evaluation_names.append(name) |
100
|
|
|
self.evaluation_variables.append(monitor) |
101
|
|
|
|
102
|
|
|
def _add_regularization(self, cost): |
103
|
|
|
if self.config.weight_l1 > 0: |
104
|
|
|
logging.info("L1 weight regularization: %f" % self.config.weight_l1) |
105
|
|
|
cost += self.config.weight_l1 * sum(abs(w).sum() for w in self.network.parameters) |
106
|
|
|
if self.config.hidden_l1 > 0: |
107
|
|
|
logging.info("L1 hidden unit regularization: %f" % self.config.hidden_l1) |
108
|
|
|
cost += self.config.hidden_l1 * sum(abs(h).mean(axis=0).sum() for h in self.network._hidden_outputs) |
109
|
|
|
if self.config.hidden_l2 > 0: |
110
|
|
|
logging.info("L2 hidden unit regularization: %f" % self.config.hidden_l2) |
111
|
|
|
cost += self.config.hidden_l2 * sum((h * h).mean(axis=0).sum() for h in self.network._hidden_outputs) |
112
|
|
|
|
113
|
|
|
return cost |
114
|
|
|
|
115
|
|
|
def set_params(self, targets, free_params=None): |
116
|
|
|
for param, target in zip(self.network.parameters, targets): |
117
|
|
|
param.set_value(target) |
118
|
|
|
if free_params: |
119
|
|
|
for param, param_value in zip(self.network.free_parameters, free_params): |
120
|
|
|
param.set_value(param_value) |
121
|
|
|
|
122
|
|
|
def save_params(self, path): |
123
|
|
|
self.set_params(*self.best_params) |
124
|
|
|
self.network.save_params(path) |
125
|
|
|
|
126
|
|
|
def load_params(self, path, exclude_free_params=False): |
127
|
|
|
""" |
128
|
|
|
Load parameters for the training. |
129
|
|
|
This method can load free parameters and resume the training progress. |
130
|
|
|
""" |
131
|
|
|
self.network.load_params(path, exclude_free_params=exclude_free_params) |
132
|
|
|
self.best_params = self.copy_params() |
133
|
|
|
# Resume the progress |
134
|
|
|
if self.network.train_logger.progress() > 0 or self.network.train_logger.epoch() > 0: |
135
|
|
|
self.skip(self.network.train_logger.progress(), self.network.train_logger.epoch() - 1) |
136
|
|
|
|
137
|
|
|
def copy_params(self): |
138
|
|
|
checkpoint = (map(lambda p: p.get_value().copy(), self.network.parameters), |
139
|
|
|
map(lambda p: p.get_value().copy(), self.network.free_parameters)) |
140
|
|
|
return checkpoint |
141
|
|
|
|
142
|
|
|
def add_iter_callback(self, func): |
143
|
|
|
""" |
144
|
|
|
Add a iteration callback function (receives an argument of the trainer). |
145
|
|
|
:return: |
146
|
|
|
""" |
147
|
|
|
self._iter_callbacks.append(func) |
148
|
|
|
|
149
|
|
|
def train(self, train_set, valid_set=None, test_set=None, train_size=None): |
150
|
|
|
""" |
151
|
|
|
Train the model and return costs. |
152
|
|
|
""" |
153
|
|
|
self._epoch = 0 |
154
|
|
|
while True: |
155
|
|
|
if self._skip_epochs > 0: |
156
|
|
|
logging.info("skipping one epoch ...") |
157
|
|
|
self._skip_epochs -= 1 |
158
|
|
|
self._epoch += 1 |
159
|
|
|
yield None |
160
|
|
|
continue |
161
|
|
|
# Test |
162
|
|
|
if not self._epoch % self.config.test_frequency and test_set: |
163
|
|
|
try: |
164
|
|
|
self._run_test(self._epoch, test_set) |
165
|
|
|
except KeyboardInterrupt: |
166
|
|
|
logging.info('interrupted!') |
167
|
|
|
break |
168
|
|
|
# Validate |
169
|
|
|
if not self._epoch % self.validation_frequency and valid_set: |
170
|
|
|
try: |
171
|
|
|
|
172
|
|
|
if not self._run_valid(self._epoch, valid_set): |
173
|
|
|
logging.info('patience elapsed, bailing out') |
174
|
|
|
break |
175
|
|
|
except KeyboardInterrupt: |
176
|
|
|
logging.info('interrupted!') |
177
|
|
|
break |
178
|
|
|
# Train one step |
179
|
|
|
|
180
|
|
|
try: |
181
|
|
|
costs = self._run_train(self._epoch, train_set, train_size) |
182
|
|
|
except KeyboardInterrupt: |
183
|
|
|
logging.info('interrupted!') |
184
|
|
|
break |
185
|
|
|
# Check costs |
186
|
|
|
if np.isnan(costs[0][1]): |
187
|
|
|
logging.info("NaN detected in costs, rollback to last parameters") |
188
|
|
|
self.set_params(*self.checkpoint) |
189
|
|
|
else: |
190
|
|
|
self._epoch += 1 |
191
|
|
|
self.network.epoch_callback() |
192
|
|
|
|
193
|
|
|
yield dict(costs) |
194
|
|
|
|
195
|
|
|
if valid_set and self.config.get("save_best_parameters", True): |
196
|
|
|
self.set_params(*self.best_params) |
197
|
|
|
if test_set: |
198
|
|
|
self._run_test(-1, test_set) |
199
|
|
|
|
200
|
|
|
@abstractmethod |
201
|
|
|
def learn(self, *variables): |
202
|
|
|
""" |
203
|
|
|
Update the parameters and return the cost with given data points. |
204
|
|
|
:param variables: |
205
|
|
|
:return: |
206
|
|
|
""" |
207
|
|
|
|
208
|
|
|
def _run_test(self, iteration, test_set): |
209
|
|
|
""" |
210
|
|
|
Run on test iteration. |
211
|
|
|
""" |
212
|
|
|
costs = self.test_step(test_set) |
213
|
|
|
info = ' '.join('%s=%.2f' % el for el in costs) |
214
|
|
|
message = "test (epoch=%i) %s" % (iteration + 1, info) |
215
|
|
|
logging.info(message) |
216
|
|
|
self.network.train_logger.record(message) |
217
|
|
|
self.last_run_costs = costs |
218
|
|
|
|
219
|
|
|
def _run_train(self, epoch, train_set, train_size=None): |
220
|
|
|
""" |
221
|
|
|
Run one training iteration. |
222
|
|
|
""" |
223
|
|
|
self.network.train_logger.record_epoch(epoch + 1) |
224
|
|
|
costs = self.train_step(train_set, train_size) |
225
|
|
|
if not epoch % self.config.monitor_frequency: |
226
|
|
|
info = " ".join("%s=%.2f" % item for item in costs) |
227
|
|
|
message = "monitor (epoch=%i) %s" % (epoch + 1, info) |
228
|
|
|
logging.info(message) |
229
|
|
|
self.network.train_logger.record(message) |
230
|
|
|
self.last_run_costs = costs |
231
|
|
|
return costs |
232
|
|
|
|
233
|
|
|
def _run_valid(self, epoch, valid_set, dry_run=False, save_path=None): |
234
|
|
|
""" |
235
|
|
|
Run one valid iteration, return true if to continue training. |
236
|
|
|
""" |
237
|
|
|
costs = self.valid_step(valid_set) |
238
|
|
|
# this is the same as: (J_i - J_f) / J_i > min improvement |
239
|
|
|
_, J = costs[0] |
240
|
|
|
marker = "" |
241
|
|
|
if self.best_cost - J > self.best_cost * self.min_improvement: |
242
|
|
|
# save the best cost and parameters |
243
|
|
|
self.best_params = self.copy_params() |
244
|
|
|
marker = ' *' |
245
|
|
|
if not dry_run: |
246
|
|
|
self.best_cost = J |
247
|
|
|
self.best_epoch = epoch |
248
|
|
|
|
249
|
|
|
save_path = save_path if save_path else self.config.auto_save |
250
|
|
|
if save_path and self._skip_batches == 0: |
251
|
|
|
self.network.train_logger.record_progress(self._progress) |
252
|
|
|
self.network.save_params(save_path, new_thread=True) |
253
|
|
|
|
254
|
|
|
info = ' '.join('%s=%.2f' % el for el in costs) |
255
|
|
|
epoch_str = "epoch=%d" % (epoch + 1) |
256
|
|
|
if dry_run: |
257
|
|
|
epoch_str = "dryrun" + " " * (len(epoch_str) - 6) |
258
|
|
|
message = "valid (%s) %s%s" % (epoch_str, info, marker) |
259
|
|
|
logging.info(message) |
260
|
|
|
self.last_run_costs = costs |
261
|
|
|
self.network.train_logger.record(message) |
262
|
|
|
self.checkpoint = self.copy_params() |
263
|
|
|
return epoch - self.best_epoch < self.patience |
264
|
|
|
|
265
|
|
|
def test_step(self, test_set): |
266
|
|
|
self._compile_evaluation_func() |
267
|
|
|
costs = list(zip( |
268
|
|
|
self.evaluation_names, |
269
|
|
|
np.mean([self.evaluation_func(*x) for x in test_set], axis=0))) |
270
|
|
|
return costs |
271
|
|
|
|
272
|
|
|
def valid_step(self, valid_set): |
273
|
|
|
self._compile_evaluation_func() |
274
|
|
|
costs = list(zip( |
275
|
|
|
self.evaluation_names, |
276
|
|
|
np.mean([self.evaluation_func(*x) for x in valid_set], axis=0))) |
277
|
|
|
return costs |
278
|
|
|
|
279
|
|
|
def train_step(self, train_set, train_size=None): |
280
|
|
|
dirty_trick_times = 0 |
281
|
|
|
network_callback = bool(self.network.training_callbacks) |
282
|
|
|
trainer_callback = bool(self._iter_callbacks) |
283
|
|
|
cost_matrix = [] |
284
|
|
|
exec_count = 0 |
285
|
|
|
start_time = time.time() |
286
|
|
|
self._progress = 0 |
287
|
|
|
|
288
|
|
|
|
289
|
|
|
for x in train_set: |
290
|
|
|
if self._skip_batches == 0: |
291
|
|
|
|
292
|
|
|
if dirty_trick_times > 0: |
293
|
|
|
cost_x = self.learn(*[t[:(t.shape[0]/2)] for t in x]) |
294
|
|
|
cost_matrix.append(cost_x) |
295
|
|
|
cost_x = self.learn(*[t[(t.shape[0]/2):] for t in x]) |
296
|
|
|
dirty_trick_times -= 1 |
297
|
|
|
else: |
298
|
|
|
try: |
299
|
|
|
cost_x = self.learn(*x) |
300
|
|
|
except MemoryError: |
301
|
|
|
logging.info("Memory error was detected, perform dirty trick 30 times") |
302
|
|
|
dirty_trick_times = 30 |
303
|
|
|
# Dirty trick |
304
|
|
|
cost_x = self.learn(*[t[:(t.shape[0]/2)] for t in x]) |
305
|
|
|
cost_matrix.append(cost_x) |
306
|
|
|
cost_x = self.learn(*[t[(t.shape[0]/2):] for t in x]) |
307
|
|
|
cost_matrix.append(cost_x) |
308
|
|
|
self.last_cost = cost_x[0] |
309
|
|
|
exec_count += 1 |
310
|
|
|
if network_callback: |
311
|
|
|
self.network.training_callback() |
312
|
|
|
if trainer_callback: |
313
|
|
|
for func in self._iter_callbacks: |
314
|
|
|
func(self) |
315
|
|
|
else: |
316
|
|
|
self._skip_batches -= 1 |
317
|
|
|
if train_size: |
318
|
|
|
self._progress += 1 |
319
|
|
|
spd = float(exec_count) / (time.time() - start_time) |
320
|
|
|
sys.stdout.write("\x1b[2K\r> %d%% | J=%.2f | spd=%.2f batch/s" % (self._progress * 100 / train_size, self.last_cost, spd)) |
321
|
|
|
sys.stdout.flush() |
322
|
|
|
self._progress = 0 |
323
|
|
|
|
324
|
|
|
if train_size: |
325
|
|
|
sys.stdout.write("\r") |
326
|
|
|
sys.stdout.flush() |
327
|
|
|
costs = list(zip(self.training_names, np.mean(cost_matrix, axis=0))) |
328
|
|
|
return costs |
329
|
|
|
|
330
|
|
|
def run(self, train_set, valid_set=None, test_set=None, train_size=None, controllers=None): |
331
|
|
|
""" |
332
|
|
|
Run until the end. |
333
|
|
|
""" |
334
|
|
|
if isinstance(train_set, Dataset): |
335
|
|
|
dataset = train_set |
336
|
|
|
train_set = dataset.train_set() |
337
|
|
|
valid_set = dataset.valid_set() |
338
|
|
|
test_set = dataset.test_set() |
339
|
|
|
train_size = dataset.train_size() |
340
|
|
|
if controllers: |
341
|
|
|
for controller in controllers: |
342
|
|
|
controller.bind(self) |
343
|
|
|
timer = Timer() |
344
|
|
|
for _ in self.train(train_set, valid_set=valid_set, test_set=test_set, train_size=train_size): |
345
|
|
|
if controllers: |
346
|
|
|
ending = False |
347
|
|
|
for controller in controllers: |
348
|
|
|
if hasattr(controller, 'invoke') and controller.invoke(): |
349
|
|
|
ending = True |
350
|
|
|
if ending: |
351
|
|
|
break |
352
|
|
|
if self._report_time: |
353
|
|
|
timer.report() |
354
|
|
|
|