experiments.attention_models.FirstGlimpseTrainer   B
last analyzed

Complexity

Total Complexity 37

Size/Duplication

Total Lines 152
Duplicated Lines 0 %
Metric Value
dl 0
loc 152
rs 8.6
wmc 37

3 Methods

Rating   Name   Duplication   Size   Complexity  
F train_func() 0 100 25
B __init__() 0 35 5
B update_parameters() 0 13 7
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
import sys
4
5
import numpy as np
6
from numpy import linalg as LA
7
from theano import tensor as T
8
import theano
9
10
from deepy.utils.functions import FLOATX
11
from deepy.trainers import CustomizeTrainer
12
from deepy.trainers.optimize import optimize_function
13
14
15
class FirstGlimpseTrainer(CustomizeTrainer):
16
17
    def __init__(self, network, attention_layer, config):
18
        """
19
        Parameters:
20
            network - AttentionNetwork
21
            config - training config
22
        :type network: NeuralClassifier
23
        :type attention_layer: experiments.attention_models.first_glimpse_model.FirstGlimpseLayer
24
        :type config: TrainerConfig
25
        """
26
        super(FirstGlimpseTrainer, self).__init__(network, config)
27
        self.large_cov_mode = False
28
        self.batch_size = config.get("batch_size", 20)
29
        self.disable_backprop = config.get("disable_backprop", False)
30
        self.disable_reinforce = config.get("disable_reinforce", False)
31
        self.last_average_reward = 999
32
        self.turn = 1
33
        self.layer = attention_layer
34
        if self.disable_backprop:
35
            grads = []
36
        else:
37
            grads = [T.grad(self.cost, p) for p in network.weights + network.biases]
38
        if self.disable_reinforce:
39
            grad_l = self.layer.W_l
40
            grad_f = self.layer.W_f
41
        else:
42
            grad_l = self.layer.wl_grad
43
            grad_f = self.layer.wf_grad
44
        self.batch_wl_grad = np.zeros(attention_layer.W_l.get_value().shape, dtype=FLOATX)
45
        self.batch_wf_grad = np.zeros(attention_layer.W_f.get_value().shape, dtype=FLOATX)
46
        self.batch_grad = [np.zeros(p.get_value().shape, dtype=FLOATX) for p in network.weights + network.biases]
47
        self.grad_func = theano.function(network.inputs,
48
                                         [self.cost, grad_l, grad_f, attention_layer.positions, attention_layer.last_decision] + grads,
49
                                         allow_input_downcast=True)
50
        self.opt_func = optimize_function(self.network.weights + self.network.biases, self.config)
51
        self.rl_opt_func = optimize_function([self.layer.W_l, self.layer.W_f], self.config)
52
53
    def update_parameters(self, update_rl):
54
        if not self.disable_backprop:
55
            grads = [self.batch_grad[i] / self.batch_size for i in range(len(self.network.weights + self.network.biases))]
56
            self.opt_func(*grads)
57
        # REINFORCE update
58
        if update_rl and not self.disable_reinforce:
59
            if np.sum(self.batch_wl_grad) == 0 or np.sum(self.batch_wf_grad) == 0:
60
                sys.stdout.write("0WRL ")
61
                sys.stdout.flush()
62
            else:
63
                grad_wl = self.batch_wl_grad / self.batch_size
64
                grad_wf = self.batch_wf_grad / self.batch_size
65
                self.rl_opt_func(grad_wl, grad_wf)
66
67
    def train_func(self, train_set):
68
        cost_sum = 0.0
69
        batch_cost = 0.0
70
        counter = 0
71
        total = 0
72
        total_reward = 0
73
        batch_reward = 0
74
        total_position_value = 0
75
        pena_count = 0
76
        for d in train_set:
77
            pairs = self.grad_func(*d)
78
            cost = pairs[0]
79
            if cost > 10 or np.isnan(cost):
80
                sys.stdout.write("X")
81
                sys.stdout.flush()
82
                continue
83
            batch_cost += cost
84
85
            wl_grad = pairs[1]
86
            wf_grad = pairs[2]
87
            max_position_value = np.max(np.absolute(pairs[3]))
88
            total_position_value += max_position_value
89
            last_decision = pairs[4]
90
            target_decision = d[1][0]
91
            # Compute reward
92
            reward = 0.005 if last_decision == target_decision else 0
93
            if max_position_value > 1.8:
94
                reward =  0
95
            # if cost > 5:
96
            #     cost = 5
97
            # reward += (5 - cost) / 100
98
            total_reward += reward
99
            batch_reward += reward
100
            if self.last_average_reward == 999 and total > 2000:
101
                self.last_average_reward = total_reward / total
102
103
            if not self.disable_reinforce:
104
                self.batch_wl_grad += wl_grad *  - (reward - self.last_average_reward)
105
                self.batch_wf_grad += wf_grad *  - (reward - self.last_average_reward)
106
            if not self.disable_backprop:
107
                for grad_cache, grad in zip(self.batch_grad, pairs[5:]):
108
                    grad_cache += grad
109
            counter += 1
110
            total += 1
111
            if counter >= self.batch_size:
112
                if total == counter: counter -= 1
113
                self.update_parameters(self.last_average_reward < 999)
114
115
                # Clean batch gradients
116
                if not self.disable_reinforce:
117
                    self.batch_wl_grad *= 0
118
                    self.batch_wf_grad *= 0
119
                if not self.disable_backprop:
120
                    for grad_cache in self.batch_grad:
121
                        grad_cache *= 0
122
123
                if total % 1000 == 0:
124
                    sys.stdout.write(".")
125
                    sys.stdout.flush()
126
127
                # Cov
128
                if not self.disable_reinforce:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
129
                    cov_changed = False
130
                    if batch_reward / self.batch_size < 0.001:
131
                        if not self.large_cov_mode:
132
                            if pena_count > 20:
133
                                self.layer.cov.set_value(self.layer.large_cov)
134
                                print "[LCOV]",
135
                                cov_changed = True
136
                            else:
137
                                pena_count += 1
138
                        else:
139
                            pena_count = 0
140
                    else:
141
                        if self.large_cov_mode:
142
                            if pena_count > 20:
143
                                self.layer.cov.set_value(self.layer.small_cov)
144
                                print "[SCOV]",
145
                                cov_changed = True
146
                            else:
147
                                pena_count += 1
148
                        else:
149
                            pena_count = 0
150
                    if cov_changed:
151
                        self.large_cov_mode = not self.large_cov_mode
152
                        self.layer.cov_inv_var.set_value(np.array(LA.inv(self.layer.cov.get_value()), dtype=FLOATX))
153
                        self.layer.cov_det_var.set_value(LA.det(self.layer.cov.get_value()))
154
155
                # Clean batch cost
156
                counter = 0
157
                cost_sum += batch_cost
158
                batch_cost = 0.0
159
                batch_reward = 0
160
        if total == 0:
161
            return "COST OVERFLOW"
162
163
        sys.stdout.write("\n")
164
        self.last_average_reward = (total_reward / total)
165
        self.turn += 1
166
        return "J: %.2f, Avg R: %.4f, Avg P: %.2f" % ((cost_sum / total), self.last_average_reward, (total_position_value / total))
167
168