Completed
Push — master ( 5bbe2a...9d73f5 )
by Raphael
01:33
created

deepy.layers.GRU   A

Complexity

Total Complexity 9

Size/Duplication

Total Lines 63
Duplicated Lines 0 %
Metric Value
dl 0
loc 63
rs 10
wmc 9

4 Methods

Rating   Name   Duplication   Size   Complexity  
B prepare() 0 26 3
A merge_inputs() 0 18 3
A compute_new_state() 0 11 2
A __init__() 0 3 1
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import theano.tensor as T
5
from recurrent import RecurrentLayer
6
7
OUTPUT_TYPES = ["sequence", "one"]
8
INPUT_TYPES = ["sequence", "one"]
9
10
class GRU(RecurrentLayer):
11
12
    def __init__(self, hidden_size, **kwargs):
13
        kwargs["hidden_size"] = hidden_size
14
        super(GRU, self).__init__("GRU", ["state"], **kwargs)
15
16
    def compute_new_state(self, step_inputs):
17
        xz_t, xr_t, xh_t, h_tm1 = map(step_inputs.get, ["xz_t", "xr_t", "xh_t", "state"])
18
        if not xz_t:
19
            xz_t, xr_t, xh_t = 0, 0, 0
20
21
        z_t = self.gate_activate(xz_t + T.dot(h_tm1, self.U_z) + self.b_z)
22
        r_t = self.gate_activate(xr_t + T.dot(h_tm1, self.U_r) + self.b_r)
23
        h_t_pre = self.activate(xh_t + T.dot(r_t * h_tm1, self.U_h) + self.b_h)
24
        h_t = z_t * h_tm1 + (1 - z_t) *  h_t_pre
25
26
        return {"state": h_t}
27
28
    def merge_inputs(self, input_var, additional_inputs=None):
29
        if not additional_inputs:
30
            additional_inputs = []
31
        all_inputs = [input_var] + additional_inputs
32
        z_inputs = []
33
        r_inputs = []
34
        h_inputs = []
35
        for x, weights in zip(all_inputs, self.input_weights):
36
            wz, wr, wh = weights
37
            z_inputs.append(T.dot(x, wz))
38
            r_inputs.append(T.dot(x, wr))
39
            h_inputs.append(T.dot(x, wh))
40
        merged_inputs = {
41
            "xz_t": sum(z_inputs),
42
            "xr_t": sum(r_inputs),
43
            "xh_t": sum(h_inputs)
44
        }
45
        return merged_inputs
46
47
    def prepare(self):
48
        self.output_dim = self.hidden_size
49
50
        self.U_z = self.create_weight(self.hidden_size, self.hidden_size, "uz", initializer=self.inner_init)
51
        self.b_z = self.create_bias(self.hidden_size, "z")
52
53
        self.U_r = self.create_weight(self.hidden_size, self.hidden_size, "ur", initializer=self.inner_init)
54
        self.b_r = self.create_bias(self.hidden_size, "r")
55
56
        self.U_h = self.create_weight(self.hidden_size, self.hidden_size, "uh", initializer=self.inner_init)
57
        self.b_h = self.create_bias(self.hidden_size, "h")
58
59
        self.register_parameters(self.U_z, self.b_z,
60
                                 self.U_r, self.b_r,
61
                                 self.U_h, self.b_h)
62
63
        self.input_weights = []
64
        if self._input_type == "sequence":
65
            all_input_dims = [self.input_dim] + self.additional_input_dims
66
            for i, input_dim in enumerate(all_input_dims):
67
                wz = self.create_weight(input_dim, self.hidden_size, "wz_{}".format(i+1), initializer=self.outer_init)
68
                wr = self.create_weight(input_dim, self.hidden_size, "wr_{}".format(i+1), initializer=self.outer_init)
69
                wh = self.create_weight(input_dim, self.hidden_size, "wh_{}".format(i+1), initializer=self.outer_init)
70
                weights = [wz, wr, wh]
71
                self.input_weights.append(weights)
72
                self.register_parameters(*weights)
73