Completed
Push — master ( 394368...090fba )
by Raphael
01:33
created

Attention.compute_context_vector()   A

Complexity

Conditions 2

Size

Total Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 2
c 1
b 0
f 0
dl 0
loc 8
rs 9.4285
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
from layer import NeuralLayer
5
import deepy.tensor as T
6
7
class Attention(NeuralLayer):
8
9
10
    def __init__(self, hidden_size, input_dim):
11
        super(Attention, self).__init__("attention")
12
        self.input_dim = input_dim if input_dim else hidden_size
13
        self.hidden_size = hidden_size
14
        self.init(input_dim)
15
16
    def prepare(self):
17
        self.Ua = self.create_weight(self.input_dim, self.hidden_size, "ua")
18
        self.Wa = self.create_weight(self.hidden_size, self.hidden_size, "wa")
19
        self.Va = self.create_weight(label="va", shape=(self.hidden_size,))
20
        self.register_parameters(self.Va, self.Wa, self.Ua)
21
22
23
    def precompute(self, inputs):
24
        """
25
        Precompute partial values in the score function.
26
        """
27
        return T.dot(inputs, self.Ua)
28
29
    def compute_alignments(self, prev_state, precomputed_values, mask=None):
30
        """
31
        Compute the alignment weights based on the previous state.
32
        """
33
34
        WaSp = T.dot(prev_state, self.Wa)
35
        UaH = precomputed_values
36
        # For test time the UaH will be (time, output_dim)
37
        if UaH.ndim == 2:
38
            preact = WaSp[:, None, :] + UaH[None, :, :]
39
        else:
40
            preact = WaSp[:, None, :] + UaH
41
        act = T.activate(preact, 'tanh')
42
        align_scores = T.dot(act, self.Va)  # ~ (batch, time)
43
        if mask:
44
            mask = (1 - mask) * -99.00
45
            if align_scores.ndim == 3:
46
                align_scores += mask[None, :]
47
            else:
48
                align_scores += mask
49
        align_weights = T.nnet.softmax(align_scores)
50
        return align_weights
51
52
    def compute_context_vector(self, prev_state, inputs, precomputed_values=None, mask=None):
53
        """
54
        Compute the context vector with soft attention.
55
        """
56
        precomputed_values = precomputed_values if precomputed_values else self.precompute(inputs)
57
        align_weights = self.compute_alignments(prev_state, precomputed_values, mask)
58
        context_vector = T.sum(align_weights[:, :, None] * inputs, axis=1)
59
        return context_vector