Completed
Push — master ( f73e69...91b7c0 )
by Raphael
01:35
created

NeuralTrainer   F

Complexity

Total Complexity 79

Size/Duplication

Total Lines 308
Duplicated Lines 0 %

Importance

Changes 12
Bugs 3 Features 0
Metric Value
dl 0
loc 308
rs 2.0547
c 12
b 3
f 0
wmc 79

19 Methods

Rating   Name   Duplication   Size   Complexity  
A _run_test() 0 9 2
A load_params() 0 10 2
A set_params() 0 6 4
B _run_valid() 0 29 6
A add_iter_callback() 0 6 1
A valid_step() 0 6 2
A copy_params() 0 4 3
C run() 0 22 8
F train_step() 0 45 14
F train() 0 43 14
A save_params() 0 3 1
A learn() 0 7 1
B _add_regularization() 0 12 7
A test_step() 0 6 2
A _compile_evaluation_func() 0 8 2
A skip() 0 6 1
A _run_train() 0 12 3
A _setup_costs() 0 15 3
B __init__() 0 35 3

How to fix   Complexity   

Complex Class

Complex classes like NeuralTrainer often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
import sys
4
import numpy as np
5
import theano
6
7
from deepy.conf import TrainerConfig
8
from deepy.dataset import Dataset
9
from deepy.utils import Timer
10
11
from abc import ABCMeta, abstractmethod
12
13
from logging import getLogger
14
logging = getLogger(__name__)
15
16
class NeuralTrainer(object):
17
    """
18
    A base class for all trainers.
19
    """
20
    __metaclass__ = ABCMeta
21
22
    def __init__(self, network, config=None):
23
        """
24
        Basic neural network trainer.
25
        :type network: deepy.NeuralNetwork
26
        :type config: deepy.conf.TrainerConfig
27
        :return:
28
        """
29
        super(NeuralTrainer, self).__init__()
30
31
        self.config = None
32
        if isinstance(config, TrainerConfig):
33
            self.config = config
34
        elif isinstance(config, dict):
35
            self.config = TrainerConfig(config)
36
        else:
37
            self.config = TrainerConfig()
38
        # Model and network all refer to the computational graph
39
        self.model = self.network = network
40
41
        self.network.prepare_training()
42
        self._setup_costs()
43
44
        self.evaluation_func = None
45
46
        self.validation_frequency = self.config.validation_frequency
47
        self.min_improvement = self.config.min_improvement
48
        self.patience = self.config.patience
49
        self._iter_callbacks = []
50
51
        self.best_cost = 1e100
52
        self.best_iter = 0
53
        self.best_params = self.copy_params()
54
        self._skip_batches = 0
55
        self._progress = 0
56
        self.last_cost = 0
57
58
    def _compile_evaluation_func(self):
59
        if not self.evaluation_func:
60
            logging.info("compile evaluation function")
61
            self.evaluation_func = theano.function(
62
                self.network.input_variables + self.network.target_variables,
63
                self.evaluation_variables,
64
                updates=self.network.updates,
65
                allow_input_downcast=True, mode=self.config.get("theano_mode", None))
66
67
    def skip(self, n_batches):
68
        """
69
        Skip N batches in the training.
70
        """
71
        logging.info("Skip %d batches" % n_batches)
72
        self._skip_batches = n_batches
73
74
    def _setup_costs(self):
75
        self.cost = self._add_regularization(self.network.cost)
76
        self.test_cost = self._add_regularization(self.network.test_cost)
77
        self.training_variables = [self.cost]
78
        self.training_names = ['J']
79
        for name, monitor in self.network.training_monitors:
80
            self.training_names.append(name)
81
            self.training_variables.append(monitor)
82
        logging.info("monitor list: %s" % ",".join(self.training_names))
83
84
        self.evaluation_variables = [self.test_cost]
85
        self.evaluation_names = ['J']
86
        for name, monitor in self.network.testing_monitors:
87
            self.evaluation_names.append(name)
88
            self.evaluation_variables.append(monitor)
89
90
    def _add_regularization(self, cost):
91
        if self.config.weight_l1 > 0:
92
            logging.info("L1 weight regularization: %f" % self.config.weight_l1)
93
            cost += self.config.weight_l1 * sum(abs(w).sum() for w in self.network.parameters)
94
        if self.config.hidden_l1 > 0:
95
            logging.info("L1 hidden unit regularization: %f" % self.config.hidden_l1)
96
            cost += self.config.hidden_l1 * sum(abs(h).mean(axis=0).sum() for h in self.network._hidden_outputs)
97
        if self.config.hidden_l2 > 0:
98
            logging.info("L2 hidden unit regularization: %f" % self.config.hidden_l2)
99
            cost += self.config.hidden_l2 * sum((h * h).mean(axis=0).sum() for h in self.network._hidden_outputs)
100
101
        return cost
102
103
    def set_params(self, targets, free_params=None):
104
        for param, target in zip(self.network.parameters, targets):
105
            param.set_value(target)
106
        if free_params:
107
            for param, param_value in zip(self.network.free_parameters, free_params):
108
                param.set_value(param_value)
109
110
    def save_params(self, path):
111
        self.set_params(*self.best_params)
112
        self.network.save_params(path)
113
114
    def load_params(self, path, exclude_free_params=False):
115
        """
116
        Load parameters for the training.
117
        This method can load free parameters and resume the training progress.
118
        """
119
        self.network.load_params(path, exclude_free_params=exclude_free_params)
120
        self.best_params = self.copy_params()
121
        # Resume the progress
122
        if self.network.train_logger.progress() > 0:
123
            self.skip(self.network.train_logger.progress())
124
125
    def copy_params(self):
126
        checkpoint = (map(lambda p: p.get_value().copy(), self.network.parameters),
127
                      map(lambda p: p.get_value().copy(), self.network.free_parameters))
128
        return checkpoint
129
130
    def add_iter_callback(self, func):
131
        """
132
        Add a iteration callback function (receives an argument of the trainer).
133
        :return:
134
        """
135
        self._iter_callbacks.append(func)
136
137
    def train(self, train_set, valid_set=None, test_set=None, train_size=None):
138
        """
139
        Train the model and return costs.
140
        """
141
        iteration = 0
142
        while True:
143
            # Test
144
            if not iteration % self.config.test_frequency and test_set:
145
                try:
146
                    self._run_test(iteration, test_set)
147
                except KeyboardInterrupt:
148
                    logging.info('interrupted!')
149
                    break
150
            # Validate
151
            if not iteration % self.validation_frequency and valid_set:
152
                try:
153
154
                    if not self._run_valid(iteration, valid_set):
155
                        logging.info('patience elapsed, bailing out')
156
                        break
157
                except KeyboardInterrupt:
158
                    logging.info('interrupted!')
159
                    break
160
            # Train one step
161
            try:
162
                costs = self._run_train(iteration, train_set, train_size)
163
            except KeyboardInterrupt:
164
                logging.info('interrupted!')
165
                break
166
            # Check costs
167
            if np.isnan(costs[0][1]):
168
                logging.info("NaN detected in costs, rollback to last parameters")
169
                self.set_params(*self.checkpoint)
170
            else:
171
                iteration += 1
172
                self.network.epoch_callback()
173
174
            yield dict(costs)
175
176
        if valid_set and self.config.get("save_best_parameters", True):
177
            self.set_params(*self.best_params)
178
        if test_set:
179
            self._run_test(-1, test_set)
180
181
    @abstractmethod
182
    def learn(self, *variables):
183
        """
184
        Update the parameters and return the cost with given data points.
185
        :param variables:
186
        :return:
187
        """
188
189
    def _run_test(self, iteration, test_set):
190
        """
191
        Run on test iteration.
192
        """
193
        costs = self.test_step(test_set)
194
        info = ' '.join('%s=%.2f' % el for el in costs)
195
        message = "test    (iter=%i) %s" % (iteration + 1, info)
196
        logging.info(message)
197
        self.network.train_logger.record(message)
198
199
    def _run_train(self, iteration, train_set, train_size=None):
200
        """
201
        Run one training iteration.
202
        """
203
        costs = self.train_step(train_set, train_size)
204
205
        if not iteration % self.config.monitor_frequency:
206
            info = " ".join("%s=%.2f" % item for item in costs)
207
            message = "monitor (iter=%i) %s" % (iteration + 1, info)
208
            logging.info(message)
209
            self.network.train_logger.record(message)
210
        return costs
211
212
    def _run_valid(self, iteration, valid_set, dry_run=False):
213
        """
214
        Run one valid iteration, return true if to continue training.
215
        """
216
        costs = self.valid_step(valid_set)
217
        # this is the same as: (J_i - J_f) / J_i > min improvement
218
        _, J = costs[0]
219
        marker = ""
220
        if self.best_cost - J > self.best_cost * self.min_improvement:
221
            # save the best cost and parameters
222
            self.best_params = self.copy_params()
223
            marker = ' *'
224
            if not dry_run:
225
                self.best_cost = J
226
                self.best_iter = iteration
227
228
            if self.config.auto_save:
229
                self.network.train_logger.record_progress(self._progress)
230
                self.network.save_params(self.config.auto_save, new_thread=True)
231
232
        info = ' '.join('%s=%.2f' % el for el in costs)
233
        iter_str = "iter=%d" % (iteration + 1)
234
        if dry_run:
235
            iter_str = "dryrun" + " " * (len(iter_str) - 6)
236
        message = "valid   (%s) %s%s" % (iter_str, info, marker)
237
        logging.info(message)
238
        self.network.train_logger.record(message)
239
        self.checkpoint = self.copy_params()
240
        return iteration - self.best_iter < self.patience
241
242
    def test_step(self, test_set):
243
        self._compile_evaluation_func()
244
        costs = list(zip(
245
            self.evaluation_names,
246
            np.mean([self.evaluation_func(*x) for x in test_set], axis=0)))
247
        return costs
248
249
    def valid_step(self, valid_set):
250
        self._compile_evaluation_func()
251
        costs = list(zip(
252
            self.evaluation_names,
253
            np.mean([self.evaluation_func(*x) for x in valid_set], axis=0)))
254
        return costs
255
256
    def train_step(self, train_set, train_size=None):
257
        dirty_trick_times = 0
258
        network_callback = bool(self.network.training_callbacks)
259
        trainer_callback = bool(self._iter_callbacks)
260
        cost_matrix = []
261
        self._progress = 0
262
263
        for x in train_set:
264
            if self._skip_batches == 0:
265
266
                if dirty_trick_times > 0:
267
                    cost_x = self.learn(*[t[:(t.shape[0]/2)] for t in x])
268
                    cost_matrix.append(cost_x)
269
                    cost_x = self.learn(*[t[(t.shape[0]/2):] for t in x])
270
                    dirty_trick_times -= 1
271
                else:
272
                    try:
273
                        cost_x = self.learn(*x)
274
                    except MemoryError:
275
                        logging.info("Memory error was detected, perform dirty trick 30 times")
276
                        dirty_trick_times = 30
277
                        # Dirty trick
278
                        cost_x = self.learn(*[t[:(t.shape[0]/2)] for t in x])
279
                        cost_matrix.append(cost_x)
280
                        cost_x = self.learn(*[t[(t.shape[0]/2):] for t in x])
281
                cost_matrix.append(cost_x)
282
                self.last_cost = cost_x[0]
283
                if network_callback:
284
                    self.network.training_callback()
285
                if trainer_callback:
286
                    for func in self._iter_callbacks:
287
                        func(self)
288
            else:
289
                self._skip_batches -= 1
290
            if train_size:
291
                self._progress += 1
292
                sys.stdout.write("\x1b[2K\r> %d%% | J=%.2f" % (self._progress * 100 / train_size, self.last_cost))
293
                sys.stdout.flush()
294
        self._progress = 0
295
296
        if train_size:
297
            sys.stdout.write("\r")
298
            sys.stdout.flush()
299
        costs = list(zip(self.training_names, np.mean(cost_matrix, axis=0)))
300
        return costs
301
302
    def run(self, train_set, valid_set=None, test_set=None, train_size=None, controllers=None):
303
        """
304
        Run until the end.
305
        """
306
        if isinstance(train_set, Dataset):
307
            dataset = train_set
308
            train_set = dataset.train_set()
309
            valid_set = dataset.valid_set()
310
            test_set = dataset.test_set()
311
            train_size = dataset.train_size()
312
313
        timer = Timer()
314
        for _ in self.train(train_set, valid_set=valid_set, test_set=test_set, train_size=train_size):
315
            if controllers:
316
                ending = False
317
                for controller in controllers:
318
                    if hasattr(controller, 'invoke') and controller.invoke():
319
                        ending = True
320
                if ending:
321
                    break
322
        timer.report()
323
        return