Completed
Push — master ( 394368...090fba )
by Raphael
01:33
created

RecurrentLayer.compute_step()   B

Complexity

Conditions 6

Size

Total Lines 23

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 6
c 0
b 0
f 0
dl 0
loc 23
rs 7.6949
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
from abc import ABCMeta, abstractmethod
5
6
import numpy as np
7
import theano.tensor as T
8
9
from deepy.core.tensor_conversion import neural_computation
10
from deepy.utils import XavierGlorotInitializer, OrthogonalInitializer, Scanner
11
from deepy.core import env
12
import deepy.tensor as DT
13
from . import NeuralLayer
14
15
OUTPUT_TYPES = ["sequence", "one"]
16
INPUT_TYPES = ["sequence", "one"]
17
18
19
20
class RecurrentLayer(NeuralLayer):
21
    __metaclass__ = ABCMeta
22
23
    def __init__(self, name, state_names, hidden_size=100, input_type="sequence", output_type="sequence",
24
                 inner_init=None, outer_init=None,
25
                 gate_activation='sigmoid', activation='tanh',
26
                 steps=None, backward=False, mask=None,
27
                 additional_input_dims=None):
28
        from deepy.core.neural_var import NeuralVariable
29
        from deepy.tensor.activations import get_activation
30
        super(RecurrentLayer, self).__init__(name)
31
        self.state_names = state_names
32
        self.main_state = state_names[0]
33
        self.hidden_size = hidden_size
34
        self._gate_activation = gate_activation
35
        self._activation = activation
36
        self.gate_activate = get_activation(self._gate_activation)
37
        self.activate = get_activation(self._activation)
38
        self._input_type = input_type
39
        self._output_type = output_type
40
        self.inner_init = inner_init if inner_init else OrthogonalInitializer()
41
        self.outer_init = outer_init if outer_init else XavierGlorotInitializer()
42
        self._steps = steps
43
        self._mask = mask.tensor if type(mask) == NeuralVariable else mask
44
        self._go_backwards = backward
45
        self.additional_input_dims = additional_input_dims if additional_input_dims else []
46
47
        if input_type not in INPUT_TYPES:
48
            raise Exception("Input type of {} is wrong: {}".format(name, input_type))
49
        if output_type not in OUTPUT_TYPES:
50
            raise Exception("Output type of {} is wrong: {}".format(name, output_type))
51
52
    @neural_computation
53
    def step(self, step_inputs):
54
        new_states = self.compute_new_state(step_inputs)
55
56
        # apply mask for each step if `output_type` is 'one'
57
        if step_inputs.get("mask"):
58
            mask = step_inputs["mask"].dimshuffle(0, 'x')
59
            for state_name in new_states:
60
                new_states[state_name] = new_states[state_name] * mask + step_inputs[state_name] * (1 - mask)
61
62
        return new_states
63
64
    @abstractmethod
65
    def compute_new_state(self, step_inputs):
66
        """
67
        :type step_inputs: dict
68
        :rtype: dict
69
        """
70
71
    @abstractmethod
72
    def merge_inputs(self, input_var, additional_inputs=None):
73
        """
74
        Merge inputs and return a map, which will be passed to core_step.
75
        :type input_var: T.var
76
        :param additional_inputs: list
77
        :rtype: dict
78
        """
79
80
    @abstractmethod
81
    def prepare(self):
82
        pass
83
84
    @neural_computation
85
    def compute_step(self, state, lstm_cell=None, input=None, additional_inputs=None):
86
        """
87
        Compute one step in the RNN.
88
        :return: one variable for RNN and GRU, multiple variables for LSTM
89
        """
90
        if not self.initialized:
91
            input_dim = None
92
            if input and hasattr(input.tag, 'last_dim'):
93
                input_dim = input.tag.last_dim
94
            self.init(input_dim)
95
96
        input_map = self.merge_inputs(input, additional_inputs=additional_inputs)
97
        input_map.update({"state": state, "lstm_cell": lstm_cell})
98
        output_map = self.compute_new_state(input_map)
99
        outputs = [output_map.pop("state")]
100
        outputs += output_map.values()
101
        for tensor in outputs:
102
            tensor.tag.last_dim = self.hidden_size
103
        if len(outputs) == 1:
104
            return outputs[0]
105
        else:
106
            return outputs
107
108
    @neural_computation
109
    def get_initial_states(self, input_var, init_state=None):
110
        """
111
        :type input_var: T.var
112
        :rtype: dict
113
        """
114
        initial_states = {}
115
        for state in self.state_names:
116
            if state != "state" or not init_state:
117
                if self._input_type == 'sequence' and input_var.ndim == 2:
118
                    init_state = T.alloc(np.cast[env.FLOATX](0.), self.hidden_size)
119
                else:
120
                    init_state = T.alloc(np.cast[env.FLOATX](0.), input_var.shape[0], self.hidden_size)
121
            initial_states[state] = init_state
122
        return initial_states
123
124
    @neural_computation
125
    def get_step_inputs(self, input_var, states=None, mask=None, additional_inputs=None):
126
        """
127
        :type input_var: T.var
128
        :rtype: dict
129
        """
130
        step_inputs = {}
131
        if self._input_type == "sequence":
132
            if not additional_inputs:
133
                additional_inputs = []
134
            if mask:
135
                step_inputs['mask'] = mask.dimshuffle(1, 0)
136
            step_inputs.update(self.merge_inputs(input_var, additional_inputs=additional_inputs))
137
        else:
138
            # step_inputs["mask"] = mask.dimshuffle((1,0)) if mask else None
139
            if additional_inputs:
140
                step_inputs.update(self.merge_inputs(None, additional_inputs=additional_inputs))
141
        if states:
142
            for name in self.state_names:
143
                step_inputs[name] = states[name]
144
145
        return step_inputs
146
147
    def compute(self, input_var, mask=None, additional_inputs=None, steps=None, backward=False, init_states=None, return_all_states=False):
148
        from deepy.core.neural_var import NeuralVariable
149
        if additional_inputs and not self.additional_input_dims:
150
            self.additional_input_dims = map(lambda var: var.dim(), additional_inputs)
151
        result_var = super(RecurrentLayer, self).compute(input_var,
152
                                                   mask=mask, additional_inputs=additional_inputs, steps=steps, backward=backward, init_states=init_states, return_all_states=return_all_states)
153
        if return_all_states:
154
            state_map = {}
155
            for k in result_var.tensor:
156
                state_map[k] = NeuralVariable(result_var.tensor[k], self.output_dim)
157
            return state_map
158
        else:
159
            return result_var
160
161
    def compute_tensor(self, input_var, mask=None, additional_inputs=None, steps=None, backward=False, init_states=None, return_all_states=False):
162
        # prepare parameters
163
        backward = backward if backward else self._go_backwards
164
        steps = steps if steps else self._steps
165
        mask = mask if mask else self._mask
166
        if mask and self._input_type == "one":
167
            raise Exception("Mask only works with sequence input")
168
        # get initial states
169
        init_state_map = self.get_initial_states(input_var)
170
        if init_states:
171
            for name, val in init_states.items():
172
                if name in init_state_map:
173
                    init_state_map[name] = val
174
        # get input sequence map
175
        if self._input_type == "sequence":
176
            # Move middle dimension to left-most position
177
            # (sequence, batch, value)
178
            if input_var.ndim == 3:
179
                input_var = input_var.dimshuffle((1,0,2))
180
181
            seq_map = self.get_step_inputs(input_var, mask=mask, additional_inputs=additional_inputs)
182
        else:
183
            init_state_map[self.main_state] = input_var
184
            seq_map = self.get_step_inputs(None, mask=mask, additional_inputs=additional_inputs)
185
        # scan
186
        retval_map, _ = Scanner(
187
            self.step,
188
            sequences=seq_map,
189
            outputs_info=init_state_map,
190
            n_steps=steps,
191
            go_backwards=backward
192
        ).compute()
193
        # return main states
194
        main_states = retval_map[self.main_state]
195
        if self._output_type == "one":
196
            if return_all_states:
197
                return_map = {}
198
                for name, val in retval_map.items():
199
                    return_map[name] = val[-1]
200
                return return_map
201
            else:
202
                return main_states[-1]
203
        elif self._output_type == "sequence":
204
            if return_all_states:
205
                return_map = {}
206
                for name, val in retval_map.items():
207
                    return_map[name] = val.dimshuffle((1,0,2))
208
                return return_map
209
            else:
210
                main_states = main_states.dimshuffle((1,0,2)) # ~ batch, time, size
211
                # if mask: # ~ batch, time
212
                #     main_states *= mask.dimshuffle((0, 1, 'x'))
213
                return main_states
214
215
216
class RNN(RecurrentLayer):
217
218
    def  __init__(self, hidden_size, **kwargs):
219
        kwargs["hidden_size"] = hidden_size
220
        super(RNN, self).__init__("RNN", ["state"], **kwargs)
221
222
    @neural_computation
223
    def compute_new_state(self, step_inputs):
224
        xh_t, h_tm1 = map(step_inputs.get, ["xh_t", "state"])
225
        if not xh_t:
226
            xh_t = 0
227
228
        h_t = self.activate(xh_t + T.dot(h_tm1, self.W_h) + self.b_h)
229
230
        return {"state": h_t}
231
232
    @neural_computation
233
    def merge_inputs(self, input_var, additional_inputs=None):
234
        if not additional_inputs:
235
            additional_inputs = []
236
        all_inputs = ([input_var] if input_var else []) + additional_inputs
237
        h_inputs = []
238
        for x, weights in zip(all_inputs, self.input_weights):
239
            wi, = weights
240
            h_inputs.append(T.dot(x, wi))
241
        merged_inputs = {
242
            "xh_t": sum(h_inputs)
243
        }
244
        return merged_inputs
245
246
    def prepare(self):
247
        self.output_dim = self.hidden_size
248
249
        self.W_h = self.create_weight(self.hidden_size, self.hidden_size, "h", initializer=self.outer_init)
250
        self.b_h = self.create_bias(self.hidden_size, "h")
251
252
        self.register_parameters(self.W_h, self.b_h)
253
254
        self.input_weights = []
255
        if self._input_type == "sequence":
256
            normal_input_dims = [self.input_dim]
257
        else:
258
            normal_input_dims = []
259
260
        all_input_dims = normal_input_dims + self.additional_input_dims
261
        for i, input_dim in enumerate(all_input_dims):
262
            wi = self.create_weight(input_dim, self.hidden_size, "wi_{}".format(i+1), initializer=self.outer_init)
263
            weights = [wi]
264
            self.input_weights.append(weights)
265
            self.register_parameters(*weights)