deepy.layers.IRNN   A
last analyzed

Complexity

Total Complexity 6

Size/Duplication

Total Lines 33
Duplicated Lines 0 %
Metric Value
dl 0
loc 33
rs 10
wmc 6

2 Methods

Rating   Name   Duplication   Size   Complexity  
A training_callback() 0 11 4
A __init__() 0 15 2
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
from . import RNN
5
from deepy.utils import GaussianInitializer, IdentityInitializer, FLOATX
6
7
MAX_IDENTITY_VALUE = 0.99
8
MIN_IDENTITY_VALUE = 0.0
9
10
class IRNN(RNN):
11
    """
12
    The implementation of http://arxiv.org/abs/1504.00941 .
13
    RNN with weight initialization using identity matrix.
14
    """
15
16
    def __init__(self, hidden_size, input_type="sequence", output_type="one", steps=None, go_backwards=False,
17
                 weight_scale=0.9, bound_recurrent_weight=True, mask=None,
18
                 persistent_state=False, reset_state_for_input=None, batch_size=None,
19
                 second_input=None, second_input_size=None):
20
        super(IRNN, self).__init__(hidden_size,
21
                                   input_type=input_type, output_type=output_type,
22
                                   hidden_activation="relu", steps=steps,
23
                                   hidden_init=IdentityInitializer(scale=weight_scale),
24
                                   input_init=GaussianInitializer(deviation=0.001),
25
                                   persistent_state=persistent_state, reset_state_for_input=reset_state_for_input,
26
                                   batch_size=batch_size, go_backwards=go_backwards,
27
                                   mask=mask, second_input=second_input, second_input_size=second_input_size)
28
        self.name = "irnn"
29
        if bound_recurrent_weight:
30
            self.register_training_callbacks(self.training_callback)
31
32
    def training_callback(self):
33
        w_value = self.W_h.get_value(borrow=True)
34
        changed = False
35
        if w_value.max() > MAX_IDENTITY_VALUE:
36
            w_value = w_value * (w_value <= MAX_IDENTITY_VALUE) + MAX_IDENTITY_VALUE * (w_value > MAX_IDENTITY_VALUE)
37
            changed = True
38
        if w_value.min() < MIN_IDENTITY_VALUE:
39
            w_value = w_value * (w_value >= MIN_IDENTITY_VALUE) + MIN_IDENTITY_VALUE * (w_value < MIN_IDENTITY_VALUE)
40
            changed = True
41
        if changed:
42
            self.W_h.set_value(w_value.astype(FLOATX))
43