|
1
|
|
|
#!/usr/bin/env python |
|
2
|
|
|
# -*- coding: utf-8 -*- |
|
3
|
|
|
|
|
4
|
|
|
from . import NeuralLayer |
|
5
|
|
|
from variable import NeuralVar |
|
6
|
|
|
from deepy.utils import build_activation, FLOATX |
|
7
|
|
|
import numpy as np |
|
8
|
|
|
import theano |
|
9
|
|
|
import theano.tensor as T |
|
10
|
|
|
from collections import OrderedDict |
|
11
|
|
|
|
|
12
|
|
|
OUTPUT_TYPES = ["sequence", "one"] |
|
13
|
|
|
INPUT_TYPES = ["sequence", "one"] |
|
14
|
|
|
|
|
15
|
|
|
class RNN(NeuralLayer): |
|
16
|
|
|
""" |
|
17
|
|
|
Recurrent neural network layer. |
|
18
|
|
|
""" |
|
19
|
|
|
|
|
20
|
|
|
def __init__(self, hidden_size, input_type="sequence", output_type="sequence", vector_core=None, |
|
21
|
|
|
hidden_activation="tanh", hidden_init=None, input_init=None, steps=None, |
|
22
|
|
|
persistent_state=False, reset_state_for_input=None, batch_size=None, |
|
23
|
|
|
go_backwards=False, mask=None, second_input_size=None, second_input=None): |
|
24
|
|
|
super(RNN, self).__init__("rnn") |
|
25
|
|
|
self._hidden_size = hidden_size |
|
26
|
|
|
self.output_dim = self._hidden_size |
|
27
|
|
|
self._input_type = input_type |
|
28
|
|
|
self._output_type = output_type |
|
29
|
|
|
self._hidden_activation = hidden_activation |
|
30
|
|
|
self._hidden_init = hidden_init |
|
31
|
|
|
self._vector_core = vector_core |
|
32
|
|
|
self._input_init = input_init |
|
33
|
|
|
self.persistent_state = persistent_state |
|
34
|
|
|
self.reset_state_for_input = reset_state_for_input |
|
35
|
|
|
self.batch_size = batch_size |
|
36
|
|
|
self._steps = steps |
|
37
|
|
|
self._go_backwards = go_backwards |
|
38
|
|
|
# mask |
|
39
|
|
|
mask = mask.tensor if type(mask) == NeuralVar else mask |
|
40
|
|
|
self._mask = mask.dimshuffle((1,0)) if mask else None |
|
41
|
|
|
# second input |
|
42
|
|
|
if type(second_input) == NeuralVar: |
|
43
|
|
|
second_input = second_input.tensor |
|
44
|
|
|
second_input_size = second_input.dim() |
|
45
|
|
|
self._second_input_size = second_input_size |
|
46
|
|
|
self._second_input = second_input |
|
47
|
|
|
self._sequence_map = OrderedDict() |
|
48
|
|
|
if input_type not in INPUT_TYPES: |
|
49
|
|
|
raise Exception("Input type of RNN is wrong: %s" % input_type) |
|
50
|
|
|
if output_type not in OUTPUT_TYPES: |
|
51
|
|
|
raise Exception("Output type of RNN is wrong: %s" % output_type) |
|
52
|
|
|
if self.persistent_state and not self.batch_size: |
|
53
|
|
|
raise Exception("Batch size must be set for persistent state mode") |
|
54
|
|
|
if mask and input_type == "one": |
|
55
|
|
|
raise Exception("Mask only works with sequence input") |
|
56
|
|
|
|
|
57
|
|
|
def _hidden_preact(self, h): |
|
58
|
|
|
return T.dot(h, self.W_h) if not self._vector_core else h * self.W_h |
|
59
|
|
|
|
|
60
|
|
|
def step(self, *vars): |
|
61
|
|
|
# Parse sequence |
|
62
|
|
|
sequence_map = dict(zip(self._sequence_map.keys(), vars[:len(self._sequence_map)])) |
|
63
|
|
|
if self._input_type == "sequence": |
|
64
|
|
|
x = sequence_map["x"] |
|
65
|
|
|
h = vars[-1] |
|
66
|
|
|
# Reset part of the state on condition |
|
67
|
|
|
if self.reset_state_for_input != None: |
|
68
|
|
|
h = h * T.neq(x[:, self.reset_state_for_input], 1).dimshuffle(0, 'x') |
|
69
|
|
|
# RNN core step |
|
70
|
|
|
z = x + self._hidden_preact(h) + self.B_h |
|
71
|
|
|
else: |
|
72
|
|
|
h = vars[-1] |
|
73
|
|
|
z = self._hidden_preact(h) + self.B_h |
|
74
|
|
|
# Second input |
|
75
|
|
|
if "second_input" in sequence_map: |
|
76
|
|
|
z += sequence_map["second_input"] |
|
77
|
|
|
|
|
78
|
|
|
new_h = self._hidden_act(z) |
|
79
|
|
|
# Apply mask |
|
80
|
|
|
if "mask" in sequence_map: |
|
81
|
|
|
mask = sequence_map["mask"].dimshuffle(0, 'x') |
|
82
|
|
|
new_h = mask * new_h + (1 - mask) * h |
|
83
|
|
|
return new_h |
|
84
|
|
|
|
|
85
|
|
|
def produce_input_sequences(self, x, mask=None, second_input=None): |
|
86
|
|
|
self._sequence_map.clear() |
|
87
|
|
|
if self._input_type == "sequence": |
|
88
|
|
|
self._sequence_map["x"] = T.dot(x, self.W_i) |
|
89
|
|
|
# Mask |
|
90
|
|
|
if mask: |
|
91
|
|
|
# (batch) |
|
92
|
|
|
self._sequence_map["mask"] = mask |
|
93
|
|
|
elif self._mask: |
|
94
|
|
|
# (time, batch) |
|
95
|
|
|
self._sequence_map["mask"] = self._mask |
|
96
|
|
|
# Second input |
|
97
|
|
|
if second_input: |
|
98
|
|
|
self._sequence_map["second_input"] = T.dot(second_input, self.W_i2) |
|
99
|
|
|
elif self._second_input: |
|
100
|
|
|
self._sequence_map["second_input"] = T.dot(self._second_input, self.W_i2) |
|
101
|
|
|
return self._sequence_map.values() |
|
102
|
|
|
|
|
103
|
|
|
def produce_initial_states(self, x): |
|
104
|
|
|
h0 = T.alloc(np.cast[FLOATX](0.), x.shape[0], self._hidden_size) |
|
105
|
|
|
if self._input_type == "sequence": |
|
106
|
|
|
if self.persistent_state: |
|
107
|
|
|
h0 = self.state |
|
108
|
|
|
else: |
|
109
|
|
|
h0 = x |
|
110
|
|
|
return [h0] |
|
111
|
|
|
|
|
112
|
|
|
def output(self, x): |
|
113
|
|
|
if self._input_type == "sequence": |
|
114
|
|
|
# Move middle dimension to left-most position |
|
115
|
|
|
# (sequence, batch, value) |
|
116
|
|
|
sequences = self.produce_input_sequences(x.dimshuffle((1,0,2))) |
|
117
|
|
|
else: |
|
118
|
|
|
sequences = self.produce_input_sequences(None) |
|
119
|
|
|
|
|
120
|
|
|
step_outputs = self.produce_initial_states(x) |
|
121
|
|
|
hiddens, _ = theano.scan(self.step, sequences=sequences, outputs_info=step_outputs, |
|
122
|
|
|
n_steps=self._steps, go_backwards=self._go_backwards) |
|
123
|
|
|
|
|
124
|
|
|
# Save persistent state |
|
125
|
|
|
if self.persistent_state: |
|
126
|
|
|
self.register_updates((self.state, hiddens[-1])) |
|
127
|
|
|
|
|
128
|
|
|
if self._output_type == "one": |
|
129
|
|
|
return hiddens[-1] |
|
130
|
|
|
elif self._output_type == "sequence": |
|
131
|
|
|
return hiddens.dimshuffle((1,0,2)) |
|
132
|
|
|
|
|
133
|
|
|
def prepare(self): |
|
134
|
|
|
if self._input_type == "one" and self.input_dim != self._hidden_size: |
|
135
|
|
|
raise Exception("For RNN receives one vector as input, " |
|
136
|
|
|
"the hidden size should be same as last output dimension.") |
|
137
|
|
|
self._setup_params() |
|
138
|
|
|
self._setup_functions() |
|
139
|
|
|
|
|
140
|
|
|
def _setup_functions(self): |
|
141
|
|
|
self._hidden_act = build_activation(self._hidden_activation) |
|
142
|
|
|
|
|
143
|
|
|
def _setup_params(self): |
|
144
|
|
|
if not self._vector_core: |
|
145
|
|
|
self.W_h = self.create_weight(self._hidden_size, self._hidden_size, suffix="h", initializer=self._hidden_init) |
|
146
|
|
|
else: |
|
147
|
|
|
self.W_h = self.create_bias(self._hidden_size, suffix="h") |
|
148
|
|
|
self.W_h.set_value(self.W_h.get_value() + self._vector_core) |
|
149
|
|
|
self.B_h = self.create_bias(self._hidden_size, suffix="h") |
|
150
|
|
|
|
|
151
|
|
|
self.register_parameters(self.W_h, self.B_h) |
|
152
|
|
|
|
|
153
|
|
|
if self.persistent_state: |
|
154
|
|
|
self.state = self.create_matrix(self.batch_size, self._hidden_size, "rnn_state") |
|
155
|
|
|
self.register_free_parameters(self.state) |
|
156
|
|
|
else: |
|
157
|
|
|
self.state = None |
|
158
|
|
|
|
|
159
|
|
|
if self._input_type == "sequence": |
|
160
|
|
|
self.W_i = self.create_weight(self.input_dim, self._hidden_size, suffix="i", initializer=self._input_init) |
|
161
|
|
|
self.register_parameters(self.W_i) |
|
162
|
|
|
if self._second_input_size: |
|
163
|
|
|
self.W_i2 = self.create_weight(self._second_input_size, self._hidden_size, suffix="i2", initializer=self._input_init) |
|
164
|
|
|
self.register_parameters(self.W_i2) |
|
165
|
|
|
|