Completed
Push — master ( 6bb71e...eef17b )
by Raphael
01:44
created

LSTM.merge_inputs()   A

Complexity

Conditions 3

Size

Total Lines 16

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 3
c 1
b 0
f 0
dl 0
loc 16
rs 9.4285
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import numpy as np
5
import theano.tensor as T
6
from recurrent import RecurrentLayer
7
from deepy.utils import neural_computation, FLOATX
8
9
10
class LSTM(RecurrentLayer):
11
    """
12
    Long short-term memory layer.
13
    """
14
15
    def __init__(self, hidden_size, init_forget_bias=1, **kwargs):
16
        kwargs["hidden_size"] = hidden_size
17
        super(LSTM, self).__init__("LSTM", ["state", "lstm_cell"], **kwargs)
18
        self._init_forget_bias = 1
19
20
    @neural_computation
21
    def compute_new_state(self, step_inputs):
22
        xi_t, xf_t, xo_t, xc_t, h_tm1, c_tm1 = map(step_inputs.get, ["xi", "xf", "xc", "xo", "state", "lstm_cell"])
23
        if not xi_t:
24
            xi_t, xf_t, xo_t, xc_t = 0, 0, 0, 0
25
26
        # LSTM core step
27
        i_t = self.gate_activate(xi_t + T.dot(h_tm1, self.U_i) + self.b_i)
28
        f_t = self.gate_activate(xf_t + T.dot(h_tm1, self.U_f) + self.b_f)
29
        c_t = f_t * c_tm1 + i_t * self.activate(xc_t + T.dot(h_tm1, self.U_c) + self.b_c)
30
        o_t = self.gate_activate(xo_t + T.dot(h_tm1, self.U_o) + self.b_o)
31
        h_t = o_t * self.activate(c_t)
32
33
        return {"state": h_t, "lstm_cell": c_t}
34
35
    @neural_computation
36
    def merge_inputs(self, input_var, additional_inputs=None):
37
        if not additional_inputs:
38
            additional_inputs = []
39
        all_inputs = filter(bool, [input_var] + additional_inputs)
40
        if not all_inputs:
41
            return {}
42
        last_dim_id = input_var.ndim - 1
43
        merged_input = T.concatenate(all_inputs, axis=last_dim_id)
44
        merged_inputs = {
45
            "xi": T.dot(merged_input, self.W_i),
46
            "xf": T.dot(merged_input, self.W_f),
47
            "xc": T.dot(merged_input, self.W_c),
48
            "xo": T.dot(merged_input, self.W_o),
49
        }
50
        return merged_inputs
51
52
53
    def prepare(self):
54
        if self._input_type == "sequence":
55
            all_input_dims = [self.input_dim] + self.additional_input_dims
56
        else:
57
            all_input_dims = self.additional_input_dims
58
        summed_input_dim = sum(all_input_dims, 0)
59
        self.output_dim = self.hidden_size
60
61
        self.W_i = self.create_weight(summed_input_dim, self.hidden_size, "wi", initializer=self.outer_init)
62
        self.U_i = self.create_weight(self.hidden_size, self.hidden_size, "ui", initializer=self.inner_init)
63
        self.b_i = self.create_bias(self.hidden_size, "i")
64
65
        self.W_f = self.create_weight(summed_input_dim, self.hidden_size, "wf", initializer=self.outer_init)
66
        self.U_f = self.create_weight(self.hidden_size, self.hidden_size, "uf", initializer=self.inner_init)
67
        self.b_f = self.create_bias(self.hidden_size, "f")
68
        self.b_f.set_value(np.ones((self.hidden_size,) * self._init_forget_bias, dtype=FLOATX))
69
70
        self.W_c = self.create_weight(summed_input_dim, self.hidden_size, "wc", initializer=self.outer_init)
71
        self.U_c = self.create_weight(self.hidden_size, self.hidden_size, "uc", initializer=self.inner_init)
72
        self.b_c = self.create_bias(self.hidden_size, "c")
73
74
        self.W_o = self.create_weight(summed_input_dim, self.hidden_size, "wo", initializer=self.outer_init)
75
        self.U_o = self.create_weight(self.hidden_size, self.hidden_size, "uo", initializer=self.inner_init)
76
        self.b_o = self.create_bias(self.hidden_size, suffix="o")
77
78
79
        if summed_input_dim > 0:
80
            self.register_parameters(self.W_i, self.U_i, self.b_i,
81
                                     self.W_c, self.U_c, self.b_c,
82
                                     self.W_f, self.U_f, self.b_f,
83
                                     self.W_o, self.U_o, self.b_o)
84
        else:
85
            self.register_parameters(self.U_i, self.b_i,
86
                                     self.U_c, self.b_c,
87
                                     self.U_f, self.b_f,
88
                                     self.U_o, self.b_o)
89
90