Code Duplication    Length = 21-24 lines in 2 locations

deepy/trainers/delayed_trainers.py 1 location

@@ 25-48 (lines=24) @@
22
    Update parameters after N iterations.
23
    """
24
25
    def __init__(self, network, config=None, batch_size=20):
26
        """
27
        Create a SGD trainer.
28
        :type network:
29
        :type config: deepy.conf.TrainerConfig
30
        :return:
31
        """
32
        super(DelayedBatchSGDTrainer, self).__init__(network, config)
33
34
        self.learning_rate = self.config.learning_rate
35
        self.batch_size = batch_size
36
37
        logging.info('compiling %s learning function', self.__class__.__name__)
38
39
        network_updates = list(network.updates) + list(network._learning_updates)
40
        learning_updates = list(self.learning_updates())
41
        update_list = network_updates + learning_updates
42
        logging.info("network updates: %s" % " ".join(map(str, [x[0] for x in network_updates])))
43
        logging.info("learning updates: %s" % " ".join(map(str, [x[0] for x in learning_updates])))
44
45
        self.learning_func = theano.function(
46
            network.inputs,
47
            self.training_variables,
48
            updates=update_list, allow_input_downcast=True, mode=theano.Mode(linker=THEANO_LINKER))
49
50
51
    def learning_updates(self):

deepy/trainers/trainers.py 1 location

@@ 72-92 (lines=21) @@
69
        updates, free_parameters = optimize_updates(params, gradients, self.config)
70
        self.network.free_parameters.extend(free_parameters)
71
        logging.info("Added %d free parameters for optimization" % len(free_parameters))
72
        return updates
73
74
    def learning_function(self):
75
        """
76
        Get the learning function.
77
        :param func:
78
        :return:
79
        """
80
        network_updates = list(self.network.updates) + list(self.network.training_updates)
81
        learning_updates = list(self._learning_updates())
82
        update_list = network_updates + learning_updates
83
84
        logging.info("network updates: %s" % " ".join(map(str, [x[0] for x in network_updates])))
85
        logging.info("learning updates: %s" % " ".join(map(str, [x[0] for x in learning_updates])))
86
87
        variables = self.network.input_variables + self.network.target_variables
88
        givens = None
89
        return theano.function(
90
            variables,
91
            map(lambda v: theano.Out(v, borrow=True), self.training_variables),
92
            updates=update_list, allow_input_downcast=True,
93
            mode=self.config.get("theano_mode", None),
94
            givens=givens)
95