Completed
Push — master ( eef17b...3c911b )
by Raphael
01:37
created

NeuralTrainer.run()   F

Complexity

Conditions 11

Size

Total Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 2
Bugs 0 Features 0
Metric Value
cc 11
dl 0
loc 24
rs 3.3409
c 2
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like NeuralTrainer.run() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
import sys
4
import numpy as np
5
import theano
6
7
from deepy.conf import TrainerConfig
8
from deepy.dataset import Dataset
9
from deepy.utils import Timer
10
11
from abc import ABCMeta, abstractmethod
12
13
from logging import getLogger
14
logging = getLogger("trainer")
15
16
class NeuralTrainer(object):
17
    """
18
    A base class for all trainers.
19
    """
20
    __metaclass__ = ABCMeta
21
22
    def __init__(self, network, config=None):
23
        """
24
        Basic neural network trainer.
25
        :type network: deepy.NeuralNetwork
26
        :type config: deepy.conf.TrainerConfig
27
        :return:
28
        """
29
        super(NeuralTrainer, self).__init__()
30
31
        self.config = None
32
        if isinstance(config, TrainerConfig):
33
            self.config = config
34
        elif isinstance(config, dict):
35
            self.config = TrainerConfig(config)
36
        else:
37
            self.config = TrainerConfig()
38
        # Model and network all refer to the computational graph
39
        self.model = self.network = network
40
41
        self.network.prepare_training()
42
        self._setup_costs()
43
44
        self.evaluation_func = None
45
46
        self.validation_frequency = self.config.validation_frequency
47
        self.min_improvement = self.config.min_improvement
48
        self.patience = self.config.patience
49
        self._iter_callbacks = []
50
51
        self.best_cost = 1e100
52
        self.best_iter = 0
53
        self.best_params = self.copy_params()
54
        self._skip_batches = 0
55
        self._progress = 0
56
        self.last_cost = 0
57
        self.last_run_costs = None
58
        self._report_time = True
59
60
    def _compile_evaluation_func(self):
61
        if not self.evaluation_func:
62
            logging.info("compile evaluation function")
63
            self.evaluation_func = theano.function(
64
                self.network.input_variables + self.network.target_variables,
65
                self.evaluation_variables,
66
                updates=self.network.updates,
67
                allow_input_downcast=True, mode=self.config.get("theano_mode", None))
68
69
    def skip(self, n_batches):
70
        """
71
        Skip N batches in the training.
72
        """
73
        logging.info("Skip %d batches" % n_batches)
74
        self._skip_batches = n_batches
75
76
    def _setup_costs(self):
77
        self.cost = self._add_regularization(self.network.cost)
78
        self.test_cost = self._add_regularization(self.network.test_cost)
79
        self.training_variables = [self.cost]
80
        self.training_names = ['J']
81
        for name, monitor in self.network.training_monitors:
82
            self.training_names.append(name)
83
            self.training_variables.append(monitor)
84
        logging.info("monitor list: %s" % ",".join(self.training_names))
85
86
        self.evaluation_variables = [self.test_cost]
87
        self.evaluation_names = ['J']
88
        for name, monitor in self.network.testing_monitors:
89
            self.evaluation_names.append(name)
90
            self.evaluation_variables.append(monitor)
91
92
    def _add_regularization(self, cost):
93
        if self.config.weight_l1 > 0:
94
            logging.info("L1 weight regularization: %f" % self.config.weight_l1)
95
            cost += self.config.weight_l1 * sum(abs(w).sum() for w in self.network.parameters)
96
        if self.config.hidden_l1 > 0:
97
            logging.info("L1 hidden unit regularization: %f" % self.config.hidden_l1)
98
            cost += self.config.hidden_l1 * sum(abs(h).mean(axis=0).sum() for h in self.network._hidden_outputs)
99
        if self.config.hidden_l2 > 0:
100
            logging.info("L2 hidden unit regularization: %f" % self.config.hidden_l2)
101
            cost += self.config.hidden_l2 * sum((h * h).mean(axis=0).sum() for h in self.network._hidden_outputs)
102
103
        return cost
104
105
    def set_params(self, targets, free_params=None):
106
        for param, target in zip(self.network.parameters, targets):
107
            param.set_value(target)
108
        if free_params:
109
            for param, param_value in zip(self.network.free_parameters, free_params):
110
                param.set_value(param_value)
111
112
    def save_params(self, path):
113
        self.set_params(*self.best_params)
114
        self.network.save_params(path)
115
116
    def load_params(self, path, exclude_free_params=False):
117
        """
118
        Load parameters for the training.
119
        This method can load free parameters and resume the training progress.
120
        """
121
        self.network.load_params(path, exclude_free_params=exclude_free_params)
122
        self.best_params = self.copy_params()
123
        # Resume the progress
124
        if self.network.train_logger.progress() > 0:
125
            self.skip(self.network.train_logger.progress())
126
127
    def copy_params(self):
128
        checkpoint = (map(lambda p: p.get_value().copy(), self.network.parameters),
129
                      map(lambda p: p.get_value().copy(), self.network.free_parameters))
130
        return checkpoint
131
132
    def add_iter_callback(self, func):
133
        """
134
        Add a iteration callback function (receives an argument of the trainer).
135
        :return:
136
        """
137
        self._iter_callbacks.append(func)
138
139
    def train(self, train_set, valid_set=None, test_set=None, train_size=None):
140
        """
141
        Train the model and return costs.
142
        """
143
        epoch = 0
144
        while True:
145
            # Test
146
            if not epoch % self.config.test_frequency and test_set:
147
                try:
148
                    self._run_test(epoch, test_set)
149
                except KeyboardInterrupt:
150
                    logging.info('interrupted!')
151
                    break
152
            # Validate
153
            if not epoch % self.validation_frequency and valid_set:
154
                try:
155
156
                    if not self._run_valid(epoch, valid_set):
157
                        logging.info('patience elapsed, bailing out')
158
                        break
159
                except KeyboardInterrupt:
160
                    logging.info('interrupted!')
161
                    break
162
            # Train one step
163
            try:
164
                costs = self._run_train(epoch, train_set, train_size)
165
            except KeyboardInterrupt:
166
                logging.info('interrupted!')
167
                break
168
            # Check costs
169
            if np.isnan(costs[0][1]):
170
                logging.info("NaN detected in costs, rollback to last parameters")
171
                self.set_params(*self.checkpoint)
172
            else:
173
                epoch += 1
174
                self.network.epoch_callback()
175
176
            yield dict(costs)
177
178
        if valid_set and self.config.get("save_best_parameters", True):
179
            self.set_params(*self.best_params)
180
        if test_set:
181
            self._run_test(-1, test_set)
182
183
    @abstractmethod
184
    def learn(self, *variables):
185
        """
186
        Update the parameters and return the cost with given data points.
187
        :param variables:
188
        :return:
189
        """
190
191
    def _run_test(self, iteration, test_set):
192
        """
193
        Run on test iteration.
194
        """
195
        costs = self.test_step(test_set)
196
        info = ' '.join('%s=%.2f' % el for el in costs)
197
        message = "test    (epoch=%i) %s" % (iteration + 1, info)
198
        logging.info(message)
199
        self.network.train_logger.record(message)
200
        self.last_run_costs = costs
201
202
    def _run_train(self, iteration, train_set, train_size=None):
203
        """
204
        Run one training iteration.
205
        """
206
        costs = self.train_step(train_set, train_size)
207
        if not iteration % self.config.monitor_frequency:
208
            info = " ".join("%s=%.2f" % item for item in costs)
209
            message = "monitor (epoch=%i) %s" % (iteration + 1, info)
210
            logging.info(message)
211
            self.network.train_logger.record(message)
212
        self.last_run_costs = costs
213
        return costs
214
215
    def _run_valid(self, iteration, valid_set, dry_run=False):
216
        """
217
        Run one valid iteration, return true if to continue training.
218
        """
219
        costs = self.valid_step(valid_set)
220
        # this is the same as: (J_i - J_f) / J_i > min improvement
221
        _, J = costs[0]
222
        marker = ""
223
        if self.best_cost - J > self.best_cost * self.min_improvement:
224
            # save the best cost and parameters
225
            self.best_params = self.copy_params()
226
            marker = ' *'
227
            if not dry_run:
228
                self.best_cost = J
229
                self.best_iter = iteration
230
231
            if self.config.auto_save:
232
                self.network.train_logger.record_progress(self._progress)
233
                self.network.save_params(self.config.auto_save, new_thread=True)
234
235
        info = ' '.join('%s=%.2f' % el for el in costs)
236
        epoch = "epoch=%d" % (iteration + 1)
237
        if dry_run:
238
            epoch = "dryrun" + " " * (len(epoch) - 6)
239
        message = "valid   (%s) %s%s" % (epoch, info, marker)
240
        logging.info(message)
241
        self.last_run_costs = costs
242
        self.network.train_logger.record(message)
243
        self.checkpoint = self.copy_params()
244
        return iteration - self.best_iter < self.patience
245
246
    def test_step(self, test_set):
247
        self._compile_evaluation_func()
248
        costs = list(zip(
249
            self.evaluation_names,
250
            np.mean([self.evaluation_func(*x) for x in test_set], axis=0)))
251
        return costs
252
253
    def valid_step(self, valid_set):
254
        self._compile_evaluation_func()
255
        costs = list(zip(
256
            self.evaluation_names,
257
            np.mean([self.evaluation_func(*x) for x in valid_set], axis=0)))
258
        return costs
259
260
    def train_step(self, train_set, train_size=None):
261
        dirty_trick_times = 0
262
        network_callback = bool(self.network.training_callbacks)
263
        trainer_callback = bool(self._iter_callbacks)
264
        cost_matrix = []
265
        self._progress = 0
266
267
        for x in train_set:
268
            if self._skip_batches == 0:
269
270
                if dirty_trick_times > 0:
271
                    cost_x = self.learn(*[t[:(t.shape[0]/2)] for t in x])
272
                    cost_matrix.append(cost_x)
273
                    cost_x = self.learn(*[t[(t.shape[0]/2):] for t in x])
274
                    dirty_trick_times -= 1
275
                else:
276
                    try:
277
                        cost_x = self.learn(*x)
278
                    except MemoryError:
279
                        logging.info("Memory error was detected, perform dirty trick 30 times")
280
                        dirty_trick_times = 30
281
                        # Dirty trick
282
                        cost_x = self.learn(*[t[:(t.shape[0]/2)] for t in x])
283
                        cost_matrix.append(cost_x)
284
                        cost_x = self.learn(*[t[(t.shape[0]/2):] for t in x])
285
                cost_matrix.append(cost_x)
286
                self.last_cost = cost_x[0]
287
                if network_callback:
288
                    self.network.training_callback()
289
                if trainer_callback:
290
                    for func in self._iter_callbacks:
291
                        func(self)
292
            else:
293
                self._skip_batches -= 1
294
            if train_size:
295
                self._progress += 1
296
                sys.stdout.write("\x1b[2K\r> %d%% | J=%.2f" % (self._progress * 100 / train_size, self.last_cost))
297
                sys.stdout.flush()
298
        self._progress = 0
299
300
        if train_size:
301
            sys.stdout.write("\r")
302
            sys.stdout.flush()
303
        costs = list(zip(self.training_names, np.mean(cost_matrix, axis=0)))
304
        return costs
305
306
    def run(self, train_set, valid_set=None, test_set=None, train_size=None, controllers=None):
307
        """
308
        Run until the end.
309
        """
310
        if isinstance(train_set, Dataset):
311
            dataset = train_set
312
            train_set = dataset.train_set()
313
            valid_set = dataset.valid_set()
314
            test_set = dataset.test_set()
315
            train_size = dataset.train_size()
316
        if controllers:
317
            for controller in controllers:
318
                controller.bind(self)
319
        timer = Timer()
320
        for _ in self.train(train_set, valid_set=valid_set, test_set=test_set, train_size=train_size):
321
            if controllers:
322
                ending = False
323
                for controller in controllers:
324
                    if hasattr(controller, 'invoke') and controller.invoke():
325
                        ending = True
326
                if ending:
327
                    break
328
        if self._report_time:
329
            timer.report()
330