deepy.trainers.CustomizeTrainer.train()   F
last analyzed

Complexity

Conditions 14

Size

Total Lines 39

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 14
dl 0
loc 39
rs 2.7582

How to fix   Complexity   

Complexity

Complex classes like deepy.trainers.CustomizeTrainer.train() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import logging as loggers
5
from abc import ABCMeta, abstractmethod
6
7
from deepy.trainers import NeuralTrainer
8
9
10
logging = loggers.getLogger(__name__)
11
THEANO_LINKER = 'cvm'
12
13
class CustomizeTrainer(NeuralTrainer):
14
    '''This is a base class for all trainers.'''
15
    __metaclass__ = ABCMeta
16
17
    def __init__(self, network, config=None):
18
        """
19
        Basic neural network trainer.
20
        :type network: deepy.NeuralNetwork
21
        :type config: deepy.conf.TrainerConfig
22
        :return:
23
        """
24
        super(CustomizeTrainer, self).__init__(network, config)
25
26
27
    def train(self, train_set, valid_set=None, test_set=None, train_size=None):
28
        '''We train over mini-batches and evaluate periodically.'''
29
        iteration = 0
30
        while True:
31
            if not iteration % self.config.test_frequency and test_set:
32
                try:
33
                    self.test(iteration, test_set)
34
                except KeyboardInterrupt:
35
                    logging.info('interrupted!')
36
                    break
37
38
            if not iteration % self.validation_frequency and valid_set:
39
                try:
40
                    if not self.evaluate(iteration, valid_set):
41
                        logging.info('patience elapsed, bailing out')
42
                        break
43
                except KeyboardInterrupt:
44
                    logging.info('interrupted!')
45
                    break
46
47
            train_message = ""
48
            try:
49
                train_message = self.train_func(train_set)
50
            except KeyboardInterrupt:
51
                logging.info('interrupted!')
52
                break
53
            if not iteration % self.config.monitor_frequency:
54
                logging.info('monitor (iter=%i) %s', iteration + 1, train_message)
55
56
            iteration += 1
57
            if hasattr(self.network, "iteration_callback"):
58
                self.network.iteration_callback()
59
60
            yield train_message
61
62
        if valid_set:
63
            self.set_params(self.best_params)
64
        if test_set:
65
            self.test(0, test_set)
66
67
    @abstractmethod
68
    def train_func(self, train_set):
69
        return ""
70