Completed
Push — master ( 27a82d...e2ab7f )
by Raphael
58s
created

deepy.trainers.NeuralTrainer.learn()   A

Complexity

Conditions 1

Size

Total Lines 7

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 7
rs 9.4285
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
import sys
4
import numpy as np
5
import theano
6
7
from deepy.conf import TrainerConfig
8
from deepy.dataset import Dataset
9
from deepy.utils import Timer
10
11
from abc import ABCMeta, abstractmethod
12
13
from logging import getLogger
14
logging = getLogger(__name__)
15
16
class NeuralTrainer(object):
17
    """
18
    A base class for all trainers.
19
    """
20
    __metaclass__ = ABCMeta
21
22
    def __init__(self, network, config=None):
23
        """
24
        Basic neural network trainer.
25
        :type network: deepy.NeuralNetwork
26
        :type config: deepy.conf.TrainerConfig
27
        :return:
28
        """
29
        super(NeuralTrainer, self).__init__()
30
31
        self.config = None
32
        if isinstance(config, TrainerConfig):
33
            self.config = config
34
        elif isinstance(config, dict):
35
            self.config = TrainerConfig(config)
36
        else:
37
            self.config = TrainerConfig()
38
        # Model and network all refer to the computational graph
39
        self.model = self.network = network
40
41
        self.network.prepare_training()
42
        self._setup_costs()
43
44
        logging.info("compile evaluation function")
45
        self.evaluation_func = theano.function(
46
            network.input_variables + network.target_variables, self.evaluation_variables, updates=network.updates,
47
            allow_input_downcast=True, mode=self.config.get("theano_mode", None))
48
49
        self.validation_frequency = self.config.validation_frequency
50
        self.min_improvement = self.config.min_improvement
51
        self.patience = self.config.patience
52
        self._iter_callbacks = []
53
54
        self.best_cost = 1e100
55
        self.best_iter = 0
56
        self.best_params = self.copy_params()
57
        self._skip_batches = 0
58
        self._progress = 0
59
        self.last_cost = 0
60
61
    def skip(self, n_batches):
62
        """
63
        Skip N batches in the training.
64
        """
65
        logging.info("Skip %d batches" % n_batches)
66
        self._skip_batches = n_batches
67
68
    def _setup_costs(self):
69
        self.cost = self._add_regularization(self.network.cost)
70
        self.test_cost = self._add_regularization(self.network.test_cost)
71
        self.training_variables = [self.cost]
72
        self.training_names = ['J']
73
        for name, monitor in self.network.training_monitors:
74
            self.training_names.append(name)
75
            self.training_variables.append(monitor)
76
        logging.info("monitor list: %s" % ",".join(self.training_names))
77
78
        self.evaluation_variables = [self.test_cost]
79
        self.evaluation_names = ['J']
80
        for name, monitor in self.network.testing_monitors:
81
            self.evaluation_names.append(name)
82
            self.evaluation_variables.append(monitor)
83
84
    def _add_regularization(self, cost):
85
        if self.config.weight_l1 > 0:
86
            logging.info("L1 weight regularization: %f" % self.config.weight_l1)
87
            cost += self.config.weight_l1 * sum(abs(w).sum() for w in self.network.parameters)
88
        if self.config.hidden_l1 > 0:
89
            logging.info("L1 hidden unit regularization: %f" % self.config.hidden_l1)
90
            cost += self.config.hidden_l1 * sum(abs(h).mean(axis=0).sum() for h in self.network._hidden_outputs)
91
        if self.config.hidden_l2 > 0:
92
            logging.info("L2 hidden unit regularization: %f" % self.config.hidden_l2)
93
            cost += self.config.hidden_l2 * sum((h * h).mean(axis=0).sum() for h in self.network._hidden_outputs)
94
95
        return cost
96
97
    def set_params(self, targets, free_params=None):
98
        for param, target in zip(self.network.parameters, targets):
99
            param.set_value(target)
100
        if free_params:
101
            for param, param_value in zip(self.network.free_parameters, free_params):
102
                param.set_value(param_value)
103
104
    def save_params(self, path):
105
        self.set_params(*self.best_params)
106
        self.network.save_params(path)
107
108
    def load_params(self, path, exclude_free_params=False):
109
        """
110
        Load parameters for the training.
111
        This method can load free parameters and resume the training progress.
112
        """
113
        self.network.load_params(path, exclude_free_params=exclude_free_params)
114
        self.best_params = self.copy_params()
115
        # Resume the progress
116
        if self.network.train_logger.progress() > 0:
117
            self.skip(self.network.train_logger.progress())
118
119
    def copy_params(self):
120
        checkpoint = (map(lambda p: p.get_value().copy(), self.network.parameters),
121
                      map(lambda p: p.get_value().copy(), self.network.free_parameters))
122
        return checkpoint
123
124
    def add_iter_callback(self, func):
125
        """
126
        Add a iteration callback function (receives an argument of the trainer).
127
        :return:
128
        """
129
        self._iter_callbacks.append(func)
130
131
    def train(self, train_set, valid_set=None, test_set=None, train_size=None):
132
        """
133
        Train the model and return costs.
134
        """
135
        iteration = 0
136
        while True:
137
            # Test
138
            if not iteration % self.config.test_frequency and test_set:
139
                try:
140
                    self._run_test(iteration, test_set)
141
                except KeyboardInterrupt:
142
                    logging.info('interrupted!')
143
                    break
144
            # Validate
145
            if not iteration % self.validation_frequency and valid_set:
146
                try:
147
148
                    if not self._run_valid(iteration, valid_set):
149
                        logging.info('patience elapsed, bailing out')
150
                        break
151
                except KeyboardInterrupt:
152
                    logging.info('interrupted!')
153
                    break
154
            # Train one step
155
            try:
156
                costs = self._run_train(iteration, train_set, train_size)
157
            except KeyboardInterrupt:
158
                logging.info('interrupted!')
159
                break
160
            # Check costs
161
            if np.isnan(costs[0][1]):
162
                logging.info("NaN detected in costs, rollback to last parameters")
163
                self.set_params(*self.checkpoint)
164
            else:
165
                iteration += 1
166
                self.network.epoch_callback()
167
168
            yield dict(costs)
169
170
        if valid_set and self.config.get("save_best_parameters", True):
171
            self.set_params(*self.best_params)
172
        if test_set:
173
            self._run_test(-1, test_set)
174
175
    @abstractmethod
176
    def learn(self, *variables):
177
        """
178
        Update the parameters and return the cost with given data points.
179
        :param variables:
180
        :return:
181
        """
182
183
    def _run_test(self, iteration, test_set):
184
        """
185
        Run on test iteration.
186
        """
187
        costs = self.test_step(test_set)
188
        info = ' '.join('%s=%.2f' % el for el in costs)
189
        message = "test    (iter=%i) %s" % (iteration + 1, info)
190
        logging.info(message)
191
        self.network.train_logger.record(message)
192
193
    def _run_train(self, iteration, train_set, train_size=None):
194
        """
195
        Run one training iteration.
196
        """
197
        costs = self.train_step(train_set, train_size)
198
199
        if not iteration % self.config.monitor_frequency:
200
            info = " ".join("%s=%.2f" % item for item in costs)
201
            message = "monitor (iter=%i) %s" % (iteration + 1, info)
202
            logging.info(message)
203
            self.network.train_logger.record(message)
204
        return costs
205
206
    def _run_valid(self, iteration, valid_set, dry_run=False):
207
        """
208
        Run one valid iteration, return true if to continue training.
209
        """
210
        costs = self.valid_step(valid_set)
211
        # this is the same as: (J_i - J_f) / J_i > min improvement
212
        _, J = costs[0]
213
        marker = ""
214
        if self.best_cost - J > self.best_cost * self.min_improvement:
215
            # save the best cost and parameters
216
            self.best_params = self.copy_params()
217
            marker = ' *'
218
            if not dry_run:
219
                self.best_cost = J
220
                self.best_iter = iteration
221
222
            if self.config.auto_save:
223
                self.network.train_logger.record_progress(self._progress)
224
                self.network.save_params(self.config.auto_save, new_thread=True)
225
226
        info = ' '.join('%s=%.2f' % el for el in costs)
227
        iter_str = "iter=%d" % (iteration + 1)
228
        if dry_run:
229
            iter_str = "dryrun" + " " * (len(iter_str) - 6)
230
        message = "valid   (%s) %s%s" % (iter_str, info, marker)
231
        logging.info(message)
232
        self.network.train_logger.record(message)
233
        self.checkpoint = self.copy_params()
234
        return iteration - self.best_iter < self.patience
235
236
    def test_step(self, test_set):
237
        costs = list(zip(
238
            self.evaluation_names,
239
            np.mean([self.evaluation_func(*x) for x in test_set], axis=0)))
240
        return costs
241
242
    def valid_step(self, valid_set):
243
        costs = list(zip(
244
            self.evaluation_names,
245
            np.mean([self.evaluation_func(*x) for x in valid_set], axis=0)))
246
        return costs
247
248
    def train_step(self, train_set, train_size=None):
249
        dirty_trick_times = 0
250
        network_callback = bool(self.network.training_callbacks)
251
        trainer_callback = bool(self._iter_callbacks)
252
        cost_matrix = []
253
        self._progress = 0
254
255
        for x in train_set:
256
            if self._skip_batches == 0:
257
                if dirty_trick_times > 0:
258
                    cost_x = self.learn(*[t[:(t.shape[0]/2)] for t in x])
259
                    cost_matrix.append(cost_x)
260
                    cost_x = self.learn(*[t[(t.shape[0]/2):] for t in x])
261
                    dirty_trick_times -= 1
262
                else:
263
                    try:
264
                        cost_x = self.learn(*x)
265
                    except MemoryError:
266
                        logging.info("Memory error was detected, perform dirty trick 30 times")
267
                        dirty_trick_times = 30
268
                        # Dirty trick
269
                        cost_x = self.learn(*[t[:(t.shape[0]/2)] for t in x])
270
                        cost_matrix.append(cost_x)
271
                        cost_x = self.learn(*[t[(t.shape[0]/2):] for t in x])
272
                cost_matrix.append(cost_x)
273
                self.last_cost = cost_x[0]
274
                if network_callback:
275
                    self.network.training_callback()
276
                if trainer_callback:
277
                    for func in self._iter_callbacks:
278
                        func(self)
279
            else:
280
                self._skip_batches -= 1
281
            if train_size:
282
                self._progress += 1
283
                sys.stdout.write("\x1b[2K\r> %d%% | J=%.2f" % (self._progress * 100 / train_size, self.last_cost))
284
                sys.stdout.flush()
285
        self._progress = 0
286
287
        if train_size:
288
            sys.stdout.write("\r")
289
            sys.stdout.flush()
290
        costs = list(zip(self.training_names, np.mean(cost_matrix, axis=0)))
291
        return costs
292
293
    def run(self, train_set, valid_set=None, test_set=None, train_size=None, controllers=None):
294
        """
295
        Run until the end.
296
        """
297
        if isinstance(train_set, Dataset):
298
            dataset = train_set
299
            train_set = dataset.train_set()
300
            valid_set = dataset.valid_set()
301
            test_set = dataset.test_set()
302
            train_size = dataset.train_size()
303
304
        timer = Timer()
305
        for _ in self.train(train_set, valid_set=valid_set, test_set=test_set, train_size=train_size):
306
            if controllers:
307
                ending = False
308
                for controller in controllers:
309
                    if hasattr(controller, 'invoke') and controller.invoke():
310
                        ending = True
311
                if ending:
312
                    break
313
        timer.report()
314
        return