Completed
Push — master ( 48255b...bf2b0c )
by Raphael
01:13
created

NeuralTrainer._run_valid()   D

Complexity

Conditions 8

Size

Total Lines 31

Duplication

Lines 0
Ratio 0 %

Importance

Changes 3
Bugs 0 Features 0
Metric Value
cc 8
c 3
b 0
f 0
dl 0
loc 31
rs 4
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
import sys
4
import time
5
import numpy as np
6
import theano
7
8
from deepy.conf import TrainerConfig
9
from deepy.dataset import Dataset
10
from deepy.utils import Timer
11
12
from abc import ABCMeta, abstractmethod
13
14
from logging import getLogger
15
logging = getLogger("trainer")
16
17
class NeuralTrainer(object):
18
    """
19
    A base class for all trainers.
20
    """
21
    __metaclass__ = ABCMeta
22
23
    def __init__(self, network, config=None):
24
        """
25
        Basic neural network trainer.
26
        :type network: deepy.NeuralNetwork
27
        :type config: deepy.conf.TrainerConfig
28
        :return:
29
        """
30
        super(NeuralTrainer, self).__init__()
31
32
        self.config = None
33
        if isinstance(config, TrainerConfig):
34
            self.config = config
35
        elif isinstance(config, dict):
36
            self.config = TrainerConfig(config)
37
        else:
38
            self.config = TrainerConfig()
39
        # Model and network all refer to the computational graph
40
        self.model = self.network = network
41
42
        self.network.prepare_training()
43
        self._setup_costs()
44
45
        self.evaluation_func = None
46
47
        self.validation_frequency = self.config.validation_frequency
48
        self.min_improvement = self.config.min_improvement
49
        self.patience = self.config.patience
50
        self._iter_callbacks = []
51
52
        self.best_cost = 1e100
53
        self.best_epoch = 0
54
        self.best_params = self.copy_params()
55
        self._skip_batches = 0
56
        self._skip_epochs = 0
57
        self._progress = 0
58
        self.last_cost = 0
59
        self.last_run_costs = None
60
        self._report_time = True
61
        self._epoch = 0
62
63
    def _compile_evaluation_func(self):
64
        if not self.evaluation_func:
65
            logging.info("compile evaluation function")
66
            self.evaluation_func = theano.function(
67
                self.network.input_variables + self.network.target_variables,
68
                self.evaluation_variables,
69
                updates=self.network.updates,
70
                allow_input_downcast=True, mode=self.config.get("theano_mode", None))
71
72
    def skip(self, n_batches, n_epochs=0):
73
        """
74
        Skip N batches in the training.
75
        """
76
        logging.info("skip %d epochs and %d batches" % (n_epochs, n_batches))
77
        self._skip_batches = n_batches
78
        self._skip_epochs = n_epochs
79
80
    def epoch(self):
81
        """
82
        Get current epoch.
83
        """
84
        return self._epoch
85
86
    def _setup_costs(self):
87
        self.cost = self._add_regularization(self.network.cost)
88
        self.test_cost = self._add_regularization(self.network.test_cost)
89
        self.training_variables = [self.cost]
90
        self.training_names = ['J']
91
        for name, monitor in self.network.training_monitors:
92
            self.training_names.append(name)
93
            self.training_variables.append(monitor)
94
        logging.info("monitor list: %s" % ",".join(self.training_names))
95
96
        self.evaluation_variables = [self.test_cost]
97
        self.evaluation_names = ['J']
98
        for name, monitor in self.network.testing_monitors:
99
            self.evaluation_names.append(name)
100
            self.evaluation_variables.append(monitor)
101
102
    def _add_regularization(self, cost):
103
        if self.config.weight_l1 > 0:
104
            logging.info("L1 weight regularization: %f" % self.config.weight_l1)
105
            cost += self.config.weight_l1 * sum(abs(w).sum() for w in self.network.parameters)
106
        if self.config.hidden_l1 > 0:
107
            logging.info("L1 hidden unit regularization: %f" % self.config.hidden_l1)
108
            cost += self.config.hidden_l1 * sum(abs(h).mean(axis=0).sum() for h in self.network._hidden_outputs)
109
        if self.config.hidden_l2 > 0:
110
            logging.info("L2 hidden unit regularization: %f" % self.config.hidden_l2)
111
            cost += self.config.hidden_l2 * sum((h * h).mean(axis=0).sum() for h in self.network._hidden_outputs)
112
113
        return cost
114
115
    def set_params(self, targets, free_params=None):
116
        for param, target in zip(self.network.parameters, targets):
117
            param.set_value(target)
118
        if free_params:
119
            for param, param_value in zip(self.network.free_parameters, free_params):
120
                param.set_value(param_value)
121
122
    def save_params(self, path):
123
        self.set_params(*self.best_params)
124
        self.network.save_params(path)
125
126
    def load_params(self, path, exclude_free_params=False):
127
        """
128
        Load parameters for the training.
129
        This method can load free parameters and resume the training progress.
130
        """
131
        self.network.load_params(path, exclude_free_params=exclude_free_params)
132
        self.best_params = self.copy_params()
133
        # Resume the progress
134
        if self.network.train_logger.progress() > 0 or self.network.train_logger.epoch() > 0:
135
            self.skip(self.network.train_logger.progress(), self.network.train_logger.epoch() - 1)
136
137
    def copy_params(self):
138
        checkpoint = (map(lambda p: p.get_value().copy(), self.network.parameters),
139
                      map(lambda p: p.get_value().copy(), self.network.free_parameters))
140
        return checkpoint
141
142
    def add_iter_callback(self, func):
143
        """
144
        Add a iteration callback function (receives an argument of the trainer).
145
        :return:
146
        """
147
        self._iter_callbacks.append(func)
148
149
    def train(self, train_set, valid_set=None, test_set=None, train_size=None):
150
        """
151
        Train the model and return costs.
152
        """
153
        self._epoch = 0
154
        while True:
155
            if self._skip_epochs > 0:
156
                logging.info("skipping one epoch ...")
157
                self._skip_epochs -= 1
158
                self._epoch += 1
159
                yield None
160
                continue
161
            # Test
162
            if not self._epoch % self.config.test_frequency and test_set:
163
                try:
164
                    self._run_test(self._epoch, test_set)
165
                except KeyboardInterrupt:
166
                    logging.info('interrupted!')
167
                    break
168
            # Validate
169
            if not self._epoch % self.validation_frequency and valid_set:
170
                try:
171
172
                    if not self._run_valid(self._epoch, valid_set):
173
                        logging.info('patience elapsed, bailing out')
174
                        break
175
                except KeyboardInterrupt:
176
                    logging.info('interrupted!')
177
                    break
178
            # Train one step
179
180
            try:
181
                costs = self._run_train(self._epoch, train_set, train_size)
182
            except KeyboardInterrupt:
183
                logging.info('interrupted!')
184
                break
185
            # Check costs
186
            if np.isnan(costs[0][1]):
187
                logging.info("NaN detected in costs, rollback to last parameters")
188
                self.set_params(*self.checkpoint)
189
            else:
190
                self._epoch += 1
191
                self.network.epoch_callback()
192
193
            yield dict(costs)
194
195
        if valid_set and self.config.get("save_best_parameters", True):
196
            self.set_params(*self.best_params)
197
        if test_set:
198
            self._run_test(-1, test_set)
199
200
    @abstractmethod
201
    def learn(self, *variables):
202
        """
203
        Update the parameters and return the cost with given data points.
204
        :param variables:
205
        :return:
206
        """
207
208
    def _run_test(self, iteration, test_set):
209
        """
210
        Run on test iteration.
211
        """
212
        costs = self.test_step(test_set)
213
        info = ' '.join('%s=%.2f' % el for el in costs)
214
        message = "test    (epoch=%i) %s" % (iteration + 1, info)
215
        logging.info(message)
216
        self.network.train_logger.record(message)
217
        self.last_run_costs = costs
218
219
    def _run_train(self, epoch, train_set, train_size=None):
220
        """
221
        Run one training iteration.
222
        """
223
        self.network.train_logger.record_epoch(epoch + 1)
224
        costs = self.train_step(train_set, train_size)
225
        if not epoch % self.config.monitor_frequency:
226
            info = " ".join("%s=%.2f" % item for item in costs)
227
            message = "monitor (epoch=%i) %s" % (epoch + 1, info)
228
            logging.info(message)
229
            self.network.train_logger.record(message)
230
        self.last_run_costs = costs
231
        return costs
232
233
    def _run_valid(self, epoch, valid_set, dry_run=False, save_path=None):
234
        """
235
        Run one valid iteration, return true if to continue training.
236
        """
237
        costs = self.valid_step(valid_set)
238
        # this is the same as: (J_i - J_f) / J_i > min improvement
239
        _, J = costs[0]
240
        marker = ""
241
        if self.best_cost - J > self.best_cost * self.min_improvement:
242
            # save the best cost and parameters
243
            self.best_params = self.copy_params()
244
            marker = ' *'
245
            if not dry_run:
246
                self.best_cost = J
247
                self.best_epoch = epoch
248
249
            save_path = save_path if save_path else self.config.auto_save
250
            if save_path and self._skip_batches == 0:
251
                self.network.train_logger.record_progress(self._progress)
252
                self.network.save_params(save_path, new_thread=True)
253
254
        info = ' '.join('%s=%.2f' % el for el in costs)
255
        epoch_str = "epoch=%d" % (epoch + 1)
256
        if dry_run:
257
            epoch_str = "dryrun" + " " * (len(epoch_str) - 6)
258
        message = "valid   (%s) %s%s" % (epoch_str, info, marker)
259
        logging.info(message)
260
        self.last_run_costs = costs
261
        self.network.train_logger.record(message)
262
        self.checkpoint = self.copy_params()
263
        return epoch - self.best_epoch < self.patience
264
265
    def test_step(self, test_set):
266
        self._compile_evaluation_func()
267
        costs = list(zip(
268
            self.evaluation_names,
269
            np.mean([self.evaluation_func(*x) for x in test_set], axis=0)))
270
        return costs
271
272
    def valid_step(self, valid_set):
273
        self._compile_evaluation_func()
274
        costs = list(zip(
275
            self.evaluation_names,
276
            np.mean([self.evaluation_func(*x) for x in valid_set], axis=0)))
277
        return costs
278
279
    def train_step(self, train_set, train_size=None):
280
        dirty_trick_times = 0
281
        network_callback = bool(self.network.training_callbacks)
282
        trainer_callback = bool(self._iter_callbacks)
283
        cost_matrix = []
284
        exec_count = 0
285
        start_time = time.time()
286
        self._progress = 0
287
288
289
        for x in train_set:
290
            if self._skip_batches == 0:
291
292
                if dirty_trick_times > 0:
293
                    cost_x = self.learn(*[t[:(t.shape[0]/2)] for t in x])
294
                    cost_matrix.append(cost_x)
295
                    cost_x = self.learn(*[t[(t.shape[0]/2):] for t in x])
296
                    dirty_trick_times -= 1
297
                else:
298
                    try:
299
                        cost_x = self.learn(*x)
300
                    except MemoryError:
301
                        logging.info("Memory error was detected, perform dirty trick 30 times")
302
                        dirty_trick_times = 30
303
                        # Dirty trick
304
                        cost_x = self.learn(*[t[:(t.shape[0]/2)] for t in x])
305
                        cost_matrix.append(cost_x)
306
                        cost_x = self.learn(*[t[(t.shape[0]/2):] for t in x])
307
                cost_matrix.append(cost_x)
308
                self.last_cost = cost_x[0]
309
                exec_count += 1
310
                if network_callback:
311
                    self.network.training_callback()
312
                if trainer_callback:
313
                    for func in self._iter_callbacks:
314
                        func(self)
315
            else:
316
                self._skip_batches -= 1
317
            if train_size:
318
                self._progress += 1
319
                spd = float(exec_count) / (time.time() - start_time)
320
                sys.stdout.write("\x1b[2K\r> %d%% | J=%.2f | spd=%.2f batch/s" % (self._progress * 100 / train_size, self.last_cost, spd))
321
                sys.stdout.flush()
322
        self._progress = 0
323
324
        if train_size:
325
            sys.stdout.write("\r")
326
            sys.stdout.flush()
327
        costs = list(zip(self.training_names, np.mean(cost_matrix, axis=0)))
328
        return costs
329
330
    def run(self, train_set, valid_set=None, test_set=None, train_size=None, controllers=None):
331
        """
332
        Run until the end.
333
        """
334
        if isinstance(train_set, Dataset):
335
            dataset = train_set
336
            train_set = dataset.train_set()
337
            valid_set = dataset.valid_set()
338
            test_set = dataset.test_set()
339
            train_size = dataset.train_size()
340
        if controllers:
341
            for controller in controllers:
342
                controller.bind(self)
343
        timer = Timer()
344
        for _ in self.train(train_set, valid_set=valid_set, test_set=test_set, train_size=train_size):
345
            if controllers:
346
                ending = False
347
                for controller in controllers:
348
                    if hasattr(controller, 'invoke') and controller.invoke():
349
                        ending = True
350
                if ending:
351
                    break
352
        if self._report_time:
353
            timer.report()
354