Completed
Push — master ( f73e69...91b7c0 )
by Raphael
01:35
created

LSTM   B

Complexity

Total Complexity 38

Size/Duplication

Total Lines 209
Duplicated Lines 0 %

Importance

Changes 7
Bugs 0 Features 0
Metric Value
dl 0
loc 209
rs 8.3999
c 7
b 0
f 0
wmc 38

9 Methods

Rating   Name   Duplication   Size   Complexity  
D produce_input_sequences() 0 28 8
A _auto_reset_memories() 0 5 1
B _setup_params() 0 48 5
A produce_initial_states() 0 7 2
B compute_tensor() 0 28 5
F __init__() 0 41 10
B step() 0 32 5
A _setup_functions() 0 3 1
A prepare() 0 3 1
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
from . import NeuralLayer
5
from var import NeuralVariable
6
from deepy.utils import build_activation, FLOATX
7
import numpy as np
8
import theano
9
import theano.tensor as T
10
from collections import OrderedDict
11
12
OUTPUT_TYPES = ["sequence", "one"]
13
INPUT_TYPES = ["sequence", "one"]
14
15
class LSTM(NeuralLayer):
16
    """
17
    Long short-term memory layer.
18
    """
19
20
    def __init__(self, hidden_size, input_type="sequence", output_type="sequence",
21
                 inner_activation="sigmoid", outer_activation="tanh",
22
                 inner_init=None, outer_init=None, steps=None,
23
                 go_backwards=False,
24
                 persistent_state=False, batch_size=0,
25
                 reset_state_for_input=None, forget_bias=1,
26
                 mask=None,
27
                 second_input=None, second_input_size=None):
28
        super(LSTM, self).__init__("lstm")
29
        self._hidden_size = hidden_size
30
        self._input_type = input_type
31
        self._output_type = output_type
32
        self._inner_activation = inner_activation
33
        self._outer_activation = outer_activation
34
        self._inner_init = inner_init
35
        self._outer_init = outer_init
36
        self._steps = steps
37
        self.persistent_state = persistent_state
38
        self.reset_state_for_input = reset_state_for_input
39
        self.batch_size = batch_size
40
        self.go_backwards = go_backwards
41
        # mask
42
        mask = mask.tensor if type(mask) == NeuralVariable else mask
43
        self.mask = mask.dimshuffle((1,0)) if mask else None
44
        self._sequence_map = OrderedDict()
45
        # second input
46
        if type(second_input) == NeuralVariable:
47
            second_input_size = second_input.dim()
48
            second_input = second_input.tensor
49
50
        self.second_input = second_input
51
        self.second_input_size = second_input_size
52
        self.forget_bias = forget_bias
53
        if input_type not in INPUT_TYPES:
54
            raise Exception("Input type of LSTM is wrong: %s" % input_type)
55
        if output_type not in OUTPUT_TYPES:
56
            raise Exception("Output type of LSTM is wrong: %s" % output_type)
57
        if self.persistent_state and not self.batch_size:
58
            raise Exception("Batch size must be set for persistent state mode")
59
        if mask and input_type == "one":
60
            raise Exception("Mask only works with sequence input")
61
62
    def _auto_reset_memories(self, x, h, m):
63
        reset_matrix = T.neq(x[:, self.reset_state_for_input], 1).dimshuffle(0, 'x')
64
        h = h * reset_matrix
65
        m = m * reset_matrix
66
        return h, m
67
68
    def step(self, *vars):
69
        # Parse sequence
70
        sequence_map = dict(zip(self._sequence_map.keys(), vars[:len(self._sequence_map)]))
71
        h_tm1, c_tm1 = vars[-2:]
72
        # Reset state
73
        if self.reset_state_for_input != None:
74
            h_tm1, c_tm1 = self._auto_reset_memories(sequence_map["x"], h_tm1, c_tm1)
75
76
        if self._input_type == "sequence":
77
            xi_t, xf_t, xo_t, xc_t = map(sequence_map.get, ["xi", "xf", "xo", "xc"])
78
        else:
79
            xi_t, xf_t, xo_t, xc_t = 0, 0, 0, 0
80
81
        # Add second input
82
        if "xi2" in sequence_map:
83
            xi2, xf2, xo2, xc2 = map(sequence_map.get, ["xi2", "xf2", "xo2", "xc2"])
84
            xi_t += xi2
85
            xf_t += xf2
86
            xo_t += xo2
87
            xc_t += xc2
88
        # LSTM core step
89
        i_t = self._inner_act(xi_t + T.dot(h_tm1, self.U_i) + self.b_i)
90
        f_t = self._inner_act(xf_t + T.dot(h_tm1, self.U_f) + self.b_f)
91
        c_t = f_t * c_tm1 + i_t * self._outer_act(xc_t + T.dot(h_tm1, self.U_c) + self.b_c)
92
        o_t = self._inner_act(xo_t + T.dot(h_tm1, self.U_o) + self.b_o)
93
        h_t = o_t * self._outer_act(c_t)
94
        # Apply mask
95
        if "mask" in sequence_map:
96
            mask = sequence_map["mask"].dimshuffle(0, 'x')
97
            h_t = h_t * mask + h_tm1 * (1 - mask)
98
            c_t = c_t * mask + c_tm1 * (1 - mask)
99
        return h_t, c_t
100
101
    def produce_input_sequences(self, x, mask=None, second_input=None):
102
        # Create sequence map
103
        self._sequence_map.clear()
104
        if self._input_type == "sequence":
105
            # Input vars
106
            xi = T.dot(x, self.W_i)
107
            xf = T.dot(x, self.W_f)
108
            xc = T.dot(x, self.W_c)
109
            xo = T.dot(x, self.W_o)
110
            self._sequence_map.update([("xi", xi), ("xf", xf), ("xc", xc), ("xo", xo)])
111
        # Reset state
112
        if self.reset_state_for_input != None:
113
            self._sequence_map["x"] = x
114
        # Add mask
115
        if mask:
116
            self._sequence_map["mask"] = mask
117
        elif self.mask:
118
            self._sequence_map["mask"] = self.mask
119
        # Add second input
120
        if self.second_input and not second_input:
121
            second_input = self.second_input
122
        if second_input:
123
            xi2 = T.dot(second_input, self.W_i2)
124
            xf2 = T.dot(second_input, self.W_f2)
125
            xc2 = T.dot(second_input, self.W_c2)
126
            xo2 = T.dot(second_input, self.W_o2)
127
            self._sequence_map.update([("xi2", xi2), ("xf2", xf2), ("xc2", xc2), ("xo2", xo2)])
128
        return self._sequence_map.values()
129
130
    def produce_initial_states(self, x):
131
        if self.persistent_state:
132
            return self.state_h, self.state_m
133
        else:
134
            h0 = T.alloc(np.cast[FLOATX](0.), x.shape[0], self._hidden_size)
135
            m0 = h0
136
            return h0, m0
137
138
    def compute_tensor(self, x):
139
        h0, m0 = self.produce_initial_states(x)
140
        if self._input_type == "sequence":
141
            # Move middle dimension to left-most position
142
            # (sequence, batch, value)
143
            x = x.dimshuffle((1,0,2))
144
            sequences = self.produce_input_sequences(x)
145
        else:
146
            h0 = x
147
            sequences = self.produce_input_sequences(None)
148
149
        [hiddens, memories], _ = theano.scan(
150
            self.step,
151
            sequences=sequences,
152
            outputs_info=[h0, m0],
153
            n_steps=self._steps,
154
            go_backwards=self.go_backwards
155
        )
156
157
        # Save persistent state
158
        if self.persistent_state:
159
            self.register_updates((self.state_h, hiddens[-1]))
160
            self.register_updates((self.state_m, memories[-1]))
161
162
        if self._output_type == "one":
163
            return hiddens[-1]
164
        elif self._output_type == "sequence":
165
            return hiddens.dimshuffle((1,0,2))
166
167
168
    def prepare(self):
169
        self._setup_params()
170
        self._setup_functions()
171
172
    def _setup_functions(self):
173
        self._inner_act = build_activation(self._inner_activation)
174
        self._outer_act = build_activation(self._outer_activation)
175
176
    def _setup_params(self):
177
        self.output_dim = self._hidden_size
178
179
        self.W_i = self.create_weight(self.input_dim, self._hidden_size, "wi", initializer=self._outer_init)
180
        self.U_i = self.create_weight(self._hidden_size, self._hidden_size, "ui", initializer=self._inner_init)
181
        self.b_i = self.create_bias(self._hidden_size, "i")
182
183
        self.W_f = self.create_weight(self.input_dim, self._hidden_size, "wf", initializer=self._outer_init)
184
        self.U_f = self.create_weight(self._hidden_size, self._hidden_size, "uf", initializer=self._inner_init)
185
        self.b_f = self.create_bias(self._hidden_size, "f")
186
        if self.forget_bias > 0:
187
            self.b_f.set_value(np.ones((self._hidden_size,), dtype=FLOATX))
188
189
        self.W_c = self.create_weight(self.input_dim, self._hidden_size, "wc", initializer=self._outer_init)
190
        self.U_c = self.create_weight(self._hidden_size, self._hidden_size, "uc", initializer=self._inner_init)
191
        self.b_c = self.create_bias(self._hidden_size, "c")
192
193
        self.W_o = self.create_weight(self.input_dim, self._hidden_size, "wo", initializer=self._outer_init)
194
        self.U_o = self.create_weight(self._hidden_size, self._hidden_size, "uo", initializer=self._inner_init)
195
        self.b_o = self.create_bias(self._hidden_size, suffix="o")
196
197
198
        if self._input_type == "sequence":
199
            self.register_parameters(self.W_i, self.U_i, self.b_i,
200
                                     self.W_c, self.U_c, self.b_c,
201
                                     self.W_f, self.U_f, self.b_f,
202
                                     self.W_o, self.U_o, self.b_o)
203
        else:
204
            self.register_parameters(self.U_i, self.b_i,
205
                                     self.U_c, self.b_c,
206
                                     self.U_f, self.b_f,
207
                                     self.U_o, self.b_o)
208
        # Second input
209
        if self.second_input_size:
210
            self.W_i2 = self.create_weight(self.second_input_size, self._hidden_size, "wi2", initializer=self._outer_init)
211
            self.W_f2 = self.create_weight(self.second_input_size, self._hidden_size, "wf2", initializer=self._outer_init)
212
            self.W_c2 = self.create_weight(self.second_input_size, self._hidden_size, "wc2", initializer=self._outer_init)
213
            self.W_o2 = self.create_weight(self.second_input_size, self._hidden_size, "wo2", initializer=self._outer_init)
214
            self.register_parameters(self.W_i2, self.W_f2, self.W_c2, self.W_o2)
215
216
        # Create persistent state
217
        if self.persistent_state:
218
            self.state_h = self.create_matrix(self.batch_size, self._hidden_size, "lstm_state_h")
219
            self.state_m = self.create_matrix(self.batch_size, self._hidden_size, "lstm_state_m")
220
            self.register_free_parameters(self.state_h, self.state_m)
221
        else:
222
            self.state_h = None
223
            self.state_m = None
224