CustomizeTrainer.train()   F
last analyzed

Complexity

Conditions 14

Size

Total Lines 39

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 14
dl 0
loc 39
rs 2.7581
c 0
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like CustomizeTrainer.train() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import logging as loggers
5
from abc import ABCMeta, abstractmethod
6
from deepy.trainers.base import NeuralTrainer
7
8
logging = loggers.getLogger(__name__)
9
THEANO_LINKER = 'cvm'
10
11
class CustomizeTrainer(NeuralTrainer):
12
    """
13
    DEPRECATED !!!
14
    A customized trainer.
15
    """
16
    __metaclass__ = ABCMeta
17
18
    def __init__(self, network, config=None):
19
        """
20
        Basic neural network trainer.
21
        :type network: deepy.NeuralNetwork
22
        :type config: deepy.conf.TrainerConfig
23
        :return:
24
        """
25
        super(CustomizeTrainer, self).__init__(network, config)
26
27
28
    def train(self, train_set, valid_set=None, test_set=None, train_size=None):
29
        '''We train over mini-batches and evaluate periodically.'''
30
        iteration = 0
31
        while True:
32
            if not iteration % self.config.test_frequency and test_set:
33
                try:
34
                    self.test(iteration, test_set)
35
                except KeyboardInterrupt:
36
                    logging.info('interrupted!')
37
                    break
38
39
            if not iteration % self.validation_frequency and valid_set:
40
                try:
41
                    if not self.evaluate(iteration, valid_set):
42
                        logging.info('patience elapsed, bailing out')
43
                        break
44
                except KeyboardInterrupt:
45
                    logging.info('interrupted!')
46
                    break
47
48
            train_message = ""
49
            try:
50
                train_message = self.train_func(train_set)
51
            except KeyboardInterrupt:
52
                logging.info('interrupted!')
53
                break
54
            if not iteration % self.config.monitor_frequency:
55
                logging.info('monitor (iter=%i) %s', iteration + 1, train_message)
56
57
            iteration += 1
58
            if hasattr(self.network, "iteration_callback"):
59
                self.network.iteration_callback()
60
61
            yield train_message
62
63
        if valid_set:
64
            self.set_params(self.best_params)
65
        if test_set:
66
            self.test(0, test_set)
67
68
    @abstractmethod
69
    def train_func(self, train_set):
70
        return ""
71