Completed
Push — master ( c8682f...5bbe2a )
by Raphael
01:33
created

deepy.layers.RNN.output()   B

Complexity

Conditions 5

Size

Total Lines 20

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 5
dl 0
loc 20
rs 8.5455
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
from . import NeuralLayer
5
from variable import NeuralVar
6
from deepy.utils import build_activation, FLOATX
7
import numpy as np
8
import theano
9
import theano.tensor as T
10
from collections import OrderedDict
11
12
OUTPUT_TYPES = ["sequence", "one"]
13
INPUT_TYPES = ["sequence", "one"]
14
15
class RNN(NeuralLayer):
16
    """
17
    Recurrent neural network layer.
18
    """
19
20
    def __init__(self, hidden_size, input_type="sequence", output_type="sequence", vector_core=None,
21
                 hidden_activation="tanh", hidden_init=None, input_init=None, steps=None,
22
                 persistent_state=False, reset_state_for_input=None, batch_size=None,
23
                 go_backwards=False, mask=None, second_input_size=None, second_input=None):
24
        super(RNN, self).__init__("rnn")
25
        self._hidden_size = hidden_size
26
        self.output_dim = self._hidden_size
27
        self._input_type = input_type
28
        self._output_type = output_type
29
        self._hidden_activation = hidden_activation
30
        self._hidden_init = hidden_init
31
        self._vector_core = vector_core
32
        self._input_init = input_init
33
        self.persistent_state = persistent_state
34
        self.reset_state_for_input = reset_state_for_input
35
        self.batch_size = batch_size
36
        self._steps = steps
37
        self._go_backwards = go_backwards
38
        # mask
39
        mask = mask.tensor if type(mask) == NeuralVar else mask
40
        self._mask = mask.dimshuffle((1,0)) if mask else None
41
        # second input
42
        if type(second_input) == NeuralVar:
43
            second_input = second_input.tensor
44
            second_input_size = second_input.dim()
45
        self._second_input_size = second_input_size
46
        self._second_input = second_input
47
        self._sequence_map = OrderedDict()
48
        if input_type not in INPUT_TYPES:
49
            raise Exception("Input type of RNN is wrong: %s" % input_type)
50
        if output_type not in OUTPUT_TYPES:
51
            raise Exception("Output type of RNN is wrong: %s" % output_type)
52
        if self.persistent_state and not self.batch_size:
53
            raise Exception("Batch size must be set for persistent state mode")
54
        if mask and input_type == "one":
55
            raise Exception("Mask only works with sequence input")
56
57
    def _hidden_preact(self, h):
58
        return T.dot(h, self.W_h) if not self._vector_core else h * self.W_h
59
60
    def step(self, *vars):
61
        # Parse sequence
62
        sequence_map = dict(zip(self._sequence_map.keys(), vars[:len(self._sequence_map)]))
63
        if self._input_type == "sequence":
64
            x = sequence_map["x"]
65
            h = vars[-1]
66
            # Reset part of the state on condition
67
            if self.reset_state_for_input != None:
68
                h = h * T.neq(x[:, self.reset_state_for_input], 1).dimshuffle(0, 'x')
69
            # RNN core step
70
            z = x + self._hidden_preact(h) + self.B_h
71
        else:
72
            h = vars[-1]
73
            z = self._hidden_preact(h) + self.B_h
74
        # Second input
75
        if "second_input" in sequence_map:
76
            z += sequence_map["second_input"]
77
78
        new_h = self._hidden_act(z)
79
        # Apply mask
80
        if "mask" in sequence_map:
81
            mask = sequence_map["mask"].dimshuffle(0, 'x')
82
            new_h = mask * new_h + (1 - mask) * h
83
        return new_h
84
85
    def produce_input_sequences(self, x, mask=None, second_input=None):
86
        self._sequence_map.clear()
87
        if self._input_type == "sequence":
88
            self._sequence_map["x"] = T.dot(x, self.W_i)
89
            # Mask
90
            if mask:
91
                # (batch)
92
                self._sequence_map["mask"] = mask
93
            elif self._mask:
94
                # (time, batch)
95
                self._sequence_map["mask"] = self._mask
96
        # Second input
97
        if second_input:
98
            self._sequence_map["second_input"] = T.dot(second_input, self.W_i2)
99
        elif self._second_input:
100
            self._sequence_map["second_input"] = T.dot(self._second_input, self.W_i2)
101
        return self._sequence_map.values()
102
103
    def produce_initial_states(self, x):
104
        h0 = T.alloc(np.cast[FLOATX](0.), x.shape[0], self._hidden_size)
105
        if self._input_type == "sequence":
106
            if self.persistent_state:
107
                h0 = self.state
108
        else:
109
            h0 = x
110
        return [h0]
111
112
    def output(self, x):
113
        if self._input_type == "sequence":
114
            # Move middle dimension to left-most position
115
            # (sequence, batch, value)
116
            sequences = self.produce_input_sequences(x.dimshuffle((1,0,2)))
117
        else:
118
            sequences = self.produce_input_sequences(None)
119
120
        step_outputs = self.produce_initial_states(x)
121
        hiddens, _ = theano.scan(self.step, sequences=sequences, outputs_info=step_outputs,
122
                                 n_steps=self._steps, go_backwards=self._go_backwards)
123
124
        # Save persistent state
125
        if self.persistent_state:
126
            self.register_updates((self.state, hiddens[-1]))
127
128
        if self._output_type == "one":
129
            return hiddens[-1]
130
        elif self._output_type == "sequence":
131
            return hiddens.dimshuffle((1,0,2))
132
133
    def prepare(self):
134
        if self._input_type == "one" and self.input_dim != self._hidden_size:
135
            raise Exception("For RNN receives one vector as input, "
136
                            "the hidden size should be same as last output dimension.")
137
        self._setup_params()
138
        self._setup_functions()
139
140
    def _setup_functions(self):
141
        self._hidden_act = build_activation(self._hidden_activation)
142
143
    def _setup_params(self):
144
        if not self._vector_core:
145
            self.W_h = self.create_weight(self._hidden_size, self._hidden_size, suffix="h", initializer=self._hidden_init)
146
        else:
147
            self.W_h = self.create_bias(self._hidden_size, suffix="h")
148
            self.W_h.set_value(self.W_h.get_value() + self._vector_core)
149
        self.B_h = self.create_bias(self._hidden_size, suffix="h")
150
151
        self.register_parameters(self.W_h, self.B_h)
152
153
        if self.persistent_state:
154
            self.state = self.create_matrix(self.batch_size, self._hidden_size, "rnn_state")
155
            self.register_free_parameters(self.state)
156
        else:
157
            self.state = None
158
159
        if self._input_type == "sequence":
160
            self.W_i = self.create_weight(self.input_dim, self._hidden_size, suffix="i", initializer=self._input_init)
161
            self.register_parameters(self.W_i)
162
        if self._second_input_size:
163
            self.W_i2 = self.create_weight(self._second_input_size, self._hidden_size, suffix="i2", initializer=self._input_init)
164
            self.register_parameters(self.W_i2)
165