|
1
|
|
|
#!/usr/bin/env python |
|
2
|
|
|
# -*- coding: utf-8 -*- |
|
3
|
|
|
|
|
4
|
|
|
import logging as loggers |
|
5
|
|
|
|
|
6
|
|
|
import sys |
|
7
|
|
|
import numpy as np |
|
8
|
|
|
import theano |
|
9
|
|
|
import theano.tensor as T |
|
10
|
|
|
|
|
11
|
|
|
from deepy.conf import TrainerConfig |
|
12
|
|
|
from deepy.dataset import Dataset |
|
13
|
|
|
from deepy.trainers.optimize import optimize_updates |
|
14
|
|
|
from deepy.utils import Timer |
|
15
|
|
|
|
|
16
|
|
|
|
|
17
|
|
|
logging = loggers.getLogger(__name__) |
|
18
|
|
|
|
|
19
|
|
|
THEANO_LINKER = 'cvm' |
|
20
|
|
|
|
|
21
|
|
|
def inspect_inputs(i, node, fn): |
|
22
|
|
|
print i, node, "input(s) value(s):", [input[0] for input in fn.inputs], |
|
23
|
|
|
|
|
24
|
|
|
def inspect_outputs(i, node, fn): |
|
25
|
|
|
print "output(s) value(s):", [output[0] for output in fn.outputs] |
|
26
|
|
|
|
|
27
|
|
|
def default_mapper(f, dataset, *args, **kwargs): |
|
28
|
|
|
'''Apply a function to each element of a dataset.''' |
|
29
|
|
|
return [f(x, *args, **kwargs) for x in dataset] |
|
30
|
|
|
|
|
31
|
|
|
def ipcluster_mapper(client): |
|
32
|
|
|
'''Get a mapper from an IPython.parallel cluster client.''' |
|
33
|
|
|
view = client.load_balanced_view() |
|
34
|
|
|
def mapper(f, dataset, *args, **kwargs): |
|
35
|
|
|
def ff(x): |
|
36
|
|
|
return f(x, *args, **kwargs) |
|
37
|
|
|
return view.map(ff, dataset).get() |
|
38
|
|
|
return mapper |
|
39
|
|
|
|
|
40
|
|
|
def save_network_params(network, path): |
|
41
|
|
|
network.save_params(path) |
|
42
|
|
|
|
|
43
|
|
|
|
|
44
|
|
|
class NeuralTrainer(object): |
|
45
|
|
|
'''This is a base class for all trainers.''' |
|
46
|
|
|
|
|
47
|
|
|
def __init__(self, network, config=None): |
|
48
|
|
|
""" |
|
49
|
|
|
Basic neural network trainer. |
|
50
|
|
|
:type network: deepy.NeuralNetwork |
|
51
|
|
|
:type config: deepy.conf.TrainerConfig |
|
52
|
|
|
:return: |
|
53
|
|
|
""" |
|
54
|
|
|
super(NeuralTrainer, self).__init__() |
|
55
|
|
|
|
|
56
|
|
|
self.config = None |
|
57
|
|
|
if isinstance(config, TrainerConfig): |
|
58
|
|
|
self.config = config |
|
59
|
|
|
elif isinstance(config, dict): |
|
60
|
|
|
self.config = TrainerConfig(config) |
|
61
|
|
|
else: |
|
62
|
|
|
self.config = TrainerConfig() |
|
63
|
|
|
self.network = network |
|
64
|
|
|
|
|
65
|
|
|
self.network.prepare_training() |
|
66
|
|
|
self._setup_costs() |
|
67
|
|
|
|
|
68
|
|
|
logging.info("compile evaluation function") |
|
69
|
|
|
self.evaluation_func = theano.function( |
|
70
|
|
|
network.input_variables + network.target_variables, self.evaluation_variables, updates=network.updates, |
|
71
|
|
|
allow_input_downcast=True, mode=self.config.get("theano_mode", None)) |
|
72
|
|
|
self.learning_func = None |
|
73
|
|
|
|
|
74
|
|
|
self.validation_frequency = self.config.validation_frequency |
|
75
|
|
|
self.min_improvement = self.config.min_improvement |
|
76
|
|
|
self.patience = self.config.patience |
|
77
|
|
|
self._iter_callbacks = [] |
|
78
|
|
|
|
|
79
|
|
|
self.best_cost = 1e100 |
|
80
|
|
|
self.best_iter = 0 |
|
81
|
|
|
self.best_params = self.copy_params() |
|
82
|
|
|
self._skip_batches = 0 |
|
83
|
|
|
self._progress = 0 |
|
84
|
|
|
|
|
85
|
|
|
def skip(self, n_batches): |
|
86
|
|
|
""" |
|
87
|
|
|
Skip N batches in the training. |
|
88
|
|
|
""" |
|
89
|
|
|
logging.info("Skip %d batches" % n_batches) |
|
90
|
|
|
self._skip_batches = n_batches |
|
91
|
|
|
|
|
92
|
|
|
def _setup_costs(self): |
|
93
|
|
|
self.cost = self._add_regularization(self.network.cost) |
|
94
|
|
|
self.test_cost = self._add_regularization(self.network.test_cost) |
|
95
|
|
|
self.training_variables = [self.cost] |
|
96
|
|
|
self.training_names = ['J'] |
|
97
|
|
|
for name, monitor in self.network.training_monitors: |
|
98
|
|
|
self.training_names.append(name) |
|
99
|
|
|
self.training_variables.append(monitor) |
|
100
|
|
|
logging.info("monitor list: %s" % ",".join(self.training_names)) |
|
101
|
|
|
|
|
102
|
|
|
self.evaluation_variables = [self.test_cost] |
|
103
|
|
|
self.evaluation_names = ['J'] |
|
104
|
|
|
for name, monitor in self.network.testing_monitors: |
|
105
|
|
|
self.evaluation_names.append(name) |
|
106
|
|
|
self.evaluation_variables.append(monitor) |
|
107
|
|
|
|
|
108
|
|
|
def _add_regularization(self, cost): |
|
109
|
|
|
if self.config.weight_l1 > 0: |
|
110
|
|
|
logging.info("L1 weight regularization: %f" % self.config.weight_l1) |
|
111
|
|
|
cost += self.config.weight_l1 * sum(abs(w).sum() for w in self.network.parameters) |
|
112
|
|
|
if self.config.hidden_l1 > 0: |
|
113
|
|
|
logging.info("L1 hidden unit regularization: %f" % self.config.hidden_l1) |
|
114
|
|
|
cost += self.config.hidden_l1 * sum(abs(h).mean(axis=0).sum() for h in self.network._hidden_outputs) |
|
115
|
|
|
if self.config.hidden_l2 > 0: |
|
116
|
|
|
logging.info("L2 hidden unit regularization: %f" % self.config.hidden_l2) |
|
117
|
|
|
cost += self.config.hidden_l2 * sum((h * h).mean(axis=0).sum() for h in self.network._hidden_outputs) |
|
118
|
|
|
|
|
119
|
|
|
return cost |
|
120
|
|
|
|
|
121
|
|
|
def set_params(self, targets, free_params=None): |
|
122
|
|
|
for param, target in zip(self.network.parameters, targets): |
|
123
|
|
|
param.set_value(target) |
|
124
|
|
|
if free_params: |
|
125
|
|
|
for param, param_value in zip(self.network.free_parameters, free_params): |
|
126
|
|
|
param.set_value(param_value) |
|
127
|
|
|
|
|
128
|
|
|
def save_params(self, path): |
|
129
|
|
|
self.set_params(*self.best_params) |
|
130
|
|
|
self.network.save_params(path) |
|
131
|
|
|
|
|
132
|
|
|
def load_params(self, path, exclude_free_params=False): |
|
133
|
|
|
""" |
|
134
|
|
|
Load parameters for the training. |
|
135
|
|
|
This method can load free parameters and resume the training progress. |
|
136
|
|
|
""" |
|
137
|
|
|
self.network.load_params(path, exclude_free_params=exclude_free_params) |
|
138
|
|
|
self.best_params = self.copy_params() |
|
139
|
|
|
# Resume the progress |
|
140
|
|
|
if self.network.train_logger.progress() > 0: |
|
141
|
|
|
self.skip(self.network.train_logger.progress()) |
|
142
|
|
|
|
|
143
|
|
|
def copy_params(self): |
|
144
|
|
|
checkpoint = (map(lambda p: p.get_value().copy(), self.network.parameters), |
|
145
|
|
|
map(lambda p: p.get_value().copy(), self.network.free_parameters)) |
|
146
|
|
|
return checkpoint |
|
147
|
|
|
|
|
148
|
|
|
def add_iter_callback(self, func): |
|
149
|
|
|
""" |
|
150
|
|
|
Add a iteration callback function (receives an argument of the trainer). |
|
151
|
|
|
:return: |
|
152
|
|
|
""" |
|
153
|
|
|
self._iter_callbacks.append(func) |
|
154
|
|
|
|
|
155
|
|
|
def train(self, train_set, valid_set=None, test_set=None, train_size=None): |
|
156
|
|
|
""" |
|
157
|
|
|
Train the model and return costs. |
|
158
|
|
|
""" |
|
159
|
|
|
if not self.learning_func: |
|
160
|
|
|
raise NotImplementedError |
|
161
|
|
|
iteration = 0 |
|
162
|
|
|
while True: |
|
163
|
|
|
# Test |
|
164
|
|
|
if not iteration % self.config.test_frequency and test_set: |
|
165
|
|
|
try: |
|
166
|
|
|
self._run_test(iteration, test_set) |
|
167
|
|
|
except KeyboardInterrupt: |
|
168
|
|
|
logging.info('interrupted!') |
|
169
|
|
|
break |
|
170
|
|
|
# Validate |
|
171
|
|
|
if not iteration % self.validation_frequency and valid_set: |
|
172
|
|
|
try: |
|
173
|
|
|
|
|
174
|
|
|
if not self._run_valid(iteration, valid_set): |
|
175
|
|
|
logging.info('patience elapsed, bailing out') |
|
176
|
|
|
break |
|
177
|
|
|
except KeyboardInterrupt: |
|
178
|
|
|
logging.info('interrupted!') |
|
179
|
|
|
break |
|
180
|
|
|
# Train one step |
|
181
|
|
|
try: |
|
182
|
|
|
costs = self._run_train(iteration, train_set, train_size) |
|
183
|
|
|
except KeyboardInterrupt: |
|
184
|
|
|
logging.info('interrupted!') |
|
185
|
|
|
break |
|
186
|
|
|
# Check costs |
|
187
|
|
|
if np.isnan(costs[0][1]): |
|
188
|
|
|
logging.info("NaN detected in costs, rollback to last parameters") |
|
189
|
|
|
self.set_params(*self.checkpoint) |
|
190
|
|
|
else: |
|
191
|
|
|
iteration += 1 |
|
192
|
|
|
self.network.epoch_callback() |
|
193
|
|
|
|
|
194
|
|
|
yield dict(costs) |
|
195
|
|
|
|
|
196
|
|
|
if valid_set and self.config.get("save_best_parameters", True): |
|
197
|
|
|
self.set_params(*self.best_params) |
|
198
|
|
|
if test_set: |
|
199
|
|
|
self._run_test(-1, test_set) |
|
200
|
|
|
|
|
201
|
|
|
def _run_test(self, iteration, test_set): |
|
202
|
|
|
""" |
|
203
|
|
|
Run on test iteration. |
|
204
|
|
|
""" |
|
205
|
|
|
costs = self.test_step(test_set) |
|
206
|
|
|
info = ' '.join('%s=%.2f' % el for el in costs) |
|
207
|
|
|
message = "test (iter=%i) %s" % (iteration + 1, info) |
|
208
|
|
|
logging.info(message) |
|
209
|
|
|
self.network.train_logger.record(message) |
|
210
|
|
|
|
|
211
|
|
|
def _run_train(self, iteration, train_set, train_size=None): |
|
212
|
|
|
""" |
|
213
|
|
|
Run one training iteration. |
|
214
|
|
|
""" |
|
215
|
|
|
costs = self.train_step(train_set, train_size) |
|
216
|
|
|
|
|
217
|
|
|
if not iteration % self.config.monitor_frequency: |
|
218
|
|
|
info = " ".join("%s=%.2f" % item for item in costs) |
|
219
|
|
|
message = "monitor (iter=%i) %s" % (iteration + 1, info) |
|
220
|
|
|
logging.info(message) |
|
221
|
|
|
self.network.train_logger.record(message) |
|
222
|
|
|
return costs |
|
223
|
|
|
|
|
224
|
|
|
def _run_valid(self, iteration, valid_set, dry_run=False): |
|
225
|
|
|
""" |
|
226
|
|
|
Run one valid iteration, return true if to continue training. |
|
227
|
|
|
""" |
|
228
|
|
|
costs = self.valid_step(valid_set) |
|
229
|
|
|
# this is the same as: (J_i - J_f) / J_i > min improvement |
|
230
|
|
|
_, J = costs[0] |
|
231
|
|
|
marker = "" |
|
232
|
|
|
if self.best_cost - J > self.best_cost * self.min_improvement: |
|
233
|
|
|
# save the best cost and parameters |
|
234
|
|
|
self.best_params = self.copy_params() |
|
235
|
|
|
marker = ' *' |
|
236
|
|
|
if not dry_run: |
|
237
|
|
|
self.best_cost = J |
|
238
|
|
|
self.best_iter = iteration |
|
239
|
|
|
|
|
240
|
|
|
if self.config.auto_save: |
|
241
|
|
|
self.network.train_logger.record_progress(self._progress) |
|
242
|
|
|
self.network.save_params(self.config.auto_save, new_thread=True) |
|
243
|
|
|
|
|
244
|
|
|
info = ' '.join('%s=%.2f' % el for el in costs) |
|
245
|
|
|
iter_str = "iter=%d" % (iteration + 1) |
|
246
|
|
|
if dry_run: |
|
247
|
|
|
iter_str = "dryrun" + " " * (len(iter_str) - 6) |
|
248
|
|
|
message = "valid (%s) %s%s" % (iter_str, info, marker) |
|
249
|
|
|
logging.info(message) |
|
250
|
|
|
self.network.train_logger.record(message) |
|
251
|
|
|
self.checkpoint = self.copy_params() |
|
252
|
|
|
return iteration - self.best_iter < self.patience |
|
253
|
|
|
|
|
254
|
|
|
def test_step(self, test_set): |
|
255
|
|
|
costs = list(zip( |
|
256
|
|
|
self.evaluation_names, |
|
257
|
|
|
np.mean([self.evaluation_func(*x) for x in test_set], axis=0))) |
|
258
|
|
|
return costs |
|
259
|
|
|
|
|
260
|
|
|
def valid_step(self, valid_set): |
|
261
|
|
|
costs = list(zip( |
|
262
|
|
|
self.evaluation_names, |
|
263
|
|
|
np.mean([self.evaluation_func(*x) for x in valid_set], axis=0))) |
|
264
|
|
|
return costs |
|
265
|
|
|
|
|
266
|
|
|
def train_step(self, train_set, train_size=None): |
|
267
|
|
|
dirty_trick_times = 0 |
|
268
|
|
|
network_callback = bool(self.network.training_callbacks) |
|
269
|
|
|
trainer_callback = bool(self._iter_callbacks) |
|
270
|
|
|
cost_matrix = [] |
|
271
|
|
|
self._progress = 0 |
|
272
|
|
|
|
|
273
|
|
|
for x in train_set: |
|
274
|
|
|
if self._skip_batches == 0: |
|
275
|
|
|
if dirty_trick_times > 0: |
|
276
|
|
|
cost_x = self.learning_func(*[t[:(t.shape[0]/2)] for t in x]) |
|
277
|
|
|
cost_matrix.append(cost_x) |
|
278
|
|
|
cost_x = self.learning_func(*[t[(t.shape[0]/2):] for t in x]) |
|
279
|
|
|
dirty_trick_times -= 1 |
|
280
|
|
|
else: |
|
281
|
|
|
try: |
|
282
|
|
|
cost_x = self.learning_func(*x) |
|
283
|
|
|
except MemoryError: |
|
284
|
|
|
logging.info("Memory error was detected, perform dirty trick 30 times") |
|
285
|
|
|
dirty_trick_times = 30 |
|
286
|
|
|
# Dirty trick |
|
287
|
|
|
cost_x = self.learning_func(*[t[:(t.shape[0]/2)] for t in x]) |
|
288
|
|
|
cost_matrix.append(cost_x) |
|
289
|
|
|
cost_x = self.learning_func(*[t[(t.shape[0]/2):] for t in x]) |
|
290
|
|
|
cost_matrix.append(cost_x) |
|
291
|
|
|
if network_callback: |
|
292
|
|
|
self.last_score = cost_x[0] |
|
293
|
|
|
self.network.training_callback() |
|
294
|
|
|
if trainer_callback: |
|
295
|
|
|
self.last_score = cost_x[0] |
|
296
|
|
|
for func in self._iter_callbacks: |
|
297
|
|
|
func(self) |
|
298
|
|
|
else: |
|
299
|
|
|
self._skip_batches -= 1 |
|
300
|
|
|
if train_size: |
|
301
|
|
|
self._progress += 1 |
|
302
|
|
|
sys.stdout.write("\r> %d%%" % (self._progress * 100 / train_size)) |
|
303
|
|
|
sys.stdout.flush() |
|
304
|
|
|
self._progress = 0 |
|
305
|
|
|
|
|
306
|
|
|
if train_size: |
|
307
|
|
|
sys.stdout.write("\r") |
|
308
|
|
|
sys.stdout.flush() |
|
309
|
|
|
costs = list(zip(self.training_names, np.mean(cost_matrix, axis=0))) |
|
310
|
|
|
return costs |
|
311
|
|
|
|
|
312
|
|
|
def run(self, train_set, valid_set=None, test_set=None, train_size=None, controllers=None): |
|
313
|
|
|
""" |
|
314
|
|
|
Run until the end. |
|
315
|
|
|
""" |
|
316
|
|
|
if isinstance(train_set, Dataset): |
|
317
|
|
|
dataset = train_set |
|
318
|
|
|
train_set = dataset.train_set() |
|
319
|
|
|
valid_set = dataset.valid_set() |
|
320
|
|
|
test_set = dataset.test_set() |
|
321
|
|
|
train_size = dataset.train_size() |
|
322
|
|
|
|
|
323
|
|
|
timer = Timer() |
|
324
|
|
|
for _ in self.train(train_set, valid_set=valid_set, test_set=test_set, train_size=train_size): |
|
325
|
|
|
if controllers: |
|
326
|
|
|
ending = False |
|
327
|
|
|
for controller in controllers: |
|
328
|
|
|
if hasattr(controller, 'invoke') and controller.invoke(): |
|
329
|
|
|
ending = True |
|
330
|
|
|
if ending: |
|
331
|
|
|
break |
|
332
|
|
|
timer.report() |
|
333
|
|
|
return |
|
334
|
|
|
|
|
335
|
|
|
class GeneralNeuralTrainer(NeuralTrainer): |
|
336
|
|
|
""" |
|
337
|
|
|
General neural network trainer. |
|
338
|
|
|
""" |
|
339
|
|
|
def __init__(self, network, config=None, method=None): |
|
340
|
|
|
|
|
341
|
|
|
if method: |
|
342
|
|
|
logging.info("changing optimization method to '%s'" % method) |
|
343
|
|
|
if not config: |
|
344
|
|
|
config = TrainerConfig() |
|
345
|
|
|
elif isinstance(config, dict): |
|
346
|
|
|
config = TrainerConfig(config) |
|
347
|
|
|
config.method = method |
|
348
|
|
|
|
|
349
|
|
|
super(GeneralNeuralTrainer, self).__init__(network, config) |
|
350
|
|
|
|
|
351
|
|
|
logging.info('compiling %s learning function', self.__class__.__name__) |
|
352
|
|
|
|
|
353
|
|
|
network_updates = list(network.updates) + list(network.training_updates) |
|
354
|
|
|
learning_updates = list(self.learning_updates()) |
|
355
|
|
|
update_list = network_updates + learning_updates |
|
356
|
|
|
|
|
357
|
|
|
logging.info("network updates: %s" % " ".join(map(str, [x[0] for x in network_updates]))) |
|
358
|
|
|
logging.info("learning updates: %s" % " ".join(map(str, [x[0] for x in learning_updates]))) |
|
359
|
|
|
|
|
360
|
|
|
|
|
361
|
|
|
variables = network.input_variables + network.target_variables |
|
362
|
|
|
givens = None |
|
363
|
|
|
|
|
364
|
|
|
self.learning_func = theano.function( |
|
365
|
|
|
variables, |
|
366
|
|
|
map(lambda v: theano.Out(v, borrow=True), self.training_variables), |
|
367
|
|
|
updates=update_list, allow_input_downcast=True, |
|
368
|
|
|
mode=self.config.get("theano_mode", None), |
|
369
|
|
|
givens=givens) |
|
370
|
|
|
|
|
371
|
|
|
|
|
372
|
|
|
def learning_updates(self): |
|
373
|
|
|
""" |
|
374
|
|
|
Return updates in the training. |
|
375
|
|
|
""" |
|
376
|
|
|
params = self.network.parameters |
|
377
|
|
|
# Freeze parameters |
|
378
|
|
|
if self.config.fixed_parameters: |
|
379
|
|
|
logging.info("fixed parameters: %s" % ", ".join(map(str, self.config.fixed_parameters))) |
|
380
|
|
|
params = [p for p in params if p not in self.config.fixed_parameters] |
|
381
|
|
|
gradients = T.grad(self.cost, params) |
|
382
|
|
|
updates, free_parameters = optimize_updates(params, gradients, self.config) |
|
383
|
|
|
self.network.free_parameters.extend(free_parameters) |
|
384
|
|
|
logging.info("Added %d free parameters for optimization" % len(free_parameters)) |
|
385
|
|
|
return updates |
|
386
|
|
|
|
|
387
|
|
|
|
|
388
|
|
|
class SGDTrainer(GeneralNeuralTrainer): |
|
389
|
|
|
""" |
|
390
|
|
|
SGD trainer. |
|
391
|
|
|
""" |
|
392
|
|
|
def __init__(self, network, config=None): |
|
393
|
|
|
super(SGDTrainer, self).__init__(network, config, "SGD") |
|
394
|
|
|
|
|
395
|
|
|
class AdaDeltaTrainer(GeneralNeuralTrainer): |
|
396
|
|
|
""" |
|
397
|
|
|
AdaDelta trainer. |
|
398
|
|
|
""" |
|
399
|
|
|
def __init__(self, network, config=None): |
|
400
|
|
|
super(AdaDeltaTrainer, self).__init__(network, config, "ADADELTA") |
|
401
|
|
|
|
|
402
|
|
|
|
|
403
|
|
|
class AdaGradTrainer(GeneralNeuralTrainer): |
|
404
|
|
|
""" |
|
405
|
|
|
AdaGrad trainer. |
|
406
|
|
|
""" |
|
407
|
|
|
def __init__(self, network, config=None): |
|
408
|
|
|
super(AdaGradTrainer, self).__init__(network, config, "ADAGRAD") |
|
409
|
|
|
|
|
410
|
|
|
class FineTuningAdaGradTrainer(GeneralNeuralTrainer): |
|
411
|
|
|
""" |
|
412
|
|
|
AdaGrad trainer. |
|
413
|
|
|
""" |
|
414
|
|
|
def __init__(self, network, config=None): |
|
415
|
|
|
super(FineTuningAdaGradTrainer, self).__init__(network, config, "FINETUNING_ADAGRAD") |
|
416
|
|
|
|
|
417
|
|
|
class AdamTrainer(GeneralNeuralTrainer): |
|
418
|
|
|
""" |
|
419
|
|
|
AdaGrad trainer. |
|
420
|
|
|
""" |
|
421
|
|
|
def __init__(self, network, config=None): |
|
422
|
|
|
super(AdamTrainer, self).__init__(network, config, "ADAM") |
|
423
|
|
|
|
|
424
|
|
|
class RmspropTrainer(GeneralNeuralTrainer): |
|
425
|
|
|
""" |
|
426
|
|
|
RmsProp trainer. |
|
427
|
|
|
""" |
|
428
|
|
|
def __init__(self, network, config=None): |
|
429
|
|
|
super(RmspropTrainer, self).__init__(network, config, "RMSPROP") |
|
430
|
|
|
|
|
431
|
|
|
class MomentumTrainer(GeneralNeuralTrainer): |
|
432
|
|
|
""" |
|
433
|
|
|
Momentum trainer. |
|
434
|
|
|
""" |
|
435
|
|
|
def __init__(self, network, config=None): |
|
436
|
|
|
super(MomentumTrainer, self).__init__(network, config, "MOMENTUM") |
|
437
|
|
|
|
|
438
|
|
|
|
|
439
|
|
|
class SSGD2Trainer(NeuralTrainer): |
|
440
|
|
|
""" |
|
441
|
|
|
Optimization class of SSGD. |
|
442
|
|
|
""" |
|
443
|
|
|
|
|
444
|
|
|
def __init__(self, network, config=None): |
|
445
|
|
|
super(SSGD2Trainer, self).__init__(network, config) |
|
446
|
|
|
|
|
447
|
|
|
self.learning_rate = self.config.learning_rate |
|
448
|
|
|
|
|
449
|
|
|
logging.info('compiling %s learning function', self.__class__.__name__) |
|
450
|
|
|
|
|
451
|
|
|
network_updates = list(network.updates) + list(network.learning_updates) |
|
452
|
|
|
learning_updates = list(self.learning_updates()) |
|
453
|
|
|
update_list = network_updates + learning_updates |
|
454
|
|
|
logging.info("network updates: %s" % " ".join(map(str, [x[0] for x in network_updates]))) |
|
455
|
|
|
logging.info("learning updates: %s" % " ".join(map(str, [x[0] for x in learning_updates]))) |
|
456
|
|
|
|
|
457
|
|
|
self.learning_func = theano.function( |
|
458
|
|
|
network.inputs, |
|
459
|
|
|
self.training_variables, |
|
460
|
|
|
updates=update_list, allow_input_downcast=True, mode=self.config.get("theano_mode", None)) |
|
461
|
|
|
|
|
462
|
|
|
def ssgd2(self, loss, all_params, learning_rate=0.01, chaos_energy=0.01, alpha=0.9): |
|
463
|
|
|
from theano.tensor.shared_randomstreams import RandomStreams |
|
464
|
|
|
|
|
465
|
|
|
chaos_energy = T.constant(chaos_energy, dtype="float32") |
|
466
|
|
|
alpha = T.constant(alpha, dtype="float32") |
|
467
|
|
|
learning_rate = T.constant(learning_rate, dtype="float32") |
|
468
|
|
|
|
|
469
|
|
|
srng = RandomStreams(seed=3) |
|
470
|
|
|
updates = [] |
|
471
|
|
|
all_grads = T.grad(loss, all_params) |
|
472
|
|
|
for p, g in zip(all_params, all_grads): |
|
473
|
|
|
rand_v = (srng.uniform(p.get_value().shape)*2 - 1) * chaos_energy |
|
474
|
|
|
g_ratio_vec = g / g.norm(L=2) |
|
475
|
|
|
ratio_sum = theano.shared(np.ones(np.array(p.get_value().shape), dtype="float32"), name="ssgd2_r_sum_%s" % p.name) |
|
476
|
|
|
abs_ratio_sum = T.abs_(ratio_sum) |
|
477
|
|
|
updates.append((ratio_sum, ratio_sum * alpha + (1 - alpha ) * g_ratio_vec)) |
|
478
|
|
|
updates.append((p, p - learning_rate*((abs_ratio_sum)*g + (1-abs_ratio_sum)*rand_v))) |
|
479
|
|
|
return updates |
|
480
|
|
|
|
|
481
|
|
|
def learning_updates(self): |
|
482
|
|
|
return self.ssgd2(self.cost, self.network.parameters, learning_rate=self.learning_rate) |
|
483
|
|
|
|
|
484
|
|
|
class FakeTrainer(NeuralTrainer): |
|
485
|
|
|
""" |
|
486
|
|
|
Fake Trainer does nothing. |
|
487
|
|
|
""" |
|
488
|
|
|
|
|
489
|
|
|
def __init__(self, network, config=None): |
|
490
|
|
|
super(FakeTrainer, self).__init__(network, config) |
|
491
|
|
|
|
|
492
|
|
|
self.learning_rate = self.config.learning_rate |
|
493
|
|
|
|
|
494
|
|
|
logging.info('compiling %s learning function', self.__class__.__name__) |
|
495
|
|
|
|
|
496
|
|
|
network_updates = list(network.updates) + list(network.learning_updates) |
|
497
|
|
|
learning_updates = [] |
|
498
|
|
|
update_list = network_updates + learning_updates |
|
499
|
|
|
logging.info("network updates: %s" % " ".join(map(str, [x[0] for x in network_updates]))) |
|
500
|
|
|
logging.info("learning updates: %s" % " ".join(map(str, [x[0] for x in learning_updates]))) |
|
501
|
|
|
|
|
502
|
|
|
self.learning_func = theano.function( |
|
503
|
|
|
network.inputs, |
|
504
|
|
|
self.training_variables, |
|
505
|
|
|
updates=update_list, allow_input_downcast=True, mode=self.config.get("theano_mode", None)) |
|
506
|
|
|
|