| @@ 25-48 (lines=24) @@ | ||
| 22 | Update parameters after N iterations. |
|
| 23 | """ |
|
| 24 | ||
| 25 | def __init__(self, network, config=None, batch_size=20): |
|
| 26 | """ |
|
| 27 | Create a SGD trainer. |
|
| 28 | :type network: |
|
| 29 | :type config: deepy.conf.TrainerConfig |
|
| 30 | :return: |
|
| 31 | """ |
|
| 32 | super(DelayedBatchSGDTrainer, self).__init__(network, config) |
|
| 33 | ||
| 34 | self.learning_rate = self.config.learning_rate |
|
| 35 | self.batch_size = batch_size |
|
| 36 | ||
| 37 | logging.info('compiling %s learning function', self.__class__.__name__) |
|
| 38 | ||
| 39 | network_updates = list(network.updates) + list(network._learning_updates) |
|
| 40 | learning_updates = list(self.learning_updates()) |
|
| 41 | update_list = network_updates + learning_updates |
|
| 42 | logging.info("network updates: %s" % " ".join(map(str, [x[0] for x in network_updates]))) |
|
| 43 | logging.info("learning updates: %s" % " ".join(map(str, [x[0] for x in learning_updates]))) |
|
| 44 | ||
| 45 | self.learning_func = theano.function( |
|
| 46 | network.inputs, |
|
| 47 | self.training_variables, |
|
| 48 | updates=update_list, allow_input_downcast=True, mode=theano.Mode(linker=THEANO_LINKER)) |
|
| 49 | ||
| 50 | ||
| 51 | def learning_updates(self): |
|
| @@ 72-92 (lines=21) @@ | ||
| 69 | updates, free_parameters = optimize_updates(params, gradients, self.config) |
|
| 70 | self.network.free_parameters.extend(free_parameters) |
|
| 71 | logging.info("Added %d free parameters for optimization" % len(free_parameters)) |
|
| 72 | return updates |
|
| 73 | ||
| 74 | def learning_function(self): |
|
| 75 | """ |
|
| 76 | Get the learning function. |
|
| 77 | :param func: |
|
| 78 | :return: |
|
| 79 | """ |
|
| 80 | network_updates = list(self.network.updates) + list(self.network.training_updates) |
|
| 81 | learning_updates = list(self._learning_updates()) |
|
| 82 | update_list = network_updates + learning_updates |
|
| 83 | ||
| 84 | logging.info("network updates: %s" % " ".join(map(str, [x[0] for x in network_updates]))) |
|
| 85 | logging.info("learning updates: %s" % " ".join(map(str, [x[0] for x in learning_updates]))) |
|
| 86 | ||
| 87 | variables = self.network.input_variables + self.network.target_variables |
|
| 88 | givens = None |
|
| 89 | return theano.function( |
|
| 90 | variables, |
|
| 91 | map(lambda v: theano.Out(v, borrow=True), self.training_variables), |
|
| 92 | updates=update_list, allow_input_downcast=True, |
|
| 93 | mode=self.config.get("theano_mode", None), |
|
| 94 | givens=givens) |
|
| 95 | ||