Completed
Push — master ( 394368...090fba )
by Raphael
01:33
created

LSTM.__init__()   A

Complexity

Conditions 1

Size

Total Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 6
Bugs 0 Features 0
Metric Value
cc 1
c 6
b 0
f 0
dl 0
loc 4
rs 10
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import numpy as np
5
import theano.tensor as T
6
from recurrent import RecurrentLayer
7
from deepy.core.tensor_conversion import neural_computation
8
from deepy.core.env import env
9
10
11
class LSTM(RecurrentLayer):
12
    """
13
    Long short-term memory layer.
14
    """
15
16
    def __init__(self, hidden_size, init_forget_bias=1, **kwargs):
17
        kwargs["hidden_size"] = hidden_size
18
        super(LSTM, self).__init__("LSTM", ["state", "lstm_cell"], **kwargs)
19
        self._init_forget_bias = 1
20
21
    @neural_computation
22
    def compute_new_state(self, step_inputs):
23
        xi_t, xf_t, xo_t, xc_t, h_tm1, c_tm1 = map(step_inputs.get, ["xi", "xf", "xc", "xo", "state", "lstm_cell"])
24
        if not xi_t:
25
            xi_t, xf_t, xo_t, xc_t = 0, 0, 0, 0
26
27
        # LSTM core step
28
        i_t = self.gate_activate(xi_t + T.dot(h_tm1, self.U_i) + self.b_i)
29
        f_t = self.gate_activate(xf_t + T.dot(h_tm1, self.U_f) + self.b_f)
30
        c_t = f_t * c_tm1 + i_t * self.activate(xc_t + T.dot(h_tm1, self.U_c) + self.b_c)
31
        o_t = self.gate_activate(xo_t + T.dot(h_tm1, self.U_o) + self.b_o)
32
        h_t = o_t * self.activate(c_t)
33
34
        return {"state": h_t, "lstm_cell": c_t}
35
36
    @neural_computation
37
    def merge_inputs(self, input_var, additional_inputs=None):
38
        if not additional_inputs:
39
            additional_inputs = []
40
        all_inputs = filter(bool, [input_var] + additional_inputs)
41
        if not all_inputs:
42
            return {}
43
        last_dim_id = all_inputs[0].ndim - 1
44
        merged_input = T.concatenate(all_inputs, axis=last_dim_id)
45
        merged_inputs = {
46
            "xi": T.dot(merged_input, self.W_i),
47
            "xf": T.dot(merged_input, self.W_f),
48
            "xc": T.dot(merged_input, self.W_c),
49
            "xo": T.dot(merged_input, self.W_o),
50
        }
51
        return merged_inputs
52
53
54
    def prepare(self):
55
        if self._input_type == "sequence":
56
            all_input_dims = [self.input_dim] + self.additional_input_dims
57
        else:
58
            all_input_dims = self.additional_input_dims
59
        summed_input_dim = sum(all_input_dims, 0)
60
        self.output_dim = self.hidden_size
61
62
        self.W_i = self.create_weight(summed_input_dim, self.hidden_size, "wi", initializer=self.outer_init)
63
        self.U_i = self.create_weight(self.hidden_size, self.hidden_size, "ui", initializer=self.inner_init)
64
        self.b_i = self.create_bias(self.hidden_size, "bi")
65
66
        self.W_f = self.create_weight(summed_input_dim, self.hidden_size, "wf", initializer=self.outer_init)
67
        self.U_f = self.create_weight(self.hidden_size, self.hidden_size, "uf", initializer=self.inner_init)
68
        self.b_f = self.create_bias(self.hidden_size, "bf")
69
        self.b_f.set_value(np.ones((self.hidden_size,) * self._init_forget_bias, dtype=env.FLOATX))
70
71
        self.W_c = self.create_weight(summed_input_dim, self.hidden_size, "wc", initializer=self.outer_init)
72
        self.U_c = self.create_weight(self.hidden_size, self.hidden_size, "uc", initializer=self.inner_init)
73
        self.b_c = self.create_bias(self.hidden_size, "bc")
74
75
        self.W_o = self.create_weight(summed_input_dim, self.hidden_size, "wo", initializer=self.outer_init)
76
        self.U_o = self.create_weight(self.hidden_size, self.hidden_size, "uo", initializer=self.inner_init)
77
        self.b_o = self.create_bias(self.hidden_size, "bo")
78
79
80
        if summed_input_dim > 0:
81
            self.register_parameters(self.W_i, self.U_i, self.b_i,
82
                                     self.W_c, self.U_c, self.b_c,
83
                                     self.W_f, self.U_f, self.b_f,
84
                                     self.W_o, self.U_o, self.b_o)
85
        else:
86
            self.register_parameters(self.U_i, self.b_i,
87
                                     self.U_c, self.b_c,
88
                                     self.U_f, self.b_f,
89
                                     self.U_o, self.b_o)
90
91