CustomizeTrainer   A
last analyzed

Complexity

Total Complexity 16

Size/Duplication

Total Lines 60
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
dl 0
loc 60
rs 10
c 0
b 0
f 0
wmc 16

3 Methods

Rating   Name   Duplication   Size   Complexity  
F train() 0 39 14
A train_func() 0 3 1
A __init__() 0 8 1
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import logging as loggers
5
from abc import ABCMeta, abstractmethod
6
from deepy.trainers.base import NeuralTrainer
7
8
logging = loggers.getLogger(__name__)
9
THEANO_LINKER = 'cvm'
10
11
class CustomizeTrainer(NeuralTrainer):
12
    """
13
    DEPRECATED !!!
14
    A customized trainer.
15
    """
16
    __metaclass__ = ABCMeta
17
18
    def __init__(self, network, config=None):
19
        """
20
        Basic neural network trainer.
21
        :type network: deepy.NeuralNetwork
22
        :type config: deepy.conf.TrainerConfig
23
        :return:
24
        """
25
        super(CustomizeTrainer, self).__init__(network, config)
26
27
28
    def train(self, train_set, valid_set=None, test_set=None, train_size=None):
29
        '''We train over mini-batches and evaluate periodically.'''
30
        iteration = 0
31
        while True:
32
            if not iteration % self.config.test_frequency and test_set:
33
                try:
34
                    self.test(iteration, test_set)
35
                except KeyboardInterrupt:
36
                    logging.info('interrupted!')
37
                    break
38
39
            if not iteration % self.validation_frequency and valid_set:
40
                try:
41
                    if not self.evaluate(iteration, valid_set):
42
                        logging.info('patience elapsed, bailing out')
43
                        break
44
                except KeyboardInterrupt:
45
                    logging.info('interrupted!')
46
                    break
47
48
            train_message = ""
49
            try:
50
                train_message = self.train_func(train_set)
51
            except KeyboardInterrupt:
52
                logging.info('interrupted!')
53
                break
54
            if not iteration % self.config.monitor_frequency:
55
                logging.info('monitor (iter=%i) %s', iteration + 1, train_message)
56
57
            iteration += 1
58
            if hasattr(self.network, "iteration_callback"):
59
                self.network.iteration_callback()
60
61
            yield train_message
62
63
        if valid_set:
64
            self.set_params(self.best_params)
65
        if test_set:
66
            self.test(0, test_set)
67
68
    @abstractmethod
69
    def train_func(self, train_set):
70
        return ""
71