FirstGlimpseLayer.init()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 1
c 1
b 0
f 0
dl 0
loc 9
rs 9.6666
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
import os
4
5
import numpy as np
6
from numpy import linalg as LA
7
from theano import tensor as T
8
import theano
9
from theano.tensor.shared_randomstreams import RandomStreams
10
11
from deepy import NeuralClassifier
12
from deepy.utils import get_activation
13
from deepy.core.disconnected_grad import disconnected_grad
14
from deepy.utils.functions import FLOATX
15
from deepy.layers import NeuralLayer
16
from examples.attention_models.gaussian_sampler import SampleMultivariateGaussian
17
import theano.tensor.signal.downsample
18
19
20
class FirstGlimpseLayer(NeuralLayer):
21
22
    def __init__(self, activation='tanh', std=0.1, disable_reinforce=False, random_glimpse=False):
23
        self.disable_reinforce = disable_reinforce
24
        self.random_glimpse = random_glimpse
25
        self.gaussian_std = std
26
        super(FirstGlimpseLayer, self).__init__(10, activation)
27
28
    def init(self, config, vars, x, input_n, id="UNKNOWN"):
29
        self._config = config
30
        self._vars = vars
31
        self.input_n = input_n
32
        self.id = id
33
        self.x = x
34
        self._setup_params()
35
        self._setup_functions()
36
        self.connected = True
37
38
    def _first_glimpse_sensor(self, x_t):
39
        """
40
        Compute first glimpse position using down-sampled image.
41
        """
42
        downsampled_img = theano.tensor.signal.downsample.max_pool_2d(x_t, (4,4))
43
        downsampled_img = downsampled_img.flatten()
44
        first_l = T.dot(downsampled_img, self.W_f)
45
        if self.disable_reinforce:
46
            wf_grad = self.W_f
47
            if self.random_glimpse:
48
                first_l = self.srng.uniform((2,), low=-1.7, high=1.7)
49
        else:
50
            sampled_l_t = self._sample_gaussian(first_l, self.cov)
51
            sampled_pdf = self._multi_gaussian_pdf(disconnected_grad(sampled_l_t), first_l)
52
            wf_grad = T.grad(T.log(sampled_pdf), self.W_f)
53
            first_l = sampled_l_t
54
        return first_l, wf_grad
55
56
    def _refined_glimpse_sensor(self, x_t, l_p):
57
        """
58
        Parameters:
59
            x_t - 28x28 image
60
            l_p - 2x1 focus vector
61
        Returns:
62
            7*14 matrix
63
        """
64
        # Turn l_p to the left-top point of rectangle
65
        l_p = l_p * 6.67 + 14 - 4
66
        l_p = T.cast(T.round(l_p), "int32")
67
68
        l_p = l_p * (l_p >= 0)
69
        l_p = l_p * (l_p < 21) + (l_p >= 21) * 20
70
        glimpse_1 = x_t[l_p[0]: l_p[0] + 7][:, l_p[1]: l_p[1] + 7]
71
        return glimpse_1
72
73
    def _multi_gaussian_pdf(self, vec, mean):
74
        norm2d_var = ((1.0 / T.sqrt((2*np.pi)**2 * self.cov_det_var)) *
75
                      T.exp(-0.5 * ((vec-mean).T.dot(self.cov_inv_var).dot(vec-mean))))
76
        return norm2d_var
77
78
    def _glimpse_network(self, x_t, l_p):
79
        """
80
        """
81
        sensor_output = self._refined_glimpse_sensor(x_t, l_p)
82
        sensor_output = T.flatten(sensor_output)
83
        h_g = self._relu(T.dot(sensor_output, self.W_g0))
84
        h_l = self._relu(T.dot(l_p, self.W_g1))
85
        g = self._relu(T.dot(h_g, self.W_g2_hg) + T.dot(h_l, self.W_g2_hl))
86
        return g
87
88
    def _location_network(self, h_t):
89
        """
90
        Parameters:
91
            h_t - 256x1 vector
92
        Returns:
93
            2x1 focus vector
94
        """
95
        return T.dot(h_t, self.W_l)
96
97
    def _action_network(self, h_t):
98
        """
99
        Parameters:
100
            h_t - 256x1 vector
101
        Returns:
102
            10x1 vector
103
        """
104
        z = self._relu(T.dot(h_t, self.W_a) + self.B_a)
105
        return self._softmax(z)
106
107
    def _core_network(self, l_p, h_p, x_t):
108
        """
109
        Parameters:
110
            x_t - 28x28 image
111
            l_p - 2x1 focus vector
112
            h_p - 256x1 vector
113
        Returns:
114
            h_t, 256x1 vector
115
        """
116
        g_t = self._glimpse_network(x_t, l_p)
117
        h_t = self._tanh(T.dot(g_t, self.W_h_g) + T.dot(h_p, self.W_h) + self.B_h)
118
        l_t = self._location_network(h_t)
119
120
        if not self.disable_reinforce:
121
            sampled_l_t = self._sample_gaussian(l_t, self.cov)
122
            sampled_pdf = self._multi_gaussian_pdf(disconnected_grad(sampled_l_t), l_t)
123
            wl_grad = T.grad(T.log(sampled_pdf), self.W_l)
124
        else:
125
            sampled_l_t = l_t
126
            wl_grad = self.W_l
127
128
        if self.random_glimpse and self.disable_reinforce:
129
            sampled_l_t = self.srng.uniform((2,), low=-1.7, high=1.7)
130
131
        a_t = self._action_network(h_t)
132
133
        return sampled_l_t, h_t, a_t, wl_grad
134
135
136
    def _output_func(self):
137
        self.x = self.x.reshape((28, 28))
138
        first_l, wf_grad = self._first_glimpse_sensor(self.x)
139
140
        [l_ts, h_ts, a_ts, wl_grads], _ = theano.scan(fn=self._core_network,
141
                         outputs_info=[first_l, self.h0, None, None],
142
                         non_sequences=[self.x],
143
                         n_steps=5)
144
145
        self.positions = l_ts
146
        self.last_decision = T.argmax(a_ts[-1])
147
        wl_grad = T.sum(wl_grads, axis=0) / wl_grads.shape[0]
148
        self.wl_grad = wl_grad
149
        self.wf_grad = wf_grad
150
        return a_ts[-1].reshape((1,10))
151
152
    def _setup_functions(self):
153
        self._relu = get_activation("tanh")
154
        self._tanh = get_activation("tanh")
155
        self._softmax = get_activation("softmax")
156
        self.output_func = self._output_func()
157
158 View Code Duplication
    def _setup_params(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
159
        self.srng = RandomStreams(seed=234)
160
        self.large_cov = np.array([[0.06,0],[0,0.06]], dtype=FLOATX)
161
        self.small_cov = np.array([[self.gaussian_std,0],[0,self.gaussian_std]], dtype=FLOATX)
162
        self.cov = theano.shared(np.array(self.small_cov, dtype=FLOATX))
163
        self.cov_inv_var = theano.shared(np.array(LA.inv(self.small_cov), dtype=FLOATX))
164
        self.cov_det_var = theano.shared(np.array(LA.det(self.small_cov), dtype=FLOATX))
165
        self._sample_gaussian = SampleMultivariateGaussian()
166
167
        self.W_g0 = self.create_weight(7 * 7, 128, label="g0")
168
        self.W_g1 = self.create_weight(2, 128, label="g1")
169
        self.W_g2_hg = self.create_weight(128, 256, label="g2_hg")
170
        self.W_g2_hl = self.create_weight(128, 256, label="g2_hl")
171
172
        self.W_h_g = self.create_weight(256, 256, label="h_g")
173
        self.W_h = self.create_weight(256, 256, label="h")
174
        self.B_h = self.create_bias(256, label="h")
175
        self.h0 = self.create_vector(256, "h0")
176
        self.l0 = self.create_vector(2, "l0")
177
        self.l0.set_value(np.array([-1, -1], dtype=FLOATX))
178
179
        self.W_l = self.create_weight(256, 2, label="l")
180
        self.W_l.set_value(self.W_l.get_value() / 10)
181
        self.B_l = self.create_bias(2, label="l")
182
        self.W_a = self.create_weight(256, 10, label="a")
183
        self.B_a = self.create_bias(10, label="a")
184
185
        self.W_f = self.create_weight(7 * 7, 2, label="f")
186
187
188
        self.W = [self.W_g0, self.W_g1, self.W_g2_hg, self.W_g2_hl, self.W_h_g, self.W_h, self.W_a]
189
        self.B = [self.B_h, self.B_a]
190
        self.parameters = [self.W_l, self.W_f]
191
192
def get_network(model=None, std=0.005, disable_reinforce=False, random_glimpse=False):
193
    """
194
    Get baseline model.
195
    Parameters:
196
        model - model path
197
    Returns:
198
        network
199
    """
200
    network = NeuralClassifier(input_dim=28 * 28)
201
    network.stack_layer(FirstGlimpseLayer(std=std, disable_reinforce=disable_reinforce, random_glimpse=random_glimpse))
202
    if model and os.path.exists(model):
203
        network.load_params(model)
204
    return network
205