Completed
Push — master ( 394368...090fba )
by Raphael
01:33
created

NeuralTrainer._run_valid()   A

Complexity

Conditions 4

Size

Total Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 3
Bugs 0 Features 0
Metric Value
cc 4
c 3
b 0
f 0
dl 0
loc 20
rs 9.2
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
import sys
4
import time
5
import numpy as np
6
import theano
7
8
from ..conf import TrainerConfig
9
from ..core import env, runtime
10
from ..utils import Timer
11
from ..dataset import Dataset
12
from controllers import TrainingController
13
14
from abc import ABCMeta, abstractmethod
15
16
from logging import getLogger
17
logging = getLogger("trainer")
18
19
class NeuralTrainer(object):
20
    """
21
    A base class for all trainers.
22
    """
23
    __metaclass__ = ABCMeta
24
25
    def __init__(self, network, config=None, validator=None, annealer=None):
26
        """
27
        Basic neural network trainer.
28
        :type network: deepy.NeuralNetwork
29
        :type config: deepy.conf.TrainerConfig
30
        :return:
31
        """
32
        super(NeuralTrainer, self).__init__()
33
34
        self.config = None
35
        if isinstance(config, TrainerConfig):
36
            self.config = config
37
        elif isinstance(config, dict):
38
            self.config = TrainerConfig(config)
39
        else:
40
            self.config = TrainerConfig()
41
        if type(self.config.learning_rate) == float:
42
            self.config.learning_rate = np.array(self.config.learning_rate, dtype=env.FLOATX)
43
        # Model and network all refer to the computational graph
44
        self.model = self.network = network
45
46
        self.network.prepare_training()
47
        self._setup_costs()
48
49
        self.evaluation_func = None
50
51
        self.validation_frequency = self.config.validation_frequency
52
        self.min_improvement = self.config.min_improvement
53
        self.patience = self.config.patience
54
55
        self._iter_controllers = []
56
        self._epoch_controllers = []
57
        if annealer:
58
            annealer.bind(self)
59
            self._epoch_controllers.append(annealer)
60
        if validator:
61
            validator.bind(self)
62
            self._iter_controllers.append(validator)
63
64
        self.best_cost = 1e100
65
        self.best_epoch = 0
66
        self.best_params = self.copy_params()
67
        self._skip_batches = 0
68
        self._skip_epochs = 0
69
        self._progress = 0
70
        self.last_cost = 0
71
        self.last_run_costs = None
72
        self._report_time = True
73
        self._epoch = 0
74
75
        self._current_train_set = None
76
        self._current_valid_set = None
77
        self._current_test_set = None
78
        self._ended = False
79
80
81
    def _compile_evaluation_func(self):
82
        if not self.evaluation_func:
83
            logging.info("compile evaluation function")
84
            self.evaluation_func = theano.function(
85
                self.network.input_variables + self.network.target_variables,
86
                self.evaluation_variables,
87
                updates=self.network.updates,
88
                allow_input_downcast=True, mode=self.config.get("theano_mode", None))
89
90
    def skip(self, n_batches, n_epochs=0):
91
        """
92
        Skip N batches in the training.
93
        """
94
        logging.info("skip %d epochs and %d batches" % (n_epochs, n_batches))
95
        self._skip_batches = n_batches
96
        self._skip_epochs = n_epochs
97
98
    def epoch(self):
99
        """
100
        Get current epoch.
101
        """
102
        return self._epoch
103
104
    def _setup_costs(self):
105
        self.cost = self._add_regularization(self.network.cost)
106
        self.test_cost = self._add_regularization(self.network.test_cost)
107
        self.training_variables = [self.cost]
108
        self.training_names = ['J']
109
        for name, monitor in self.network.training_monitors:
110
            self.training_names.append(name)
111
            self.training_variables.append(monitor)
112
        logging.info("monitor list: %s" % ",".join(self.training_names))
113
114
        self.evaluation_variables = [self.test_cost]
115
        self.evaluation_names = ['J']
116
        for name, monitor in self.network.testing_monitors:
117
            self.evaluation_names.append(name)
118
            self.evaluation_variables.append(monitor)
119
120
    def _add_regularization(self, cost):
121
        if self.config.weight_l1 > 0:
122
            logging.info("L1 weight regularization: %f" % self.config.weight_l1)
123
            cost += self.config.weight_l1 * sum(abs(w).sum() for w in self.network.parameters)
124
        if self.config.hidden_l1 > 0:
125
            logging.info("L1 hidden unit regularization: %f" % self.config.hidden_l1)
126
            cost += self.config.hidden_l1 * sum(abs(h).mean(axis=0).sum() for h in self.network._hidden_outputs)
127
        if self.config.hidden_l2 > 0:
128
            logging.info("L2 hidden unit regularization: %f" % self.config.hidden_l2)
129
            cost += self.config.hidden_l2 * sum((h * h).mean(axis=0).sum() for h in self.network._hidden_outputs)
130
131
        return cost
132
133
    def set_params(self, targets, free_params=None):
134
        for param, target in zip(self.network.parameters, targets):
135
            param.set_value(target)
136
        if free_params:
137
            for param, param_value in zip(self.network.free_parameters, free_params):
138
                param.set_value(param_value)
139
140
    def save_params(self, path):
141
        self.set_params(*self.best_params)
142
        self.network.save_params(path)
143
144
    def load_params(self, path, exclude_free_params=False):
145
        """
146
        Load parameters for the training.
147
        This method can load free parameters and resume the training progress.
148
        """
149
        self.network.load_params(path, exclude_free_params=exclude_free_params)
150
        self.best_params = self.copy_params()
151
        # Resume the progress
152
        if self.network.train_logger.progress() > 0 or self.network.train_logger.epoch() > 0:
153
            self.skip(self.network.train_logger.progress(), self.network.train_logger.epoch() - 1)
154
155
    def copy_params(self):
156
        checkpoint = (map(lambda p: p.get_value().copy(), self.network.parameters),
157
                      map(lambda p: p.get_value().copy(), self.network.free_parameters))
158
        return checkpoint
159
160
    def exit(self):
161
        """
162
        End the training.
163
        """
164
        self._ended = True
165
166
    def add_iter_controllers(self, *controllers):
167
        """
168
        Add iteration callbacks function (receives an argument of the trainer).
169
        :param controllers: can be a `TrainingController` or a function.
170
        :type funcs: list of TrainingContoller
171
        """
172
        for controller in controllers:
173
            if isinstance(controller, TrainingController):
174
                controller.bind(self)
175
            self._iter_controllers.append(controller)
176
177
    def add_epoch_controllers(self, *controllers):
178
        """
179
        Add epoch callbacks function.
180
        :param controllers: can be a `TrainingController` or a function.
181
        """
182
        for controller in controllers:
183
            if isinstance(controller, TrainingController):
184
                controller.bind(self)
185
            self._epoch_controllers.append(controller)
186
187
    def train(self, train_set, valid_set=None, test_set=None, train_size=None):
188
        """
189
        Train the model and return costs.
190
        """
191
        self._epoch = 0
192
        while True:
193
            if self._skip_epochs > 0:
194
                logging.info("skipping one epoch ...")
195
                self._skip_epochs -= 1
196
                self._epoch += 1
197
                yield None
198
                continue
199
            # Test
200
            if not self._epoch % self.config.test_frequency and test_set:
201
                try:
202
                    self._run_test(self._epoch, test_set)
203
                except KeyboardInterrupt:
204
                    logging.info('interrupted!')
205
                    break
206
            # Validate
207
            if not self._epoch % self.validation_frequency and valid_set:
208
                try:
209
210
                    if not self._run_valid(self._epoch, valid_set):
211
                        logging.info('patience elapsed, bailing out')
212
                        break
213
                except KeyboardInterrupt:
214
                    logging.info('interrupted!')
215
                    break
216
            # Train one step
217
218
            try:
219
                costs = self._run_train(self._epoch, train_set, train_size)
220
            except KeyboardInterrupt:
221
                logging.info('interrupted!')
222
                break
223
            # Check costs
224
            if np.isnan(costs[0][1]):
225
                logging.info("NaN detected in costs, rollback to last parameters")
226
                self.set_params(*self.checkpoint)
227
            else:
228
                self._epoch += 1
229
                self.network.epoch_callback()
230
231
            yield dict(costs)
232
233
        if valid_set and self.config.get("save_best_parameters", True):
234
            self.set_params(*self.best_params)
235
        if test_set:
236
            self._run_test(-1, test_set)
237
238
    @abstractmethod
239
    def learn(self, *variables):
240
        """
241
        Update the parameters and return the cost with given data points.
242
        :param variables:
243
        :return:
244
        """
245
246
    def _run_test(self, epoch, test_set):
247
        """
248
        Run on test epoch.
249
        """
250
        costs = self.test_step(test_set)
251
        self.report(dict(costs), "test", epoch)
252
        self.last_run_costs = costs
253
254
    def _run_train(self, epoch, train_set, train_size=None):
255
        """
256
        Run one training iteration.
257
        """
258
        self.network.train_logger.record_epoch(epoch + 1)
259
        costs = self.train_step(train_set, train_size)
260
        if not epoch % self.config.monitor_frequency:
261
            self.report(dict(costs), "train", epoch)
262
        self.last_run_costs = costs
263
        return costs
264
265
    def _run_valid(self, epoch, valid_set, dry_run=False, save_path=None):
266
        """
267
        Run one valid iteration, return true if to continue training.
268
        """
269
        costs = self.valid_step(valid_set)
270
        # this is the same as: (J_i - J_f) / J_i > min improvement
271
        _, J = costs[0]
272
        new_best = False
273
        if self.best_cost - J > self.best_cost * self.min_improvement:
274
            # save the best cost and parameters
275
            self.best_params = self.copy_params()
276
            new_best = True
277
            if not dry_run:
278
                self.best_cost = J
279
                self.best_epoch = epoch
280
            self.save_checkpoint(save_path)
281
282
        self.report(dict(costs), type="valid", epoch=0 if dry_run else epoch, new_best=new_best)
283
        self.last_run_costs = costs
284
        return epoch - self.best_epoch < self.patience
285
286
    def save_checkpoint(self, save_path=None):
287
        save_path = save_path if save_path else self.config.auto_save
288
        self.checkpoint = self.copy_params()
289
        if save_path and self._skip_batches == 0:
290
            self.network.train_logger.record_progress(self._progress)
291
            self.network.save_params(save_path, new_thread=True)
292
293
    def report(self, score_map, type="valid", epoch=-1, new_best=False):
294
        """
295
        Report the scores and record them in the log.
296
        """
297
        type_str = type
298
        if len(type_str) < 5:
299
            type_str += " " * (5 - len(type_str))
300
        info = " ".join("%s=%.2f" % el for el in score_map.items())
301
        current_epoch = epoch if epoch > 0 else self.current_epoch()
302
        epoch_str = "epoch={}".format(current_epoch + 1)
303
        if epoch < 0:
304
            epoch_str = "dryrun"
305
            sys.stdout.write("\r")
306
            sys.stdout.flush()
307
        marker = " *" if new_best else ""
308
        message = "{} ({}) {}{}".format(type_str, epoch_str, info, marker)
309
        self.network.train_logger.record(message)
310
        logging.info(message)
311
312
    def test_step(self, test_set):
313
        runtime.switch_training(False)
314
        self._compile_evaluation_func()
315
        costs = list(zip(
316
            self.evaluation_names,
317
            np.mean([self.evaluation_func(*x) for x in test_set], axis=0)))
318
        return costs
319
320
    def valid_step(self, valid_set):
321
        runtime.switch_training(False)
322
        self._compile_evaluation_func()
323
        costs = list(zip(
324
            self.evaluation_names,
325
            np.mean([self.evaluation_func(*x) for x in valid_set], axis=0)))
326
        return costs
327
328
    def train_step(self, train_set, train_size=None):
329
        dirty_trick_times = 0
330
        network_callback = bool(self.network.training_callbacks)
331
        trainer_callback = bool(self._iter_controllers)
332
        cost_matrix = []
333
        exec_count = 0
334
        start_time = time.time()
335
        self._compile_time = 0
336
        self._progress = 0
337
338
        for x in train_set:
339
            runtime.switch_training(True)
340
            if self._skip_batches == 0:
341
342
                if dirty_trick_times > 0:
343
                    cost_x = self.learn(*[t[:(t.shape[0]/2)] for t in x])
344
                    cost_matrix.append(cost_x)
345
                    cost_x = self.learn(*[t[(t.shape[0]/2):] for t in x])
346
                    dirty_trick_times -= 1
347
                else:
348
                    try:
349
                        cost_x = self.learn(*x)
350
                    except MemoryError:
351
                        logging.info("Memory error was detected, perform dirty trick 30 times")
352
                        dirty_trick_times = 30
353
                        # Dirty trick
354
                        cost_x = self.learn(*[t[:(t.shape[0]/2)] for t in x])
355
                        cost_matrix.append(cost_x)
356
                        cost_x = self.learn(*[t[(t.shape[0]/2):] for t in x])
357
                cost_matrix.append(cost_x)
358
                self.last_cost = cost_x[0]
359
                exec_count += 1
360
                if network_callback:
361
                    self.network.training_callback()
362
                if trainer_callback:
363
                    for func in self._iter_controllers:
364
                        if isinstance(func, TrainingController):
365
                            func.invoke()
366
                        else:
367
                            func(self)
368
            else:
369
                self._skip_batches -= 1
370
            if train_size:
371
                self._progress += 1
372
                spd = float(exec_count) / (time.time() - start_time - self._compile_time)
373
                sys.stdout.write("\x1b[2K\r> %d%% | J=%.2f | spd=%.2f batch/s" % (self._progress * 100 / train_size, self.last_cost, spd))
374
                sys.stdout.flush()
375
        self._progress = 0
376
377
        if train_size:
378
            sys.stdout.write("\r")
379
            sys.stdout.flush()
380
        costs = list(zip(self.training_names, np.mean(cost_matrix, axis=0)))
381
        return costs
382
383
    def current_epoch(self):
384
        """
385
        Get current epoch.
386
        """
387
        return self._epoch
388
389
390
    def get_data(self, data_split="train"):
391
        """
392
        Get specified split of data.
393
        """
394
        if data_split == 'train':
395
            return self._current_train_set
396
        elif data_split == 'valid':
397
            return self._current_valid_set
398
        elif data_split == 'test':
399
            return self._current_test_set
400
        else:
401
            return None
402
403
    def run(self, train_set, valid_set=None, test_set=None, train_size=None, epoch_controllers=None):
404
        """
405
        Run until the end.
406
        :param epoch_controllers: deprecated
407
        """
408
        epoch_controllers = epoch_controllers if epoch_controllers else []
409
        epoch_controllers += self._epoch_controllers
410
        if isinstance(train_set, Dataset):
411
            dataset = train_set
412
            train_set = dataset.train_set()
413
            valid_set = dataset.valid_set()
414
            test_set = dataset.test_set()
415
            train_size = dataset.train_size()
416
        self._current_train_set = train_set
417
        self._current_valid_set = valid_set
418
        self._current_test_set = test_set
419
        if epoch_controllers:
420
            for controller in epoch_controllers:
421
                controller.bind(self)
422
        timer = Timer()
423
        for _ in self.train(train_set, valid_set=valid_set, test_set=test_set, train_size=train_size):
424
            if epoch_controllers:
425
                for controller in epoch_controllers:
426
                    controller.invoke()
427
            if self._ended:
428
                break
429
        if self._report_time:
430
            timer.report()
431