Completed
Push — master ( c48f07...4ce1c1 )
by Raphael
01:33
created

NeuralTrainer.epoch()   A

Complexity

Conditions 1

Size

Total Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
c 0
b 0
f 0
dl 0
loc 5
rs 9.4285
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
import sys
4
import numpy as np
5
import theano
6
7
from deepy.conf import TrainerConfig
8
from deepy.dataset import Dataset
9
from deepy.utils import Timer
10
11
from abc import ABCMeta, abstractmethod
12
13
from logging import getLogger
14
logging = getLogger("trainer")
15
16
class NeuralTrainer(object):
17
    """
18
    A base class for all trainers.
19
    """
20
    __metaclass__ = ABCMeta
21
22
    def __init__(self, network, config=None):
23
        """
24
        Basic neural network trainer.
25
        :type network: deepy.NeuralNetwork
26
        :type config: deepy.conf.TrainerConfig
27
        :return:
28
        """
29
        super(NeuralTrainer, self).__init__()
30
31
        self.config = None
32
        if isinstance(config, TrainerConfig):
33
            self.config = config
34
        elif isinstance(config, dict):
35
            self.config = TrainerConfig(config)
36
        else:
37
            self.config = TrainerConfig()
38
        # Model and network all refer to the computational graph
39
        self.model = self.network = network
40
41
        self.network.prepare_training()
42
        self._setup_costs()
43
44
        self.evaluation_func = None
45
46
        self.validation_frequency = self.config.validation_frequency
47
        self.min_improvement = self.config.min_improvement
48
        self.patience = self.config.patience
49
        self._iter_callbacks = []
50
51
        self.best_cost = 1e100
52
        self.best_epoch = 0
53
        self.best_params = self.copy_params()
54
        self._skip_batches = 0
55
        self._skip_epochs = 0
56
        self._progress = 0
57
        self.last_cost = 0
58
        self.last_run_costs = None
59
        self._report_time = True
60
        self._epoch = 0
61
62
    def _compile_evaluation_func(self):
63
        if not self.evaluation_func:
64
            logging.info("compile evaluation function")
65
            self.evaluation_func = theano.function(
66
                self.network.input_variables + self.network.target_variables,
67
                self.evaluation_variables,
68
                updates=self.network.updates,
69
                allow_input_downcast=True, mode=self.config.get("theano_mode", None))
70
71
    def skip(self, n_batches, n_epochs=0):
72
        """
73
        Skip N batches in the training.
74
        """
75
        logging.info("skip %d epochs and %d batches" % (n_epochs, n_batches))
76
        self._skip_batches = n_batches
77
        self._skip_epochs = n_epochs
78
79
    def epoch(self):
80
        """
81
        Get current epoch.
82
        """
83
        return self._epoch
84
85
    def _setup_costs(self):
86
        self.cost = self._add_regularization(self.network.cost)
87
        self.test_cost = self._add_regularization(self.network.test_cost)
88
        self.training_variables = [self.cost]
89
        self.training_names = ['J']
90
        for name, monitor in self.network.training_monitors:
91
            self.training_names.append(name)
92
            self.training_variables.append(monitor)
93
        logging.info("monitor list: %s" % ",".join(self.training_names))
94
95
        self.evaluation_variables = [self.test_cost]
96
        self.evaluation_names = ['J']
97
        for name, monitor in self.network.testing_monitors:
98
            self.evaluation_names.append(name)
99
            self.evaluation_variables.append(monitor)
100
101
    def _add_regularization(self, cost):
102
        if self.config.weight_l1 > 0:
103
            logging.info("L1 weight regularization: %f" % self.config.weight_l1)
104
            cost += self.config.weight_l1 * sum(abs(w).sum() for w in self.network.parameters)
105
        if self.config.hidden_l1 > 0:
106
            logging.info("L1 hidden unit regularization: %f" % self.config.hidden_l1)
107
            cost += self.config.hidden_l1 * sum(abs(h).mean(axis=0).sum() for h in self.network._hidden_outputs)
108
        if self.config.hidden_l2 > 0:
109
            logging.info("L2 hidden unit regularization: %f" % self.config.hidden_l2)
110
            cost += self.config.hidden_l2 * sum((h * h).mean(axis=0).sum() for h in self.network._hidden_outputs)
111
112
        return cost
113
114
    def set_params(self, targets, free_params=None):
115
        for param, target in zip(self.network.parameters, targets):
116
            param.set_value(target)
117
        if free_params:
118
            for param, param_value in zip(self.network.free_parameters, free_params):
119
                param.set_value(param_value)
120
121
    def save_params(self, path):
122
        self.set_params(*self.best_params)
123
        self.network.save_params(path)
124
125
    def load_params(self, path, exclude_free_params=False):
126
        """
127
        Load parameters for the training.
128
        This method can load free parameters and resume the training progress.
129
        """
130
        self.network.load_params(path, exclude_free_params=exclude_free_params)
131
        self.best_params = self.copy_params()
132
        # Resume the progress
133
        if self.network.train_logger.progress() > 0 or self.network.train_logger.epoch() > 0:
134
            self.skip(self.network.train_logger.progress(), self.network.train_logger.epoch() - 1)
135
136
    def copy_params(self):
137
        checkpoint = (map(lambda p: p.get_value().copy(), self.network.parameters),
138
                      map(lambda p: p.get_value().copy(), self.network.free_parameters))
139
        return checkpoint
140
141
    def add_iter_callback(self, func):
142
        """
143
        Add a iteration callback function (receives an argument of the trainer).
144
        :return:
145
        """
146
        self._iter_callbacks.append(func)
147
148
    def train(self, train_set, valid_set=None, test_set=None, train_size=None):
149
        """
150
        Train the model and return costs.
151
        """
152
        self._epoch = 0
153
        while True:
154
            if self._skip_epochs > 0:
155
                logging.info("skipping one epoch ...")
156
                self._skip_epochs -= 1
157
                self._epoch += 1
158
                yield None
159
                continue
160
            # Test
161
            if not self._epoch % self.config.test_frequency and test_set:
162
                try:
163
                    self._run_test(self._epoch, test_set)
164
                except KeyboardInterrupt:
165
                    logging.info('interrupted!')
166
                    break
167
            # Validate
168
            if not self._epoch % self.validation_frequency and valid_set:
169
                try:
170
171
                    if not self._run_valid(self._epoch, valid_set):
172
                        logging.info('patience elapsed, bailing out')
173
                        break
174
                except KeyboardInterrupt:
175
                    logging.info('interrupted!')
176
                    break
177
            # Train one step
178
179
            try:
180
                costs = self._run_train(self._epoch, train_set, train_size)
181
            except KeyboardInterrupt:
182
                logging.info('interrupted!')
183
                break
184
            # Check costs
185
            if np.isnan(costs[0][1]):
186
                logging.info("NaN detected in costs, rollback to last parameters")
187
                self.set_params(*self.checkpoint)
188
            else:
189
                self._epoch += 1
190
                self.network.epoch_callback()
191
192
            yield dict(costs)
193
194
        if valid_set and self.config.get("save_best_parameters", True):
195
            self.set_params(*self.best_params)
196
        if test_set:
197
            self._run_test(-1, test_set)
198
199
    @abstractmethod
200
    def learn(self, *variables):
201
        """
202
        Update the parameters and return the cost with given data points.
203
        :param variables:
204
        :return:
205
        """
206
207
    def _run_test(self, iteration, test_set):
208
        """
209
        Run on test iteration.
210
        """
211
        costs = self.test_step(test_set)
212
        info = ' '.join('%s=%.2f' % el for el in costs)
213
        message = "test    (epoch=%i) %s" % (iteration + 1, info)
214
        logging.info(message)
215
        self.network.train_logger.record(message)
216
        self.last_run_costs = costs
217
218
    def _run_train(self, epoch, train_set, train_size=None):
219
        """
220
        Run one training iteration.
221
        """
222
        self.network.train_logger.record_epoch(epoch + 1)
223
        costs = self.train_step(train_set, train_size)
224
        if not epoch % self.config.monitor_frequency:
225
            info = " ".join("%s=%.2f" % item for item in costs)
226
            message = "monitor (epoch=%i) %s" % (epoch + 1, info)
227
            logging.info(message)
228
            self.network.train_logger.record(message)
229
        self.last_run_costs = costs
230
        return costs
231
232
    def _run_valid(self, epoch, valid_set, dry_run=False):
233
        """
234
        Run one valid iteration, return true if to continue training.
235
        """
236
        costs = self.valid_step(valid_set)
237
        # this is the same as: (J_i - J_f) / J_i > min improvement
238
        _, J = costs[0]
239
        marker = ""
240
        if self.best_cost - J > self.best_cost * self.min_improvement:
241
            # save the best cost and parameters
242
            self.best_params = self.copy_params()
243
            marker = ' *'
244
            if not dry_run:
245
                self.best_cost = J
246
                self.best_epoch = epoch
247
248
            if self.config.auto_save and self._skip_batches == 0:
249
                self.network.train_logger.record_progress(self._progress)
250
                self.network.save_params(self.config.auto_save, new_thread=True)
251
252
        info = ' '.join('%s=%.2f' % el for el in costs)
253
        epoch_str = "epoch=%d" % (epoch + 1)
254
        if dry_run:
255
            epoch_str = "dryrun" + " " * (len(epoch_str) - 6)
256
        message = "valid   (%s) %s%s" % (epoch_str, info, marker)
257
        logging.info(message)
258
        self.last_run_costs = costs
259
        self.network.train_logger.record(message)
260
        self.checkpoint = self.copy_params()
261
        return epoch - self.best_epoch < self.patience
262
263
    def test_step(self, test_set):
264
        self._compile_evaluation_func()
265
        costs = list(zip(
266
            self.evaluation_names,
267
            np.mean([self.evaluation_func(*x) for x in test_set], axis=0)))
268
        return costs
269
270
    def valid_step(self, valid_set):
271
        self._compile_evaluation_func()
272
        costs = list(zip(
273
            self.evaluation_names,
274
            np.mean([self.evaluation_func(*x) for x in valid_set], axis=0)))
275
        return costs
276
277
    def train_step(self, train_set, train_size=None):
278
        dirty_trick_times = 0
279
        network_callback = bool(self.network.training_callbacks)
280
        trainer_callback = bool(self._iter_callbacks)
281
        cost_matrix = []
282
        self._progress = 0
283
284
        for x in train_set:
285
            if self._skip_batches == 0:
286
287
                if dirty_trick_times > 0:
288
                    cost_x = self.learn(*[t[:(t.shape[0]/2)] for t in x])
289
                    cost_matrix.append(cost_x)
290
                    cost_x = self.learn(*[t[(t.shape[0]/2):] for t in x])
291
                    dirty_trick_times -= 1
292
                else:
293
                    try:
294
                        cost_x = self.learn(*x)
295
                    except MemoryError:
296
                        logging.info("Memory error was detected, perform dirty trick 30 times")
297
                        dirty_trick_times = 30
298
                        # Dirty trick
299
                        cost_x = self.learn(*[t[:(t.shape[0]/2)] for t in x])
300
                        cost_matrix.append(cost_x)
301
                        cost_x = self.learn(*[t[(t.shape[0]/2):] for t in x])
302
                cost_matrix.append(cost_x)
303
                self.last_cost = cost_x[0]
304
                if network_callback:
305
                    self.network.training_callback()
306
                if trainer_callback:
307
                    for func in self._iter_callbacks:
308
                        func(self)
309
            else:
310
                self._skip_batches -= 1
311
            if train_size:
312
                self._progress += 1
313
                sys.stdout.write("\x1b[2K\r> %d%% | J=%.2f" % (self._progress * 100 / train_size, self.last_cost))
314
                sys.stdout.flush()
315
        self._progress = 0
316
317
        if train_size:
318
            sys.stdout.write("\r")
319
            sys.stdout.flush()
320
        costs = list(zip(self.training_names, np.mean(cost_matrix, axis=0)))
321
        return costs
322
323
    def run(self, train_set, valid_set=None, test_set=None, train_size=None, controllers=None):
324
        """
325
        Run until the end.
326
        """
327
        if isinstance(train_set, Dataset):
328
            dataset = train_set
329
            train_set = dataset.train_set()
330
            valid_set = dataset.valid_set()
331
            test_set = dataset.test_set()
332
            train_size = dataset.train_size()
333
        if controllers:
334
            for controller in controllers:
335
                controller.bind(self)
336
        timer = Timer()
337
        for _ in self.train(train_set, valid_set=valid_set, test_set=test_set, train_size=train_size):
338
            if controllers:
339
                ending = False
340
                for controller in controllers:
341
                    if hasattr(controller, 'invoke') and controller.invoke():
342
                        ending = True
343
                if ending:
344
                    break
345
        if self._report_time:
346
            timer.report()
347