1
|
|
|
#!/usr/bin/env python |
2
|
|
|
# -*- coding: utf-8 -*- |
3
|
|
|
import sys |
4
|
|
|
import numpy as np |
5
|
|
|
import theano |
6
|
|
|
|
7
|
|
|
from deepy.conf import TrainerConfig |
8
|
|
|
from deepy.dataset import Dataset |
9
|
|
|
from deepy.utils import Timer |
10
|
|
|
|
11
|
|
|
from abc import ABCMeta, abstractmethod |
12
|
|
|
|
13
|
|
|
from logging import getLogger |
14
|
|
|
logging = getLogger("trainer") |
15
|
|
|
|
16
|
|
|
class NeuralTrainer(object): |
17
|
|
|
""" |
18
|
|
|
A base class for all trainers. |
19
|
|
|
""" |
20
|
|
|
__metaclass__ = ABCMeta |
21
|
|
|
|
22
|
|
|
def __init__(self, network, config=None): |
23
|
|
|
""" |
24
|
|
|
Basic neural network trainer. |
25
|
|
|
:type network: deepy.NeuralNetwork |
26
|
|
|
:type config: deepy.conf.TrainerConfig |
27
|
|
|
:return: |
28
|
|
|
""" |
29
|
|
|
super(NeuralTrainer, self).__init__() |
30
|
|
|
|
31
|
|
|
self.config = None |
32
|
|
|
if isinstance(config, TrainerConfig): |
33
|
|
|
self.config = config |
34
|
|
|
elif isinstance(config, dict): |
35
|
|
|
self.config = TrainerConfig(config) |
36
|
|
|
else: |
37
|
|
|
self.config = TrainerConfig() |
38
|
|
|
# Model and network all refer to the computational graph |
39
|
|
|
self.model = self.network = network |
40
|
|
|
|
41
|
|
|
self.network.prepare_training() |
42
|
|
|
self._setup_costs() |
43
|
|
|
|
44
|
|
|
self.evaluation_func = None |
45
|
|
|
|
46
|
|
|
self.validation_frequency = self.config.validation_frequency |
47
|
|
|
self.min_improvement = self.config.min_improvement |
48
|
|
|
self.patience = self.config.patience |
49
|
|
|
self._iter_callbacks = [] |
50
|
|
|
|
51
|
|
|
self.best_cost = 1e100 |
52
|
|
|
self.best_epoch = 0 |
53
|
|
|
self.best_params = self.copy_params() |
54
|
|
|
self._skip_batches = 0 |
55
|
|
|
self._skip_epochs = 0 |
56
|
|
|
self._progress = 0 |
57
|
|
|
self.last_cost = 0 |
58
|
|
|
self.last_run_costs = None |
59
|
|
|
self._report_time = True |
60
|
|
|
self._epoch = 0 |
61
|
|
|
|
62
|
|
|
def _compile_evaluation_func(self): |
63
|
|
|
if not self.evaluation_func: |
64
|
|
|
logging.info("compile evaluation function") |
65
|
|
|
self.evaluation_func = theano.function( |
66
|
|
|
self.network.input_variables + self.network.target_variables, |
67
|
|
|
self.evaluation_variables, |
68
|
|
|
updates=self.network.updates, |
69
|
|
|
allow_input_downcast=True, mode=self.config.get("theano_mode", None)) |
70
|
|
|
|
71
|
|
|
def skip(self, n_batches, n_epochs=0): |
72
|
|
|
""" |
73
|
|
|
Skip N batches in the training. |
74
|
|
|
""" |
75
|
|
|
logging.info("skip %d epochs and %d batches" % (n_epochs, n_batches)) |
76
|
|
|
self._skip_batches = n_batches |
77
|
|
|
self._skip_epochs = n_epochs |
78
|
|
|
|
79
|
|
|
def epoch(self): |
80
|
|
|
""" |
81
|
|
|
Get current epoch. |
82
|
|
|
""" |
83
|
|
|
return self._epoch |
84
|
|
|
|
85
|
|
|
def _setup_costs(self): |
86
|
|
|
self.cost = self._add_regularization(self.network.cost) |
87
|
|
|
self.test_cost = self._add_regularization(self.network.test_cost) |
88
|
|
|
self.training_variables = [self.cost] |
89
|
|
|
self.training_names = ['J'] |
90
|
|
|
for name, monitor in self.network.training_monitors: |
91
|
|
|
self.training_names.append(name) |
92
|
|
|
self.training_variables.append(monitor) |
93
|
|
|
logging.info("monitor list: %s" % ",".join(self.training_names)) |
94
|
|
|
|
95
|
|
|
self.evaluation_variables = [self.test_cost] |
96
|
|
|
self.evaluation_names = ['J'] |
97
|
|
|
for name, monitor in self.network.testing_monitors: |
98
|
|
|
self.evaluation_names.append(name) |
99
|
|
|
self.evaluation_variables.append(monitor) |
100
|
|
|
|
101
|
|
|
def _add_regularization(self, cost): |
102
|
|
|
if self.config.weight_l1 > 0: |
103
|
|
|
logging.info("L1 weight regularization: %f" % self.config.weight_l1) |
104
|
|
|
cost += self.config.weight_l1 * sum(abs(w).sum() for w in self.network.parameters) |
105
|
|
|
if self.config.hidden_l1 > 0: |
106
|
|
|
logging.info("L1 hidden unit regularization: %f" % self.config.hidden_l1) |
107
|
|
|
cost += self.config.hidden_l1 * sum(abs(h).mean(axis=0).sum() for h in self.network._hidden_outputs) |
108
|
|
|
if self.config.hidden_l2 > 0: |
109
|
|
|
logging.info("L2 hidden unit regularization: %f" % self.config.hidden_l2) |
110
|
|
|
cost += self.config.hidden_l2 * sum((h * h).mean(axis=0).sum() for h in self.network._hidden_outputs) |
111
|
|
|
|
112
|
|
|
return cost |
113
|
|
|
|
114
|
|
|
def set_params(self, targets, free_params=None): |
115
|
|
|
for param, target in zip(self.network.parameters, targets): |
116
|
|
|
param.set_value(target) |
117
|
|
|
if free_params: |
118
|
|
|
for param, param_value in zip(self.network.free_parameters, free_params): |
119
|
|
|
param.set_value(param_value) |
120
|
|
|
|
121
|
|
|
def save_params(self, path): |
122
|
|
|
self.set_params(*self.best_params) |
123
|
|
|
self.network.save_params(path) |
124
|
|
|
|
125
|
|
|
def load_params(self, path, exclude_free_params=False): |
126
|
|
|
""" |
127
|
|
|
Load parameters for the training. |
128
|
|
|
This method can load free parameters and resume the training progress. |
129
|
|
|
""" |
130
|
|
|
self.network.load_params(path, exclude_free_params=exclude_free_params) |
131
|
|
|
self.best_params = self.copy_params() |
132
|
|
|
# Resume the progress |
133
|
|
|
if self.network.train_logger.progress() > 0 or self.network.train_logger.epoch() > 0: |
134
|
|
|
self.skip(self.network.train_logger.progress(), self.network.train_logger.epoch() - 1) |
135
|
|
|
|
136
|
|
|
def copy_params(self): |
137
|
|
|
checkpoint = (map(lambda p: p.get_value().copy(), self.network.parameters), |
138
|
|
|
map(lambda p: p.get_value().copy(), self.network.free_parameters)) |
139
|
|
|
return checkpoint |
140
|
|
|
|
141
|
|
|
def add_iter_callback(self, func): |
142
|
|
|
""" |
143
|
|
|
Add a iteration callback function (receives an argument of the trainer). |
144
|
|
|
:return: |
145
|
|
|
""" |
146
|
|
|
self._iter_callbacks.append(func) |
147
|
|
|
|
148
|
|
|
def train(self, train_set, valid_set=None, test_set=None, train_size=None): |
149
|
|
|
""" |
150
|
|
|
Train the model and return costs. |
151
|
|
|
""" |
152
|
|
|
self._epoch = 0 |
153
|
|
|
while True: |
154
|
|
|
if self._skip_epochs > 0: |
155
|
|
|
logging.info("skipping one epoch ...") |
156
|
|
|
self._skip_epochs -= 1 |
157
|
|
|
self._epoch += 1 |
158
|
|
|
yield None |
159
|
|
|
continue |
160
|
|
|
# Test |
161
|
|
|
if not self._epoch % self.config.test_frequency and test_set: |
162
|
|
|
try: |
163
|
|
|
self._run_test(self._epoch, test_set) |
164
|
|
|
except KeyboardInterrupt: |
165
|
|
|
logging.info('interrupted!') |
166
|
|
|
break |
167
|
|
|
# Validate |
168
|
|
|
if not self._epoch % self.validation_frequency and valid_set: |
169
|
|
|
try: |
170
|
|
|
|
171
|
|
|
if not self._run_valid(self._epoch, valid_set): |
172
|
|
|
logging.info('patience elapsed, bailing out') |
173
|
|
|
break |
174
|
|
|
except KeyboardInterrupt: |
175
|
|
|
logging.info('interrupted!') |
176
|
|
|
break |
177
|
|
|
# Train one step |
178
|
|
|
|
179
|
|
|
try: |
180
|
|
|
costs = self._run_train(self._epoch, train_set, train_size) |
181
|
|
|
except KeyboardInterrupt: |
182
|
|
|
logging.info('interrupted!') |
183
|
|
|
break |
184
|
|
|
# Check costs |
185
|
|
|
if np.isnan(costs[0][1]): |
186
|
|
|
logging.info("NaN detected in costs, rollback to last parameters") |
187
|
|
|
self.set_params(*self.checkpoint) |
188
|
|
|
else: |
189
|
|
|
self._epoch += 1 |
190
|
|
|
self.network.epoch_callback() |
191
|
|
|
|
192
|
|
|
yield dict(costs) |
193
|
|
|
|
194
|
|
|
if valid_set and self.config.get("save_best_parameters", True): |
195
|
|
|
self.set_params(*self.best_params) |
196
|
|
|
if test_set: |
197
|
|
|
self._run_test(-1, test_set) |
198
|
|
|
|
199
|
|
|
@abstractmethod |
200
|
|
|
def learn(self, *variables): |
201
|
|
|
""" |
202
|
|
|
Update the parameters and return the cost with given data points. |
203
|
|
|
:param variables: |
204
|
|
|
:return: |
205
|
|
|
""" |
206
|
|
|
|
207
|
|
|
def _run_test(self, iteration, test_set): |
208
|
|
|
""" |
209
|
|
|
Run on test iteration. |
210
|
|
|
""" |
211
|
|
|
costs = self.test_step(test_set) |
212
|
|
|
info = ' '.join('%s=%.2f' % el for el in costs) |
213
|
|
|
message = "test (epoch=%i) %s" % (iteration + 1, info) |
214
|
|
|
logging.info(message) |
215
|
|
|
self.network.train_logger.record(message) |
216
|
|
|
self.last_run_costs = costs |
217
|
|
|
|
218
|
|
|
def _run_train(self, epoch, train_set, train_size=None): |
219
|
|
|
""" |
220
|
|
|
Run one training iteration. |
221
|
|
|
""" |
222
|
|
|
self.network.train_logger.record_epoch(epoch + 1) |
223
|
|
|
costs = self.train_step(train_set, train_size) |
224
|
|
|
if not epoch % self.config.monitor_frequency: |
225
|
|
|
info = " ".join("%s=%.2f" % item for item in costs) |
226
|
|
|
message = "monitor (epoch=%i) %s" % (epoch + 1, info) |
227
|
|
|
logging.info(message) |
228
|
|
|
self.network.train_logger.record(message) |
229
|
|
|
self.last_run_costs = costs |
230
|
|
|
return costs |
231
|
|
|
|
232
|
|
|
def _run_valid(self, epoch, valid_set, dry_run=False): |
233
|
|
|
""" |
234
|
|
|
Run one valid iteration, return true if to continue training. |
235
|
|
|
""" |
236
|
|
|
costs = self.valid_step(valid_set) |
237
|
|
|
# this is the same as: (J_i - J_f) / J_i > min improvement |
238
|
|
|
_, J = costs[0] |
239
|
|
|
marker = "" |
240
|
|
|
if self.best_cost - J > self.best_cost * self.min_improvement: |
241
|
|
|
# save the best cost and parameters |
242
|
|
|
self.best_params = self.copy_params() |
243
|
|
|
marker = ' *' |
244
|
|
|
if not dry_run: |
245
|
|
|
self.best_cost = J |
246
|
|
|
self.best_epoch = epoch |
247
|
|
|
|
248
|
|
|
if self.config.auto_save and self._skip_batches == 0: |
249
|
|
|
self.network.train_logger.record_progress(self._progress) |
250
|
|
|
self.network.save_params(self.config.auto_save, new_thread=True) |
251
|
|
|
|
252
|
|
|
info = ' '.join('%s=%.2f' % el for el in costs) |
253
|
|
|
epoch_str = "epoch=%d" % (epoch + 1) |
254
|
|
|
if dry_run: |
255
|
|
|
epoch_str = "dryrun" + " " * (len(epoch_str) - 6) |
256
|
|
|
message = "valid (%s) %s%s" % (epoch_str, info, marker) |
257
|
|
|
logging.info(message) |
258
|
|
|
self.last_run_costs = costs |
259
|
|
|
self.network.train_logger.record(message) |
260
|
|
|
self.checkpoint = self.copy_params() |
261
|
|
|
return epoch - self.best_epoch < self.patience |
262
|
|
|
|
263
|
|
|
def test_step(self, test_set): |
264
|
|
|
self._compile_evaluation_func() |
265
|
|
|
costs = list(zip( |
266
|
|
|
self.evaluation_names, |
267
|
|
|
np.mean([self.evaluation_func(*x) for x in test_set], axis=0))) |
268
|
|
|
return costs |
269
|
|
|
|
270
|
|
|
def valid_step(self, valid_set): |
271
|
|
|
self._compile_evaluation_func() |
272
|
|
|
costs = list(zip( |
273
|
|
|
self.evaluation_names, |
274
|
|
|
np.mean([self.evaluation_func(*x) for x in valid_set], axis=0))) |
275
|
|
|
return costs |
276
|
|
|
|
277
|
|
|
def train_step(self, train_set, train_size=None): |
278
|
|
|
dirty_trick_times = 0 |
279
|
|
|
network_callback = bool(self.network.training_callbacks) |
280
|
|
|
trainer_callback = bool(self._iter_callbacks) |
281
|
|
|
cost_matrix = [] |
282
|
|
|
self._progress = 0 |
283
|
|
|
|
284
|
|
|
for x in train_set: |
285
|
|
|
if self._skip_batches == 0: |
286
|
|
|
|
287
|
|
|
if dirty_trick_times > 0: |
288
|
|
|
cost_x = self.learn(*[t[:(t.shape[0]/2)] for t in x]) |
289
|
|
|
cost_matrix.append(cost_x) |
290
|
|
|
cost_x = self.learn(*[t[(t.shape[0]/2):] for t in x]) |
291
|
|
|
dirty_trick_times -= 1 |
292
|
|
|
else: |
293
|
|
|
try: |
294
|
|
|
cost_x = self.learn(*x) |
295
|
|
|
except MemoryError: |
296
|
|
|
logging.info("Memory error was detected, perform dirty trick 30 times") |
297
|
|
|
dirty_trick_times = 30 |
298
|
|
|
# Dirty trick |
299
|
|
|
cost_x = self.learn(*[t[:(t.shape[0]/2)] for t in x]) |
300
|
|
|
cost_matrix.append(cost_x) |
301
|
|
|
cost_x = self.learn(*[t[(t.shape[0]/2):] for t in x]) |
302
|
|
|
cost_matrix.append(cost_x) |
303
|
|
|
self.last_cost = cost_x[0] |
304
|
|
|
if network_callback: |
305
|
|
|
self.network.training_callback() |
306
|
|
|
if trainer_callback: |
307
|
|
|
for func in self._iter_callbacks: |
308
|
|
|
func(self) |
309
|
|
|
else: |
310
|
|
|
self._skip_batches -= 1 |
311
|
|
|
if train_size: |
312
|
|
|
self._progress += 1 |
313
|
|
|
sys.stdout.write("\x1b[2K\r> %d%% | J=%.2f" % (self._progress * 100 / train_size, self.last_cost)) |
314
|
|
|
sys.stdout.flush() |
315
|
|
|
self._progress = 0 |
316
|
|
|
|
317
|
|
|
if train_size: |
318
|
|
|
sys.stdout.write("\r") |
319
|
|
|
sys.stdout.flush() |
320
|
|
|
costs = list(zip(self.training_names, np.mean(cost_matrix, axis=0))) |
321
|
|
|
return costs |
322
|
|
|
|
323
|
|
|
def run(self, train_set, valid_set=None, test_set=None, train_size=None, controllers=None): |
324
|
|
|
""" |
325
|
|
|
Run until the end. |
326
|
|
|
""" |
327
|
|
|
if isinstance(train_set, Dataset): |
328
|
|
|
dataset = train_set |
329
|
|
|
train_set = dataset.train_set() |
330
|
|
|
valid_set = dataset.valid_set() |
331
|
|
|
test_set = dataset.test_set() |
332
|
|
|
train_size = dataset.train_size() |
333
|
|
|
if controllers: |
334
|
|
|
for controller in controllers: |
335
|
|
|
controller.bind(self) |
336
|
|
|
timer = Timer() |
337
|
|
|
for _ in self.train(train_set, valid_set=valid_set, test_set=test_set, train_size=train_size): |
338
|
|
|
if controllers: |
339
|
|
|
ending = False |
340
|
|
|
for controller in controllers: |
341
|
|
|
if hasattr(controller, 'invoke') and controller.invoke(): |
342
|
|
|
ending = True |
343
|
|
|
if ending: |
344
|
|
|
break |
345
|
|
|
if self._report_time: |
346
|
|
|
timer.report() |
347
|
|
|
|