|
1
|
|
|
#!/usr/bin/env python |
|
2
|
|
|
# -*- coding: utf-8 -*- |
|
3
|
|
|
|
|
4
|
|
|
from . import RNN |
|
5
|
|
|
from deepy.utils import GaussianInitializer, IdentityInitializer, FLOATX |
|
6
|
|
|
|
|
7
|
|
|
MAX_IDENTITY_VALUE = 0.99 |
|
8
|
|
|
MIN_IDENTITY_VALUE = 0.0 |
|
9
|
|
|
|
|
10
|
|
|
class IRNN(RNN): |
|
11
|
|
|
""" |
|
12
|
|
|
The implementation of http://arxiv.org/abs/1504.00941 . |
|
13
|
|
|
RNN with weight initialization using identity matrix. |
|
14
|
|
|
""" |
|
15
|
|
|
|
|
16
|
|
|
def __init__(self, hidden_size, input_type="sequence", output_type="one", steps=None, go_backwards=False, |
|
17
|
|
|
weight_scale=0.9, bound_recurrent_weight=True, mask=None, |
|
18
|
|
|
persistent_state=False, reset_state_for_input=None, batch_size=None, |
|
19
|
|
|
second_input=None, second_input_size=None): |
|
20
|
|
|
super(IRNN, self).__init__(hidden_size, |
|
21
|
|
|
input_type=input_type, output_type=output_type, |
|
22
|
|
|
hidden_activation="relu", steps=steps, |
|
23
|
|
|
hidden_init=IdentityInitializer(scale=weight_scale), |
|
24
|
|
|
input_init=GaussianInitializer(deviation=0.001), |
|
25
|
|
|
persistent_state=persistent_state, reset_state_for_input=reset_state_for_input, |
|
26
|
|
|
batch_size=batch_size, go_backwards=go_backwards, |
|
27
|
|
|
mask=mask, second_input=second_input, second_input_size=second_input_size) |
|
28
|
|
|
self.name = "irnn" |
|
29
|
|
|
if bound_recurrent_weight: |
|
30
|
|
|
self.register_training_callbacks(self.training_callback) |
|
31
|
|
|
|
|
32
|
|
|
def training_callback(self): |
|
33
|
|
|
w_value = self.W_h.get_value(borrow=True) |
|
34
|
|
|
changed = False |
|
35
|
|
|
if w_value.max() > MAX_IDENTITY_VALUE: |
|
36
|
|
|
w_value = w_value * (w_value <= MAX_IDENTITY_VALUE) + MAX_IDENTITY_VALUE * (w_value > MAX_IDENTITY_VALUE) |
|
37
|
|
|
changed = True |
|
38
|
|
|
if w_value.min() < MIN_IDENTITY_VALUE: |
|
39
|
|
|
w_value = w_value * (w_value >= MIN_IDENTITY_VALUE) + MIN_IDENTITY_VALUE * (w_value < MIN_IDENTITY_VALUE) |
|
40
|
|
|
changed = True |
|
41
|
|
|
if changed: |
|
42
|
|
|
self.W_h.set_value(w_value.astype(FLOATX)) |
|
43
|
|
|
|