AttentionTrainer   B
last analyzed

Complexity

Total Complexity 36

Size/Duplication

Total Lines 140
Duplicated Lines 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
c 1
b 0
f 0
dl 0
loc 140
rs 8.8
wmc 36

3 Methods

Rating   Name   Duplication   Size   Complexity  
B __init__() 0 30 5
F train_func() 0 92 25
B update_parameters() 0 12 6
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
import sys
4
5
import numpy as np
6
from numpy import linalg as LA
7
from theano import tensor as T
8
import theano
9
10
from deepy.utils.functions import FLOATX
11
from deepy.trainers import CustomizeTrainer
12
from deepy.trainers.optimize import optimize_function
13
14
15
class AttentionTrainer(CustomizeTrainer):
16
17
    def __init__(self, network, attention_layer, config):
18
        """
19
        Parameters:
20
            network - AttentionNetwork
21
            config - training config
22
        :type network: NeuralClassifier
23
        :type attention_layer: experiments.attention_models.baseline_model.AttentionLayer
24
        :type config: TrainerConfig
25
        """
26
        super(AttentionTrainer, self).__init__(network, config)
27
        self.large_cov_mode = False
28
        self.batch_size = config.get("batch_size", 20)
29
        self.disable_backprop = config.get("disable_backprop", False)
30
        self.disable_reinforce = config.get("disable_reinforce", False)
31
        self.last_average_reward = 999
32
        self.turn = 1
33
        self.layer = attention_layer
34
        if self.disable_backprop:
35
            grads = []
36
        else:
37
            grads = [T.grad(self.cost, p) for p in network.weights + network.biases]
38
        if self.disable_reinforce:
39
            grad_l = self.layer.W_l
40
        else:
41
            grad_l = self.layer.wl_grad
42
        self.batch_wl_grad = np.zeros(attention_layer.W_l.get_value().shape, dtype=FLOATX)
43
        self.batch_grad = [np.zeros(p.get_value().shape, dtype=FLOATX) for p in network.weights + network.biases]
44
        self.grad_func = theano.function(network.inputs, [self.cost, grad_l, attention_layer.positions, attention_layer.last_decision] + grads, allow_input_downcast=True)
45
        self.opt_interface = optimize_function(self.network.weights + self.network.biases, self.config)
46
        self.l_opt_interface = optimize_function([self.layer.W_l], self.config)
47
        # self.opt_interface = gradient_interface_future(self.network.weights + self.network.biases, config=self.config)
48
        # self.l_opt_interface = gradient_interface_future([self.layer.W_l], config=self.config)
49
50
    def update_parameters(self, update_wl):
51
        if not self.disable_backprop:
52
            grads = [self.batch_grad[i] / self.batch_size for i in range(len(self.network.weights + self.network.biases))]
53
            self.opt_interface(*grads)
54
        # REINFORCE update
55
        if update_wl and not self.disable_reinforce:
56
            if np.sum(self.batch_wl_grad) == 0:
57
                sys.stdout.write("[0 WLG] ")
58
                sys.stdout.flush()
59
            else:
60
                grad_wl = self.batch_wl_grad / self.batch_size
61
                self.l_opt_interface(grad_wl)
62
63
    def train_func(self, train_set):
64
        cost_sum = 0.0
65
        batch_cost = 0.0
66
        counter = 0
67
        total = 0
68
        total_reward = 0
69
        batch_reward = 0
70
        total_position_value = 0
71
        pena_count = 0
72
        for d in train_set:
73
            pairs = self.grad_func(*d)
74
            cost = pairs[0]
75
            if cost > 10 or np.isnan(cost):
76
                sys.stdout.write("X")
77
                sys.stdout.flush()
78
                continue
79
            batch_cost += cost
80
81
            wl_grad = pairs[1]
82
            max_position_value = np.max(np.absolute(pairs[2]))
83
            total_position_value += max_position_value
84
            last_decision = pairs[3]
85
            target_decision = d[1][0]
86
            reward = 0.005 if last_decision == target_decision else 0
87
            if max_position_value > 0.8:
88
                reward =  0
89
            total_reward += reward
90
            batch_reward += reward
91
            if self.last_average_reward == 999 and total > 2000:
92
                self.last_average_reward = total_reward / total
93
            if not self.disable_reinforce:
94
                self.batch_wl_grad += wl_grad *  - (reward - self.last_average_reward)
95
            if not self.disable_backprop:
96
                for grad_cache, grad in zip(self.batch_grad, pairs[4:]):
97
                    grad_cache += grad
98
            counter += 1
99
            total += 1
100
            if counter >= self.batch_size:
101
                if total == counter: counter -= 1
102
                self.update_parameters(self.last_average_reward < 999)
103
104
                # Clean batch gradients
105
                if not self.disable_reinforce:
106
                    self.batch_wl_grad *= 0
107
                if not self.disable_backprop:
108
                    for grad_cache in self.batch_grad:
109
                        grad_cache *= 0
110
111
                if total % 1000 == 0:
112
                    sys.stdout.write(".")
113
                    sys.stdout.flush()
114
115
                # Cov
116
                if not self.disable_reinforce:
117
                    cov_changed = False
118
                    if batch_reward / self.batch_size < 0.001:
119
                        if not self.large_cov_mode:
120
                            if pena_count > 20:
121
                                self.layer.cov.set_value(self.layer.large_cov)
122
                                print "[LCOV]",
123
                                cov_changed = True
124
                            else:
125
                                pena_count += 1
126
                        else:
127
                            pena_count = 0
128
                    else:
129
                        if self.large_cov_mode:
130
                            if pena_count > 20:
131
                                self.layer.cov.set_value(self.layer.small_cov)
132
                                print "[SCOV]",
133
                                cov_changed = True
134
                            else:
135
                                pena_count += 1
136
                        else:
137
                            pena_count = 0
138
                    if cov_changed:
139
                        self.large_cov_mode = not self.large_cov_mode
140
                        self.layer.cov_inv_var.set_value(np.array(LA.inv(self.layer.cov.get_value()), dtype=FLOATX))
141
                        self.layer.cov_det_var.set_value(LA.det(self.layer.cov.get_value()))
142
143
                # Clean batch cost
144
                counter = 0
145
                cost_sum += batch_cost
146
                batch_cost = 0.0
147
                batch_reward = 0
148
        if total == 0:
149
            return "COST OVERFLOW"
150
151
        sys.stdout.write("\n")
152
        self.last_average_reward = (total_reward / total)
153
        self.turn += 1
154
        return "J: %.2f, Avg R: %.4f, Avg P: %.2f" % ((cost_sum / total), self.last_average_reward, (total_position_value / total))
155
156