deepy.trainers.CustomizeTrainer   A
last analyzed

Complexity

Total Complexity 16

Size/Duplication

Total Lines 57
Duplicated Lines 0 %
Metric Value
dl 0
loc 57
rs 10
wmc 16

3 Methods

Rating   Name   Duplication   Size   Complexity  
A train_func() 0 3 1
F train() 0 39 14
A __init__() 0 8 1
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import logging as loggers
5
from abc import ABCMeta, abstractmethod
6
7
from deepy.trainers import NeuralTrainer
8
9
10
logging = loggers.getLogger(__name__)
11
THEANO_LINKER = 'cvm'
12
13
class CustomizeTrainer(NeuralTrainer):
14
    '''This is a base class for all trainers.'''
15
    __metaclass__ = ABCMeta
16
17
    def __init__(self, network, config=None):
18
        """
19
        Basic neural network trainer.
20
        :type network: deepy.NeuralNetwork
21
        :type config: deepy.conf.TrainerConfig
22
        :return:
23
        """
24
        super(CustomizeTrainer, self).__init__(network, config)
25
26
27
    def train(self, train_set, valid_set=None, test_set=None, train_size=None):
28
        '''We train over mini-batches and evaluate periodically.'''
29
        iteration = 0
30
        while True:
31
            if not iteration % self.config.test_frequency and test_set:
32
                try:
33
                    self.test(iteration, test_set)
34
                except KeyboardInterrupt:
35
                    logging.info('interrupted!')
36
                    break
37
38
            if not iteration % self.validation_frequency and valid_set:
39
                try:
40
                    if not self.evaluate(iteration, valid_set):
41
                        logging.info('patience elapsed, bailing out')
42
                        break
43
                except KeyboardInterrupt:
44
                    logging.info('interrupted!')
45
                    break
46
47
            train_message = ""
48
            try:
49
                train_message = self.train_func(train_set)
50
            except KeyboardInterrupt:
51
                logging.info('interrupted!')
52
                break
53
            if not iteration % self.config.monitor_frequency:
54
                logging.info('monitor (iter=%i) %s', iteration + 1, train_message)
55
56
            iteration += 1
57
            if hasattr(self.network, "iteration_callback"):
58
                self.network.iteration_callback()
59
60
            yield train_message
61
62
        if valid_set:
63
            self.set_params(self.best_params)
64
        if test_set:
65
            self.test(0, test_set)
66
67
    @abstractmethod
68
    def train_func(self, train_set):
69
        return ""
70