1
|
|
|
#!/usr/bin/env python |
2
|
|
|
# -*- coding: utf-8 -*- |
3
|
|
|
|
4
|
|
|
import logging as loggers |
5
|
|
|
|
6
|
|
|
import sys |
7
|
|
|
import numpy as np |
8
|
|
|
import theano |
9
|
|
|
import theano.tensor as T |
10
|
|
|
|
11
|
|
|
from deepy.conf import TrainerConfig |
12
|
|
|
from deepy.dataset import Dataset |
13
|
|
|
from deepy.trainers.optimize import optimize_updates |
14
|
|
|
from deepy.utils import Timer |
15
|
|
|
|
16
|
|
|
|
17
|
|
|
logging = loggers.getLogger(__name__) |
18
|
|
|
|
19
|
|
|
THEANO_LINKER = 'cvm' |
20
|
|
|
|
21
|
|
|
def inspect_inputs(i, node, fn): |
22
|
|
|
print i, node, "input(s) value(s):", [input[0] for input in fn.inputs], |
23
|
|
|
|
24
|
|
|
def inspect_outputs(i, node, fn): |
25
|
|
|
print "output(s) value(s):", [output[0] for output in fn.outputs] |
26
|
|
|
|
27
|
|
|
def default_mapper(f, dataset, *args, **kwargs): |
28
|
|
|
'''Apply a function to each element of a dataset.''' |
29
|
|
|
return [f(x, *args, **kwargs) for x in dataset] |
30
|
|
|
|
31
|
|
|
def ipcluster_mapper(client): |
32
|
|
|
'''Get a mapper from an IPython.parallel cluster client.''' |
33
|
|
|
view = client.load_balanced_view() |
34
|
|
|
def mapper(f, dataset, *args, **kwargs): |
35
|
|
|
def ff(x): |
36
|
|
|
return f(x, *args, **kwargs) |
37
|
|
|
return view.map(ff, dataset).get() |
38
|
|
|
return mapper |
39
|
|
|
|
40
|
|
|
def save_network_params(network, path): |
41
|
|
|
network.save_params(path) |
42
|
|
|
|
43
|
|
|
|
44
|
|
|
class NeuralTrainer(object): |
45
|
|
|
'''This is a base class for all trainers.''' |
46
|
|
|
|
47
|
|
|
def __init__(self, network, config=None): |
48
|
|
|
""" |
49
|
|
|
Basic neural network trainer. |
50
|
|
|
:type network: deepy.NeuralNetwork |
51
|
|
|
:type config: deepy.conf.TrainerConfig |
52
|
|
|
:return: |
53
|
|
|
""" |
54
|
|
|
super(NeuralTrainer, self).__init__() |
55
|
|
|
|
56
|
|
|
self.config = None |
57
|
|
|
if isinstance(config, TrainerConfig): |
58
|
|
|
self.config = config |
59
|
|
|
elif isinstance(config, dict): |
60
|
|
|
self.config = TrainerConfig(config) |
61
|
|
|
else: |
62
|
|
|
self.config = TrainerConfig() |
63
|
|
|
self.network = network |
64
|
|
|
|
65
|
|
|
self.network.prepare_training() |
66
|
|
|
self._setup_costs() |
67
|
|
|
|
68
|
|
|
logging.info("compile evaluation function") |
69
|
|
|
self.evaluation_func = theano.function( |
70
|
|
|
network.input_variables + network.target_variables, self.evaluation_variables, updates=network.updates, |
71
|
|
|
allow_input_downcast=True, mode=self.config.get("theano_mode", None)) |
72
|
|
|
self.learning_func = None |
73
|
|
|
|
74
|
|
|
self.validation_frequency = self.config.validation_frequency |
75
|
|
|
self.min_improvement = self.config.min_improvement |
76
|
|
|
self.patience = self.config.patience |
77
|
|
|
self._iter_callbacks = [] |
78
|
|
|
|
79
|
|
|
self.best_cost = 1e100 |
80
|
|
|
self.best_iter = 0 |
81
|
|
|
self.best_params = self.copy_params() |
82
|
|
|
self._skip_batches = 0 |
83
|
|
|
self._progress = 0 |
84
|
|
|
|
85
|
|
|
def skip(self, n_batches): |
86
|
|
|
""" |
87
|
|
|
Skip N batches in the training. |
88
|
|
|
""" |
89
|
|
|
logging.info("Skip %d batches" % n_batches) |
90
|
|
|
self._skip_batches = n_batches |
91
|
|
|
|
92
|
|
|
def _setup_costs(self): |
93
|
|
|
self.cost = self._add_regularization(self.network.cost) |
94
|
|
|
self.test_cost = self._add_regularization(self.network.test_cost) |
95
|
|
|
self.training_variables = [self.cost] |
96
|
|
|
self.training_names = ['J'] |
97
|
|
|
for name, monitor in self.network.training_monitors: |
98
|
|
|
self.training_names.append(name) |
99
|
|
|
self.training_variables.append(monitor) |
100
|
|
|
logging.info("monitor list: %s" % ",".join(self.training_names)) |
101
|
|
|
|
102
|
|
|
self.evaluation_variables = [self.test_cost] |
103
|
|
|
self.evaluation_names = ['J'] |
104
|
|
|
for name, monitor in self.network.testing_monitors: |
105
|
|
|
self.evaluation_names.append(name) |
106
|
|
|
self.evaluation_variables.append(monitor) |
107
|
|
|
|
108
|
|
|
def _add_regularization(self, cost): |
109
|
|
|
if self.config.weight_l1 > 0: |
110
|
|
|
logging.info("L1 weight regularization: %f" % self.config.weight_l1) |
111
|
|
|
cost += self.config.weight_l1 * sum(abs(w).sum() for w in self.network.parameters) |
112
|
|
|
if self.config.hidden_l1 > 0: |
113
|
|
|
logging.info("L1 hidden unit regularization: %f" % self.config.hidden_l1) |
114
|
|
|
cost += self.config.hidden_l1 * sum(abs(h).mean(axis=0).sum() for h in self.network._hidden_outputs) |
115
|
|
|
if self.config.hidden_l2 > 0: |
116
|
|
|
logging.info("L2 hidden unit regularization: %f" % self.config.hidden_l2) |
117
|
|
|
cost += self.config.hidden_l2 * sum((h * h).mean(axis=0).sum() for h in self.network._hidden_outputs) |
118
|
|
|
|
119
|
|
|
return cost |
120
|
|
|
|
121
|
|
|
def set_params(self, targets, free_params=None): |
122
|
|
|
for param, target in zip(self.network.parameters, targets): |
123
|
|
|
param.set_value(target) |
124
|
|
|
if free_params: |
125
|
|
|
for param, param_value in zip(self.network.free_parameters, free_params): |
126
|
|
|
param.set_value(param_value) |
127
|
|
|
|
128
|
|
|
def save_params(self, path): |
129
|
|
|
self.set_params(*self.best_params) |
130
|
|
|
self.network.save_params(path) |
131
|
|
|
|
132
|
|
|
def load_params(self, path, exclude_free_params=False): |
133
|
|
|
""" |
134
|
|
|
Load parameters for the training. |
135
|
|
|
This method can load free parameters and resume the training progress. |
136
|
|
|
""" |
137
|
|
|
self.network.load_params(path, exclude_free_params=exclude_free_params) |
138
|
|
|
self.best_params = self.copy_params() |
139
|
|
|
# Resume the progress |
140
|
|
|
if self.network.train_logger.progress() > 0: |
141
|
|
|
self.skip(self.network.train_logger.progress()) |
142
|
|
|
|
143
|
|
|
def copy_params(self): |
144
|
|
|
checkpoint = (map(lambda p: p.get_value().copy(), self.network.parameters), |
145
|
|
|
map(lambda p: p.get_value().copy(), self.network.free_parameters)) |
146
|
|
|
return checkpoint |
147
|
|
|
|
148
|
|
|
def add_iter_callback(self, func): |
149
|
|
|
""" |
150
|
|
|
Add a iteration callback function (receives an argument of the trainer). |
151
|
|
|
:return: |
152
|
|
|
""" |
153
|
|
|
self._iter_callbacks.append(func) |
154
|
|
|
|
155
|
|
|
def train(self, train_set, valid_set=None, test_set=None, train_size=None): |
156
|
|
|
""" |
157
|
|
|
Train the model and return costs. |
158
|
|
|
""" |
159
|
|
|
if not self.learning_func: |
160
|
|
|
raise NotImplementedError |
161
|
|
|
iteration = 0 |
162
|
|
|
while True: |
163
|
|
|
# Test |
164
|
|
|
if not iteration % self.config.test_frequency and test_set: |
165
|
|
|
try: |
166
|
|
|
self._run_test(iteration, test_set) |
167
|
|
|
except KeyboardInterrupt: |
168
|
|
|
logging.info('interrupted!') |
169
|
|
|
break |
170
|
|
|
# Validate |
171
|
|
|
if not iteration % self.validation_frequency and valid_set: |
172
|
|
|
try: |
173
|
|
|
|
174
|
|
|
if not self._run_valid(iteration, valid_set): |
175
|
|
|
logging.info('patience elapsed, bailing out') |
176
|
|
|
break |
177
|
|
|
except KeyboardInterrupt: |
178
|
|
|
logging.info('interrupted!') |
179
|
|
|
break |
180
|
|
|
# Train one step |
181
|
|
|
try: |
182
|
|
|
costs = self._run_train(iteration, train_set, train_size) |
183
|
|
|
except KeyboardInterrupt: |
184
|
|
|
logging.info('interrupted!') |
185
|
|
|
break |
186
|
|
|
# Check costs |
187
|
|
|
if np.isnan(costs[0][1]): |
188
|
|
|
logging.info("NaN detected in costs, rollback to last parameters") |
189
|
|
|
self.set_params(*self.checkpoint) |
190
|
|
|
else: |
191
|
|
|
iteration += 1 |
192
|
|
|
self.network.epoch_callback() |
193
|
|
|
|
194
|
|
|
yield dict(costs) |
195
|
|
|
|
196
|
|
|
if valid_set and self.config.get("save_best_parameters", True): |
197
|
|
|
self.set_params(*self.best_params) |
198
|
|
|
if test_set: |
199
|
|
|
self._run_test(-1, test_set) |
200
|
|
|
|
201
|
|
|
def _run_test(self, iteration, test_set): |
202
|
|
|
""" |
203
|
|
|
Run on test iteration. |
204
|
|
|
""" |
205
|
|
|
costs = self.test_step(test_set) |
206
|
|
|
info = ' '.join('%s=%.2f' % el for el in costs) |
207
|
|
|
message = "test (iter=%i) %s" % (iteration + 1, info) |
208
|
|
|
logging.info(message) |
209
|
|
|
self.network.train_logger.record(message) |
210
|
|
|
|
211
|
|
|
def _run_train(self, iteration, train_set, train_size=None): |
212
|
|
|
""" |
213
|
|
|
Run one training iteration. |
214
|
|
|
""" |
215
|
|
|
costs = self.train_step(train_set, train_size) |
216
|
|
|
|
217
|
|
|
if not iteration % self.config.monitor_frequency: |
218
|
|
|
info = " ".join("%s=%.2f" % item for item in costs) |
219
|
|
|
message = "monitor (iter=%i) %s" % (iteration + 1, info) |
220
|
|
|
logging.info(message) |
221
|
|
|
self.network.train_logger.record(message) |
222
|
|
|
return costs |
223
|
|
|
|
224
|
|
|
def _run_valid(self, iteration, valid_set, dry_run=False): |
225
|
|
|
""" |
226
|
|
|
Run one valid iteration, return true if to continue training. |
227
|
|
|
""" |
228
|
|
|
costs = self.valid_step(valid_set) |
229
|
|
|
# this is the same as: (J_i - J_f) / J_i > min improvement |
230
|
|
|
_, J = costs[0] |
231
|
|
|
marker = "" |
232
|
|
|
if self.best_cost - J > self.best_cost * self.min_improvement: |
233
|
|
|
# save the best cost and parameters |
234
|
|
|
self.best_params = self.copy_params() |
235
|
|
|
marker = ' *' |
236
|
|
|
if not dry_run: |
237
|
|
|
self.best_cost = J |
238
|
|
|
self.best_iter = iteration |
239
|
|
|
|
240
|
|
|
if self.config.auto_save: |
241
|
|
|
self.network.train_logger.record_progress(self._progress) |
242
|
|
|
self.network.save_params(self.config.auto_save, new_thread=True) |
243
|
|
|
|
244
|
|
|
info = ' '.join('%s=%.2f' % el for el in costs) |
245
|
|
|
iter_str = "iter=%d" % (iteration + 1) |
246
|
|
|
if dry_run: |
247
|
|
|
iter_str = "dryrun" + " " * (len(iter_str) - 6) |
248
|
|
|
message = "valid (%s) %s%s" % (iter_str, info, marker) |
249
|
|
|
logging.info(message) |
250
|
|
|
self.network.train_logger.record(message) |
251
|
|
|
self.checkpoint = self.copy_params() |
252
|
|
|
return iteration - self.best_iter < self.patience |
253
|
|
|
|
254
|
|
|
def test_step(self, test_set): |
255
|
|
|
costs = list(zip( |
256
|
|
|
self.evaluation_names, |
257
|
|
|
np.mean([self.evaluation_func(*x) for x in test_set], axis=0))) |
258
|
|
|
return costs |
259
|
|
|
|
260
|
|
|
def valid_step(self, valid_set): |
261
|
|
|
costs = list(zip( |
262
|
|
|
self.evaluation_names, |
263
|
|
|
np.mean([self.evaluation_func(*x) for x in valid_set], axis=0))) |
264
|
|
|
return costs |
265
|
|
|
|
266
|
|
|
def train_step(self, train_set, train_size=None): |
267
|
|
|
dirty_trick_times = 0 |
268
|
|
|
network_callback = bool(self.network.training_callbacks) |
269
|
|
|
trainer_callback = bool(self._iter_callbacks) |
270
|
|
|
cost_matrix = [] |
271
|
|
|
self._progress = 0 |
272
|
|
|
|
273
|
|
|
for x in train_set: |
274
|
|
|
if self._skip_batches == 0: |
275
|
|
|
if dirty_trick_times > 0: |
276
|
|
|
cost_x = self.learning_func(*[t[:(t.shape[0]/2)] for t in x]) |
277
|
|
|
cost_matrix.append(cost_x) |
278
|
|
|
cost_x = self.learning_func(*[t[(t.shape[0]/2):] for t in x]) |
279
|
|
|
dirty_trick_times -= 1 |
280
|
|
|
else: |
281
|
|
|
try: |
282
|
|
|
cost_x = self.learning_func(*x) |
283
|
|
|
except MemoryError: |
284
|
|
|
logging.info("Memory error was detected, perform dirty trick 30 times") |
285
|
|
|
dirty_trick_times = 30 |
286
|
|
|
# Dirty trick |
287
|
|
|
cost_x = self.learning_func(*[t[:(t.shape[0]/2)] for t in x]) |
288
|
|
|
cost_matrix.append(cost_x) |
289
|
|
|
cost_x = self.learning_func(*[t[(t.shape[0]/2):] for t in x]) |
290
|
|
|
cost_matrix.append(cost_x) |
291
|
|
|
if network_callback: |
292
|
|
|
self.last_score = cost_x[0] |
293
|
|
|
self.network.training_callback() |
294
|
|
|
if trainer_callback: |
295
|
|
|
self.last_score = cost_x[0] |
296
|
|
|
for func in self._iter_callbacks: |
297
|
|
|
func(self) |
298
|
|
|
else: |
299
|
|
|
self._skip_batches -= 1 |
300
|
|
|
if train_size: |
301
|
|
|
self._progress += 1 |
302
|
|
|
sys.stdout.write("\r> %d%%" % (self._progress * 100 / train_size)) |
303
|
|
|
sys.stdout.flush() |
304
|
|
|
self._progress = 0 |
305
|
|
|
|
306
|
|
|
if train_size: |
307
|
|
|
sys.stdout.write("\r") |
308
|
|
|
sys.stdout.flush() |
309
|
|
|
costs = list(zip(self.training_names, np.mean(cost_matrix, axis=0))) |
310
|
|
|
return costs |
311
|
|
|
|
312
|
|
|
def run(self, train_set, valid_set=None, test_set=None, train_size=None, controllers=None): |
313
|
|
|
""" |
314
|
|
|
Run until the end. |
315
|
|
|
""" |
316
|
|
|
if isinstance(train_set, Dataset): |
317
|
|
|
dataset = train_set |
318
|
|
|
train_set = dataset.train_set() |
319
|
|
|
valid_set = dataset.valid_set() |
320
|
|
|
test_set = dataset.test_set() |
321
|
|
|
train_size = dataset.train_size() |
322
|
|
|
|
323
|
|
|
timer = Timer() |
324
|
|
|
for _ in self.train(train_set, valid_set=valid_set, test_set=test_set, train_size=train_size): |
325
|
|
|
if controllers: |
326
|
|
|
ending = False |
327
|
|
|
for controller in controllers: |
328
|
|
|
if hasattr(controller, 'invoke') and controller.invoke(): |
329
|
|
|
ending = True |
330
|
|
|
if ending: |
331
|
|
|
break |
332
|
|
|
timer.report() |
333
|
|
|
return |
334
|
|
|
|
335
|
|
|
class GeneralNeuralTrainer(NeuralTrainer): |
336
|
|
|
""" |
337
|
|
|
General neural network trainer. |
338
|
|
|
""" |
339
|
|
|
def __init__(self, network, config=None, method=None): |
340
|
|
|
|
341
|
|
|
if method: |
342
|
|
|
logging.info("changing optimization method to '%s'" % method) |
343
|
|
|
if not config: |
344
|
|
|
config = TrainerConfig() |
345
|
|
|
elif isinstance(config, dict): |
346
|
|
|
config = TrainerConfig(config) |
347
|
|
|
config.method = method |
348
|
|
|
|
349
|
|
|
super(GeneralNeuralTrainer, self).__init__(network, config) |
350
|
|
|
|
351
|
|
|
logging.info('compiling %s learning function', self.__class__.__name__) |
352
|
|
|
|
353
|
|
|
network_updates = list(network.updates) + list(network.training_updates) |
354
|
|
|
learning_updates = list(self.learning_updates()) |
355
|
|
|
update_list = network_updates + learning_updates |
356
|
|
|
|
357
|
|
|
logging.info("network updates: %s" % " ".join(map(str, [x[0] for x in network_updates]))) |
358
|
|
|
logging.info("learning updates: %s" % " ".join(map(str, [x[0] for x in learning_updates]))) |
359
|
|
|
|
360
|
|
|
|
361
|
|
|
variables = network.input_variables + network.target_variables |
362
|
|
|
givens = None |
363
|
|
|
|
364
|
|
|
self.learning_func = theano.function( |
365
|
|
|
variables, |
366
|
|
|
map(lambda v: theano.Out(v, borrow=True), self.training_variables), |
367
|
|
|
updates=update_list, allow_input_downcast=True, |
368
|
|
|
mode=self.config.get("theano_mode", None), |
369
|
|
|
givens=givens) |
370
|
|
|
|
371
|
|
|
|
372
|
|
|
def learning_updates(self): |
373
|
|
|
""" |
374
|
|
|
Return updates in the training. |
375
|
|
|
""" |
376
|
|
|
params = self.network.parameters |
377
|
|
|
# Freeze parameters |
378
|
|
|
if self.config.fixed_parameters: |
379
|
|
|
logging.info("fixed parameters: %s" % ", ".join(map(str, self.config.fixed_parameters))) |
380
|
|
|
params = [p for p in params if p not in self.config.fixed_parameters] |
381
|
|
|
gradients = T.grad(self.cost, params) |
382
|
|
|
updates, free_parameters = optimize_updates(params, gradients, self.config) |
383
|
|
|
self.network.free_parameters.extend(free_parameters) |
384
|
|
|
logging.info("Added %d free parameters for optimization" % len(free_parameters)) |
385
|
|
|
return updates |
386
|
|
|
|
387
|
|
|
|
388
|
|
|
class SGDTrainer(GeneralNeuralTrainer): |
389
|
|
|
""" |
390
|
|
|
SGD trainer. |
391
|
|
|
""" |
392
|
|
|
def __init__(self, network, config=None): |
393
|
|
|
super(SGDTrainer, self).__init__(network, config, "SGD") |
394
|
|
|
|
395
|
|
|
class AdaDeltaTrainer(GeneralNeuralTrainer): |
396
|
|
|
""" |
397
|
|
|
AdaDelta trainer. |
398
|
|
|
""" |
399
|
|
|
def __init__(self, network, config=None): |
400
|
|
|
super(AdaDeltaTrainer, self).__init__(network, config, "ADADELTA") |
401
|
|
|
|
402
|
|
|
|
403
|
|
|
class AdaGradTrainer(GeneralNeuralTrainer): |
404
|
|
|
""" |
405
|
|
|
AdaGrad trainer. |
406
|
|
|
""" |
407
|
|
|
def __init__(self, network, config=None): |
408
|
|
|
super(AdaGradTrainer, self).__init__(network, config, "ADAGRAD") |
409
|
|
|
|
410
|
|
|
class FineTuningAdaGradTrainer(GeneralNeuralTrainer): |
411
|
|
|
""" |
412
|
|
|
AdaGrad trainer. |
413
|
|
|
""" |
414
|
|
|
def __init__(self, network, config=None): |
415
|
|
|
super(FineTuningAdaGradTrainer, self).__init__(network, config, "FINETUNING_ADAGRAD") |
416
|
|
|
|
417
|
|
|
class AdamTrainer(GeneralNeuralTrainer): |
418
|
|
|
""" |
419
|
|
|
AdaGrad trainer. |
420
|
|
|
""" |
421
|
|
|
def __init__(self, network, config=None): |
422
|
|
|
super(AdamTrainer, self).__init__(network, config, "ADAM") |
423
|
|
|
|
424
|
|
|
class RmspropTrainer(GeneralNeuralTrainer): |
425
|
|
|
""" |
426
|
|
|
RmsProp trainer. |
427
|
|
|
""" |
428
|
|
|
def __init__(self, network, config=None): |
429
|
|
|
super(RmspropTrainer, self).__init__(network, config, "RMSPROP") |
430
|
|
|
|
431
|
|
|
class MomentumTrainer(GeneralNeuralTrainer): |
432
|
|
|
""" |
433
|
|
|
Momentum trainer. |
434
|
|
|
""" |
435
|
|
|
def __init__(self, network, config=None): |
436
|
|
|
super(MomentumTrainer, self).__init__(network, config, "MOMENTUM") |
437
|
|
|
|
438
|
|
|
|
439
|
|
|
class SSGD2Trainer(NeuralTrainer): |
440
|
|
|
""" |
441
|
|
|
Optimization class of SSGD. |
442
|
|
|
""" |
443
|
|
|
|
444
|
|
|
def __init__(self, network, config=None): |
445
|
|
|
super(SSGD2Trainer, self).__init__(network, config) |
446
|
|
|
|
447
|
|
|
self.learning_rate = self.config.learning_rate |
448
|
|
|
|
449
|
|
|
logging.info('compiling %s learning function', self.__class__.__name__) |
450
|
|
|
|
451
|
|
|
network_updates = list(network.updates) + list(network.learning_updates) |
452
|
|
|
learning_updates = list(self.learning_updates()) |
453
|
|
|
update_list = network_updates + learning_updates |
454
|
|
|
logging.info("network updates: %s" % " ".join(map(str, [x[0] for x in network_updates]))) |
455
|
|
|
logging.info("learning updates: %s" % " ".join(map(str, [x[0] for x in learning_updates]))) |
456
|
|
|
|
457
|
|
|
self.learning_func = theano.function( |
458
|
|
|
network.inputs, |
459
|
|
|
self.training_variables, |
460
|
|
|
updates=update_list, allow_input_downcast=True, mode=self.config.get("theano_mode", None)) |
461
|
|
|
|
462
|
|
|
def ssgd2(self, loss, all_params, learning_rate=0.01, chaos_energy=0.01, alpha=0.9): |
463
|
|
|
from theano.tensor.shared_randomstreams import RandomStreams |
464
|
|
|
|
465
|
|
|
chaos_energy = T.constant(chaos_energy, dtype="float32") |
466
|
|
|
alpha = T.constant(alpha, dtype="float32") |
467
|
|
|
learning_rate = T.constant(learning_rate, dtype="float32") |
468
|
|
|
|
469
|
|
|
srng = RandomStreams(seed=3) |
470
|
|
|
updates = [] |
471
|
|
|
all_grads = T.grad(loss, all_params) |
472
|
|
|
for p, g in zip(all_params, all_grads): |
473
|
|
|
rand_v = (srng.uniform(p.get_value().shape)*2 - 1) * chaos_energy |
474
|
|
|
g_ratio_vec = g / g.norm(L=2) |
475
|
|
|
ratio_sum = theano.shared(np.ones(np.array(p.get_value().shape), dtype="float32"), name="ssgd2_r_sum_%s" % p.name) |
476
|
|
|
abs_ratio_sum = T.abs_(ratio_sum) |
477
|
|
|
updates.append((ratio_sum, ratio_sum * alpha + (1 - alpha ) * g_ratio_vec)) |
478
|
|
|
updates.append((p, p - learning_rate*((abs_ratio_sum)*g + (1-abs_ratio_sum)*rand_v))) |
479
|
|
|
return updates |
480
|
|
|
|
481
|
|
|
def learning_updates(self): |
482
|
|
|
return self.ssgd2(self.cost, self.network.parameters, learning_rate=self.learning_rate) |
483
|
|
|
|
484
|
|
|
class FakeTrainer(NeuralTrainer): |
485
|
|
|
""" |
486
|
|
|
Fake Trainer does nothing. |
487
|
|
|
""" |
488
|
|
|
|
489
|
|
|
def __init__(self, network, config=None): |
490
|
|
|
super(FakeTrainer, self).__init__(network, config) |
491
|
|
|
|
492
|
|
|
self.learning_rate = self.config.learning_rate |
493
|
|
|
|
494
|
|
|
logging.info('compiling %s learning function', self.__class__.__name__) |
495
|
|
|
|
496
|
|
|
network_updates = list(network.updates) + list(network.learning_updates) |
497
|
|
|
learning_updates = [] |
498
|
|
|
update_list = network_updates + learning_updates |
499
|
|
|
logging.info("network updates: %s" % " ".join(map(str, [x[0] for x in network_updates]))) |
500
|
|
|
logging.info("learning updates: %s" % " ".join(map(str, [x[0] for x in learning_updates]))) |
501
|
|
|
|
502
|
|
|
self.learning_func = theano.function( |
503
|
|
|
network.inputs, |
504
|
|
|
self.training_variables, |
505
|
|
|
updates=update_list, allow_input_downcast=True, mode=self.config.get("theano_mode", None)) |
506
|
|
|
|