Completed
Push — master ( bf2b0c...394368 )
by Raphael
01:28
created

RecurrentLayer.get_initial_states()   B

Complexity

Conditions 6

Size

Total Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 6
c 1
b 0
f 0
dl 0
loc 15
rs 8
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
from . import NeuralLayer
5
from neural_var import NeuralVariable
6
from deepy.utils import build_activation, FLOATX, XavierGlorotInitializer, OrthogonalInitializer, Scanner, neural_computation
7
import numpy as np
8
import theano.tensor as T
9
from abc import ABCMeta, abstractmethod
10
11
OUTPUT_TYPES = ["sequence", "one"]
12
INPUT_TYPES = ["sequence", "one"]
13
14
15
16
class RecurrentLayer(NeuralLayer):
17
    __metaclass__ = ABCMeta
18
19
    def __init__(self, name, state_names, hidden_size=100, input_type="sequence", output_type="sequence",
20
                 inner_init=None, outer_init=None,
21
                 gate_activation='sigmoid', activation='tanh',
22
                 steps=None, backward=False, mask=None,
23
                 additional_input_dims=None):
24
        super(RecurrentLayer, self).__init__(name)
25
        self.state_names = state_names
26
        self.main_state = state_names[0]
27
        self.hidden_size = hidden_size
28
        self._gate_activation = gate_activation
29
        self._activation = activation
30
        self.gate_activate = build_activation(self._gate_activation)
31
        self.activate = build_activation(self._activation)
32
        self._input_type = input_type
33
        self._output_type = output_type
34
        self.inner_init = inner_init if inner_init else OrthogonalInitializer()
35
        self.outer_init = outer_init if outer_init else XavierGlorotInitializer()
36
        self._steps = steps
37
        self._mask = mask.tensor if type(mask) == NeuralVariable else mask
38
        self._go_backwards = backward
39
        self.additional_input_dims = additional_input_dims if additional_input_dims else []
40
41
        if input_type not in INPUT_TYPES:
42
            raise Exception("Input type of {} is wrong: {}".format(name, input_type))
43
        if output_type not in OUTPUT_TYPES:
44
            raise Exception("Output type of {} is wrong: {}".format(name, output_type))
45
46
    @neural_computation
47
    def step(self, step_inputs):
48
        new_states = self.compute_new_state(step_inputs)
49
50
        # apply mask for each step if `output_type` is 'one'
51
        if step_inputs.get("mask"):
52
            mask = step_inputs["mask"].dimshuffle(0, 'x')
53
            for state_name in new_states:
54
                new_states[state_name] = new_states[state_name] * mask + step_inputs[state_name] * (1 - mask)
55
56
        return new_states
57
58
    @abstractmethod
59
    def compute_new_state(self, step_inputs):
60
        """
61
        :type step_inputs: dict
62
        :rtype: dict
63
        """
64
65
    @abstractmethod
66
    def merge_inputs(self, input_var, additional_inputs=None):
67
        """
68
        Merge inputs and return a map, which will be passed to core_step.
69
        :type input_var: T.var
70
        :param additional_inputs: list
71
        :rtype: dict
72
        """
73
74
    @abstractmethod
75
    def prepare(self):
76
        pass
77
78
    @neural_computation
79
    def compute_step(self, state, lstm_cell=None, input=None, additional_inputs=None):
80
        """
81
        Compute one step in the RNN.
82
        :return: one variable for RNN and GRU, multiple variables for LSTM
83
        """
84
        input_map = self.merge_inputs(input, additional_inputs=additional_inputs)
85
        input_map.update({"state": state, "lstm_cell": lstm_cell})
86
        output_map = self.compute_new_state(input_map)
87
        outputs = [output_map.pop("state")]
88
        outputs += output_map.values()
89
        if len(outputs) == 1:
90
            return outputs[0]
91
        else:
92
            return outputs
93
94
    @neural_computation
95
    def get_initial_states(self, input_var, init_state=None):
96
        """
97
        :type input_var: T.var
98
        :rtype: dict
99
        """
100
        initial_states = {}
101
        for state in self.state_names:
102
            if state != "state" or not init_state:
103
                if self._input_type == 'sequence' and input_var.ndim == 2:
104
                    init_state = T.alloc(np.cast[FLOATX](0.), self.hidden_size)
105
                else:
106
                    init_state = T.alloc(np.cast[FLOATX](0.), input_var.shape[0], self.hidden_size)
107
            initial_states[state] = init_state
108
        return initial_states
109
110
    @neural_computation
111
    def get_step_inputs(self, input_var, states=None, mask=None, additional_inputs=None):
112
        """
113
        :type input_var: T.var
114
        :rtype: dict
115
        """
116
        step_inputs = {}
117
        if self._input_type == "sequence":
118
            if not additional_inputs:
119
                additional_inputs = []
120
            if mask:
121
                step_inputs['mask'] = mask.dimshuffle(1, 0)
122
            step_inputs.update(self.merge_inputs(input_var, additional_inputs=additional_inputs))
123
        else:
124
            # step_inputs["mask"] = mask.dimshuffle((1,0)) if mask else None
125
            if additional_inputs:
126
                step_inputs.update(self.merge_inputs(None, additional_inputs=additional_inputs))
127
        if states:
128
            for name in self.state_names:
129
                step_inputs[name] = states[name]
130
131
        return step_inputs
132
133
    def compute(self, input_var, mask=None, additional_inputs=None, steps=None, backward=False, init_states=None, return_all_states=False):
134
        if additional_inputs and not self.additional_input_dims:
135
            self.additional_input_dims = map(lambda var: var.dim(), additional_inputs)
136
        result_var = super(RecurrentLayer, self).compute(input_var,
137
                                                   mask=mask, additional_inputs=additional_inputs, steps=steps, backward=backward, init_states=init_states, return_all_states=return_all_states)
138
        if return_all_states:
139
            state_map = {}
140
            for k in result_var.tensor:
141
                state_map[k] = NeuralVariable(result_var.tensor[k], result_var.test_tensor[k], self.output_dim)
142
            return state_map
143
        else:
144
            return result_var
145
146
    def compute_tensor(self, input_var, mask=None, additional_inputs=None, steps=None, backward=False, init_states=None, return_all_states=False):
147
        # prepare parameters
148
        backward = backward if backward else self._go_backwards
149
        steps = steps if steps else self._steps
150
        mask = mask if mask else self._mask
151
        if mask and self._input_type == "one":
152
            raise Exception("Mask only works with sequence input")
153
        # get initial states
154
        init_state_map = self.get_initial_states(input_var)
155
        if init_states:
156
            for name, val in init_states.items():
157
                if name in init_state_map:
158
                    init_state_map[name] = val
159
        # get input sequence map
160
        if self._input_type == "sequence":
161
            # Move middle dimension to left-most position
162
            # (sequence, batch, value)
163
            if input_var.ndim == 3:
164
                input_var = input_var.dimshuffle((1,0,2))
165
166
            seq_map = self.get_step_inputs(input_var, mask=mask, additional_inputs=additional_inputs)
167
        else:
168
            init_state_map[self.main_state] = input_var
169
            seq_map = self.get_step_inputs(None, mask=mask, additional_inputs=additional_inputs)
170
        # scan
171
        retval_map, _ = Scanner(
172
            self.step,
173
            sequences=seq_map,
174
            outputs_info=init_state_map,
175
            n_steps=steps,
176
            go_backwards=backward
177
        ).compute()
178
        # return main states
179
        main_states = retval_map[self.main_state]
180
        if self._output_type == "one":
181
            if return_all_states:
182
                return_map = {}
183
                for name, val in retval_map.items():
184
                    return_map[name] = val[-1]
185
                return return_map
186
            else:
187
                return main_states[-1]
188
        elif self._output_type == "sequence":
189
            if return_all_states:
190
                return_map = {}
191
                for name, val in retval_map.items():
192
                    return_map[name] = val.dimshuffle((1,0,2))
193
                return return_map
194
            else:
195
                main_states = main_states.dimshuffle((1,0,2)) # ~ batch, time, size
196
                # if mask: # ~ batch, time
197
                #     main_states *= mask.dimshuffle((0, 1, 'x'))
198
                return main_states
199
200
201
class RNN(RecurrentLayer):
202
203
    def  __init__(self, hidden_size, **kwargs):
204
        kwargs["hidden_size"] = hidden_size
205
        super(RNN, self).__init__("RNN", ["state"], **kwargs)
206
207
    @neural_computation
208
    def compute_new_state(self, step_inputs):
209
        xh_t, h_tm1 = map(step_inputs.get, ["xh_t", "state"])
210
        if not xh_t:
211
            xh_t = 0
212
213
        h_t = self.activate(xh_t + T.dot(h_tm1, self.W_h) + self.b_h)
214
215
        return {"state": h_t}
216
217
    @neural_computation
218
    def merge_inputs(self, input_var, additional_inputs=None):
219
        if not additional_inputs:
220
            additional_inputs = []
221
        all_inputs = ([input_var] if input_var else []) + additional_inputs
222
        h_inputs = []
223
        for x, weights in zip(all_inputs, self.input_weights):
224
            wi, = weights
225
            h_inputs.append(T.dot(x, wi))
226
        merged_inputs = {
227
            "xh_t": sum(h_inputs)
228
        }
229
        return merged_inputs
230
231
    def prepare(self):
232
        self.output_dim = self.hidden_size
233
234
        self.W_h = self.create_weight(self.hidden_size, self.hidden_size, "h", initializer=self.outer_init)
235
        self.b_h = self.create_bias(self.hidden_size, "h")
236
237
        self.register_parameters(self.W_h, self.b_h)
238
239
        self.input_weights = []
240
        if self._input_type == "sequence":
241
            normal_input_dims = [self.input_dim]
242
        else:
243
            normal_input_dims = []
244
245
        all_input_dims = normal_input_dims + self.additional_input_dims
246
        for i, input_dim in enumerate(all_input_dims):
247
            wi = self.create_weight(input_dim, self.hidden_size, "wi_{}".format(i+1), initializer=self.outer_init)
248
            weights = [wi]
249
            self.input_weights.append(weights)
250
            self.register_parameters(*weights)