Completed
Push — master ( 5bbe2a...9d73f5 )
by Raphael
01:33
created

deepy.layers.RNN.compute_new_state()   A

Complexity

Conditions 2

Size

Total Lines 8

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 2
dl 0
loc 8
rs 9.4286
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
from . import NeuralLayer
5
from var import NeuralVar
6
from deepy.utils import build_activation, FLOATX, XavierGlorotInitializer, OrthogonalInitializer, Scanner
7
import numpy as np
8
import theano
9
import theano.tensor as T
10
from collections import OrderedDict
11
from abc import ABCMeta, abstractmethod
12
13
OUTPUT_TYPES = ["sequence", "one"]
14
INPUT_TYPES = ["sequence", "one"]
15
16
17
18
class RecurrentLayer(NeuralLayer):
19
    __metaclass__ = ABCMeta
20
21
    def __init__(self, name, state_names, hidden_size=100, input_type="sequence", output_type="sequence",
22
                 inner_init=None, outer_init=None,
23
                 gate_activation='sigmoid', activation='tanh',
24
                 steps=None, backward=False, mask=None,
25
                 additional_input_dims=None):
26
        super(RecurrentLayer, self).__init__(name)
27
        self.state_names = state_names
28
        self.main_state = state_names[0]
29
        self.hidden_size = hidden_size
30
        self._gate_activation = gate_activation
31
        self._activation = activation
32
        self.gate_activate = build_activation(self._gate_activation)
33
        self.activate = build_activation(self._activation)
34
        self._input_type = input_type
35
        self._output_type = output_type
36
        self.inner_init = inner_init if inner_init else OrthogonalInitializer()
37
        self.outer_init = outer_init if outer_init else XavierGlorotInitializer()
38
        self._steps = steps
39
        self._mask = mask.tensor if type(mask) == NeuralVar else mask
40
        self._go_backwards = backward
41
        self.additional_input_dims = additional_input_dims if additional_input_dims else []
42
43
        if input_type not in INPUT_TYPES:
44
            raise Exception("Input type of {} is wrong: {}".format(name, input_type))
45
        if output_type not in OUTPUT_TYPES:
46
            raise Exception("Output type of {} is wrong: {}".format(name, output_type))
47
48
    def step(self, step_inputs):
49
        new_states = self.compute_new_state(step_inputs)
50
51
        # apply mask for each step if `output_type` is 'one'
52
        if self._output_type == "one" and step_inputs.get("mask"):
53
            mask = step_inputs["mask"].dimshuffle(0, 'x')
54
            for state_name in new_states:
55
                new_states[state_name] = new_states[state_name] * mask + step_inputs[state_name] * (1 - mask)
56
57
        return new_states
58
59
    @abstractmethod
60
    def compute_new_state(self, step_inputs):
61
        """
62
        :type step_inputs: dict
63
        :rtype: dict
64
        """
65
66
    @abstractmethod
67
    def merge_inputs(self, input_var, additional_inputs=None):
68
        """
69
        Merge inputs and return a map, which will be passed to core_step.
70
        :type input_var: T.var
71
        :param additional_inputs: list
72
        :rtype: dict
73
        """
74
75
    @abstractmethod
76
    def prepare(self):
77
        pass
78
79
80
    def get_initial_states(self, input_var):
81
        """
82
        :type input_var: T.var
83
        :rtype: dict
84
        """
85
        initial_states = {}
86
        for state in self.state_names:
87
            initial_states[state] = T.alloc(np.cast[FLOATX](0.), input_var.shape[0], self.hidden_size)
88
        return initial_states
89
90
    def get_step_inputs(self, input_var, states=None, mask=None, additional_inputs=None):
91
        """
92
        :type input_var: T.var
93
        :rtype: dict
94
        """
95
        step_inputs = {}
96
        if self._input_type == "sequence":
97
            if not additional_inputs:
98
                additional_inputs = []
99
            step_inputs.update(self.merge_inputs(input_var, additional_inputs=additional_inputs))
100
        else:
101
            step_inputs["mask"] = mask.dimshuffle((1,0)) if mask else None
102
        if states:
103
            for name in self.state_names:
104
                step_inputs[name] = states[name]
105
106
        return step_inputs
107
108
    def compute(self, input_var, mask=None, additional_inputs=None, steps=None, backward=False):
109
        if additional_inputs and not self.additional_input_dims:
110
            self.additional_input_dims = map(lambda var: var.dim(), additional_inputs)
111
        return super(RecurrentLayer, self).compute(input_var, mask=mask, additional_inputs=additional_inputs, steps=steps, backward=backward)
112
113
    def output(self, input_var, mask=None, additional_inputs=None, steps=None, backward=False):
114
        # prepare parameters
115
        backward = backward if backward else self._go_backwards
116
        steps = steps if steps else self._steps
117
        mask = mask if mask else self._mask
118
        if mask and self._input_type == "one":
119
            raise Exception("Mask only works with sequence input")
120
        # get initial states
121
        init_state_map = self.get_initial_states(input_var)
122
        # get input sequence map
123
        if self._input_type == "sequence":
124
            # Move middle dimension to left-most position
125
            # (sequence, batch, value)
126
            input_var = input_var.dimshuffle((1,0,2))
127
            seq_map = self.get_step_inputs(input_var, mask=mask, additional_inputs=additional_inputs)
128
        else:
129
            init_state_map[self.main_state] = input_var
130
            seq_map = self.get_step_inputs(None)
131
        # scan
132
        retval_map, _ = Scanner(
133
            self.step,
134
            sequences=seq_map,
135
            outputs_info=init_state_map,
136
            n_steps=steps,
137
            go_backwards=backward
138
        ).compute()
139
        # return main states
140
        main_states = retval_map[self.main_state]
141
        if self._output_type == "one":
142
            return main_states[-1]
143
        elif self._output_type == "sequence":
144
            if mask: # ~ batch, time
145
                main_states = main_states.dimshuffle((1,0,2)) # ~ batch, time, size
146
                main_states *= mask.dimshuffle((0, 1, 'x'))
147
            return main_states
148
149
150
class RNN(RecurrentLayer):
151
152
    def __init__(self, hidden_size, **kwargs):
153
        kwargs["hidden_size"] = hidden_size
154
        super(RNN, self).__init__("RNN", ["state"], **kwargs)
155
156
    def compute_new_state(self, step_inputs):
157
        xh_t, h_tm1 = map(step_inputs.get, ["xh_t", "state"])
158
        if not xh_t:
159
            xh_t = 0
160
161
        h_t = self.activate(xh_t + T.dot(h_tm1, self.W_h) + self.b_h)
162
163
        return {"state": h_t}
164
165
    def merge_inputs(self, input_var, additional_inputs=None):
166
        if not additional_inputs:
167
            additional_inputs = []
168
        all_inputs = [input_var] + additional_inputs
169
        h_inputs = []
170
        for x, weights in zip(all_inputs, self.input_weights):
171
            wi, = weights
172
            h_inputs.append(T.dot(x, wi))
173
        merged_inputs = {
174
            "xh_t": sum(h_inputs)
175
        }
176
        return merged_inputs
177
178
    def prepare(self):
179
        self.output_dim = self.hidden_size
180
181
        self.W_h = self.create_weight(self.hidden_size, self.hidden_size, "h", initializer=self.inner_init)
182
        self.b_h = self.create_bias(self.hidden_size, "h")
183
184
        self.register_parameters(self.W_h, self.b_h)
185
186
        self.input_weights = []
187
        if self._input_type == "sequence":
188
            all_input_dims = [self.input_dim] + self.additional_input_dims
189
            for i, input_dim in enumerate(all_input_dims):
190
                wi = self.create_weight(input_dim, self.hidden_size, "wi_{}".format(i+1), initializer=self.outer_init)
191
                weights = [wi]
192
                self.input_weights.append(weights)
193
                self.register_parameters(*weights)