experiments.attention_models.AttentionTrainer   B
last analyzed

Complexity

Total Complexity 36

Size/Duplication

Total Lines 140
Duplicated Lines 0 %
Metric Value
dl 0
loc 140
rs 8.8
wmc 36

3 Methods

Rating   Name   Duplication   Size   Complexity  
B update_parameters() 0 12 6
F train_func() 0 92 25
B __init__() 0 30 5
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
import sys
4
5
import numpy as np
6
from numpy import linalg as LA
7
from theano import tensor as T
8
import theano
9
10
from deepy.utils.functions import FLOATX
11
from deepy.trainers import CustomizeTrainer
12
from deepy.trainers.optimize import optimize_function
13
14
15
class AttentionTrainer(CustomizeTrainer):
16
17
    def __init__(self, network, attention_layer, config):
18
        """
19
        Parameters:
20
            network - AttentionNetwork
21
            config - training config
22
        :type network: NeuralClassifier
23
        :type attention_layer: experiments.attention_models.baseline_model.AttentionLayer
24
        :type config: TrainerConfig
25
        """
26
        super(AttentionTrainer, self).__init__(network, config)
27
        self.large_cov_mode = False
28
        self.batch_size = config.get("batch_size", 20)
29
        self.disable_backprop = config.get("disable_backprop", False)
30
        self.disable_reinforce = config.get("disable_reinforce", False)
31
        self.last_average_reward = 999
32
        self.turn = 1
33
        self.layer = attention_layer
34
        if self.disable_backprop:
35
            grads = []
36
        else:
37
            grads = [T.grad(self.cost, p) for p in network.weights + network.biases]
38
        if self.disable_reinforce:
39
            grad_l = self.layer.W_l
40
        else:
41
            grad_l = self.layer.wl_grad
42
        self.batch_wl_grad = np.zeros(attention_layer.W_l.get_value().shape, dtype=FLOATX)
43
        self.batch_grad = [np.zeros(p.get_value().shape, dtype=FLOATX) for p in network.weights + network.biases]
44
        self.grad_func = theano.function(network.inputs, [self.cost, grad_l, attention_layer.positions, attention_layer.last_decision] + grads, allow_input_downcast=True)
45
        self.opt_interface = optimize_function(self.network.weights + self.network.biases, self.config)
46
        self.l_opt_interface = optimize_function([self.layer.W_l], self.config)
47
        # self.opt_interface = gradient_interface_future(self.network.weights + self.network.biases, config=self.config)
48
        # self.l_opt_interface = gradient_interface_future([self.layer.W_l], config=self.config)
49
50
    def update_parameters(self, update_wl):
51
        if not self.disable_backprop:
52
            grads = [self.batch_grad[i] / self.batch_size for i in range(len(self.network.weights + self.network.biases))]
53
            self.opt_interface(*grads)
54
        # REINFORCE update
55
        if update_wl and not self.disable_reinforce:
56
            if np.sum(self.batch_wl_grad) == 0:
57
                sys.stdout.write("[0 WLG] ")
58
                sys.stdout.flush()
59
            else:
60
                grad_wl = self.batch_wl_grad / self.batch_size
61
                self.l_opt_interface(grad_wl)
62
63
    def train_func(self, train_set):
64
        cost_sum = 0.0
65
        batch_cost = 0.0
66
        counter = 0
67
        total = 0
68
        total_reward = 0
69
        batch_reward = 0
70
        total_position_value = 0
71
        pena_count = 0
72
        for d in train_set:
73
            pairs = self.grad_func(*d)
74
            cost = pairs[0]
75
            if cost > 10 or np.isnan(cost):
76
                sys.stdout.write("X")
77
                sys.stdout.flush()
78
                continue
79
            batch_cost += cost
80
81
            wl_grad = pairs[1]
82
            max_position_value = np.max(np.absolute(pairs[2]))
83
            total_position_value += max_position_value
84
            last_decision = pairs[3]
85
            target_decision = d[1][0]
86
            reward = 0.005 if last_decision == target_decision else 0
87
            if max_position_value > 0.8:
88
                reward =  0
89
            total_reward += reward
90
            batch_reward += reward
91
            if self.last_average_reward == 999 and total > 2000:
92
                self.last_average_reward = total_reward / total
93
            if not self.disable_reinforce:
94
                self.batch_wl_grad += wl_grad *  - (reward - self.last_average_reward)
95
            if not self.disable_backprop:
96
                for grad_cache, grad in zip(self.batch_grad, pairs[4:]):
97
                    grad_cache += grad
98
            counter += 1
99
            total += 1
100
            if counter >= self.batch_size:
101
                if total == counter: counter -= 1
102
                self.update_parameters(self.last_average_reward < 999)
103
104
                # Clean batch gradients
105
                if not self.disable_reinforce:
106
                    self.batch_wl_grad *= 0
107
                if not self.disable_backprop:
108
                    for grad_cache in self.batch_grad:
109
                        grad_cache *= 0
110
111
                if total % 1000 == 0:
112
                    sys.stdout.write(".")
113
                    sys.stdout.flush()
114
115
                # Cov
116
                if not self.disable_reinforce:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
117
                    cov_changed = False
118
                    if batch_reward / self.batch_size < 0.001:
119
                        if not self.large_cov_mode:
120
                            if pena_count > 20:
121
                                self.layer.cov.set_value(self.layer.large_cov)
122
                                print "[LCOV]",
123
                                cov_changed = True
124
                            else:
125
                                pena_count += 1
126
                        else:
127
                            pena_count = 0
128
                    else:
129
                        if self.large_cov_mode:
130
                            if pena_count > 20:
131
                                self.layer.cov.set_value(self.layer.small_cov)
132
                                print "[SCOV]",
133
                                cov_changed = True
134
                            else:
135
                                pena_count += 1
136
                        else:
137
                            pena_count = 0
138
                    if cov_changed:
139
                        self.large_cov_mode = not self.large_cov_mode
140
                        self.layer.cov_inv_var.set_value(np.array(LA.inv(self.layer.cov.get_value()), dtype=FLOATX))
141
                        self.layer.cov_det_var.set_value(LA.det(self.layer.cov.get_value()))
142
143
                # Clean batch cost
144
                counter = 0
145
                cost_sum += batch_cost
146
                batch_cost = 0.0
147
                batch_reward = 0
148
        if total == 0:
149
            return "COST OVERFLOW"
150
151
        sys.stdout.write("\n")
152
        self.last_average_reward = (total_reward / total)
153
        self.turn += 1
154
        return "J: %.2f, Avg R: %.4f, Avg P: %.2f" % ((cost_sum / total), self.last_average_reward, (total_position_value / total))
155
156