FirstGlimpseTrainer.train_func()   F
last analyzed

Complexity

Conditions 25

Size

Total Lines 100

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 25
c 1
b 0
f 0
dl 0
loc 100
rs 2

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like FirstGlimpseTrainer.train_func() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
import sys
4
5
import numpy as np
6
from numpy import linalg as LA
7
from theano import tensor as T
8
import theano
9
10
from deepy.utils.functions import FLOATX
11
from deepy.trainers import CustomizeTrainer
12
from deepy.trainers.optimize import optimize_function
13
14
15
class FirstGlimpseTrainer(CustomizeTrainer):
16
17
    def __init__(self, network, attention_layer, config):
18
        """
19
        Parameters:
20
            network - AttentionNetwork
21
            config - training config
22
        :type network: NeuralClassifier
23
        :type attention_layer: experiments.attention_models.first_glimpse_model.FirstGlimpseLayer
24
        :type config: TrainerConfig
25
        """
26
        super(FirstGlimpseTrainer, self).__init__(network, config)
27
        self.large_cov_mode = False
28
        self.batch_size = config.get("batch_size", 20)
29
        self.disable_backprop = config.get("disable_backprop", False)
30
        self.disable_reinforce = config.get("disable_reinforce", False)
31
        self.last_average_reward = 999
32
        self.turn = 1
33
        self.layer = attention_layer
34
        if self.disable_backprop:
35
            grads = []
36
        else:
37
            grads = [T.grad(self.cost, p) for p in network.weights + network.biases]
38
        if self.disable_reinforce:
39
            grad_l = self.layer.W_l
40
            grad_f = self.layer.W_f
41
        else:
42
            grad_l = self.layer.wl_grad
43
            grad_f = self.layer.wf_grad
44
        self.batch_wl_grad = np.zeros(attention_layer.W_l.get_value().shape, dtype=FLOATX)
45
        self.batch_wf_grad = np.zeros(attention_layer.W_f.get_value().shape, dtype=FLOATX)
46
        self.batch_grad = [np.zeros(p.get_value().shape, dtype=FLOATX) for p in network.weights + network.biases]
47
        self.grad_func = theano.function(network.inputs,
48
                                         [self.cost, grad_l, grad_f, attention_layer.positions, attention_layer.last_decision] + grads,
49
                                         allow_input_downcast=True)
50
        self.opt_func = optimize_function(self.network.weights + self.network.biases, self.config)
51
        self.rl_opt_func = optimize_function([self.layer.W_l, self.layer.W_f], self.config)
52
53
    def update_parameters(self, update_rl):
54
        if not self.disable_backprop:
55
            grads = [self.batch_grad[i] / self.batch_size for i in range(len(self.network.weights + self.network.biases))]
56
            self.opt_func(*grads)
57
        # REINFORCE update
58
        if update_rl and not self.disable_reinforce:
59
            if np.sum(self.batch_wl_grad) == 0 or np.sum(self.batch_wf_grad) == 0:
60
                sys.stdout.write("0WRL ")
61
                sys.stdout.flush()
62
            else:
63
                grad_wl = self.batch_wl_grad / self.batch_size
64
                grad_wf = self.batch_wf_grad / self.batch_size
65
                self.rl_opt_func(grad_wl, grad_wf)
66
67
    def train_func(self, train_set):
68
        cost_sum = 0.0
69
        batch_cost = 0.0
70
        counter = 0
71
        total = 0
72
        total_reward = 0
73
        batch_reward = 0
74
        total_position_value = 0
75
        pena_count = 0
76
        for d in train_set:
77
            pairs = self.grad_func(*d)
78
            cost = pairs[0]
79
            if cost > 10 or np.isnan(cost):
80
                sys.stdout.write("X")
81
                sys.stdout.flush()
82
                continue
83
            batch_cost += cost
84
85
            wl_grad = pairs[1]
86
            wf_grad = pairs[2]
87
            max_position_value = np.max(np.absolute(pairs[3]))
88
            total_position_value += max_position_value
89
            last_decision = pairs[4]
90
            target_decision = d[1][0]
91
            # Compute reward
92
            reward = 0.005 if last_decision == target_decision else 0
93
            if max_position_value > 1.8:
94
                reward =  0
95
            # if cost > 5:
96
            #     cost = 5
97
            # reward += (5 - cost) / 100
98
            total_reward += reward
99
            batch_reward += reward
100
            if self.last_average_reward == 999 and total > 2000:
101
                self.last_average_reward = total_reward / total
102
103
            if not self.disable_reinforce:
104
                self.batch_wl_grad += wl_grad *  - (reward - self.last_average_reward)
105
                self.batch_wf_grad += wf_grad *  - (reward - self.last_average_reward)
106
            if not self.disable_backprop:
107
                for grad_cache, grad in zip(self.batch_grad, pairs[5:]):
108
                    grad_cache += grad
109
            counter += 1
110
            total += 1
111
            if counter >= self.batch_size:
112
                if total == counter: counter -= 1
113
                self.update_parameters(self.last_average_reward < 999)
114
115
                # Clean batch gradients
116
                if not self.disable_reinforce:
117
                    self.batch_wl_grad *= 0
118
                    self.batch_wf_grad *= 0
119
                if not self.disable_backprop:
120
                    for grad_cache in self.batch_grad:
121
                        grad_cache *= 0
122
123
                if total % 1000 == 0:
124
                    sys.stdout.write(".")
125
                    sys.stdout.flush()
126
127
                # Cov
128
                if not self.disable_reinforce:
129
                    cov_changed = False
130
                    if batch_reward / self.batch_size < 0.001:
131
                        if not self.large_cov_mode:
132
                            if pena_count > 20:
133
                                self.layer.cov.set_value(self.layer.large_cov)
134
                                print "[LCOV]",
135
                                cov_changed = True
136
                            else:
137
                                pena_count += 1
138
                        else:
139
                            pena_count = 0
140
                    else:
141
                        if self.large_cov_mode:
142
                            if pena_count > 20:
143
                                self.layer.cov.set_value(self.layer.small_cov)
144
                                print "[SCOV]",
145
                                cov_changed = True
146
                            else:
147
                                pena_count += 1
148
                        else:
149
                            pena_count = 0
150
                    if cov_changed:
151
                        self.large_cov_mode = not self.large_cov_mode
152
                        self.layer.cov_inv_var.set_value(np.array(LA.inv(self.layer.cov.get_value()), dtype=FLOATX))
153
                        self.layer.cov_det_var.set_value(LA.det(self.layer.cov.get_value()))
154
155
                # Clean batch cost
156
                counter = 0
157
                cost_sum += batch_cost
158
                batch_cost = 0.0
159
                batch_reward = 0
160
        if total == 0:
161
            return "COST OVERFLOW"
162
163
        sys.stdout.write("\n")
164
        self.last_average_reward = (total_reward / total)
165
        self.turn += 1
166
        return "J: %.2f, Avg R: %.4f, Avg P: %.2f" % ((cost_sum / total), self.last_average_reward, (total_position_value / total))
167
168