Completed
Pull Request — master (#1030)
by
unknown
04:44
created

LookupTable._initialize()   A

Complexity

Conditions 1

Size

Total Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
"""Introduces Lookup brick."""
2
from blocks.bricks import Initializable, Feedforward
3
from blocks.bricks.base import application, lazy
4
from blocks.roles import WEIGHT, add_role
5
from blocks.utils import check_theano_variable, shared_floatx_nans
6
7
8
class LookupTable(Initializable, Feedforward):
9
    """Encapsulates representations of a range of integers.
10
11
    This brick can be used to embed integers, e.g. word indices,
12
    into a vector space.
13
14
    Parameters
15
    ----------
16
    length : int
17
        The size of the lookup table, or in other words, one plus the
18
        maximum index for which a representation is contained.
19
    dim : int
20
        The dimensionality of representations.
21
22
    Notes
23
    -----
24
    See :class:`.Initializable` for initialization parameters.
25
26
    """
27
    has_bias = False
28
29
    @lazy(allocation=['length', 'dim'])
30
    def __init__(self, length, dim, **kwargs):
31
        super(LookupTable, self).__init__(**kwargs)
32
        self.length = length
33
        self.dim = dim
34
35
    @property
36
    def W(self):
37
        return self.parameters[0]
38
39
    def _allocate(self):
40
        self.parameters.append(shared_floatx_nans((self.length, self.dim),
41
                               name='W'))
42
        add_role(self.parameters[-1], WEIGHT)
43
44
    @application(inputs=['indices'], outputs=['output'])
45
    def apply(self, indices):
46
        """Perform lookup.
47
48
        Parameters
49
        ----------
50
        indices : :class:`~tensor.TensorVariable`
51
            The indices of interest. The dtype must be integer.
52
53
        Returns
54
        -------
55
        output : :class:`~tensor.TensorVariable`
56
            Representations for the indices of the query. Has :math:`k+1`
57
            dimensions, where :math:`k` is the number of dimensions of the
58
            `indices` parameter. The last dimension stands for the
59
            representation element.
60
61
        """
62
        check_theano_variable(indices, None, ("int", "uint"))
63
        output_shape = [indices.shape[i]
64
                        for i in range(indices.ndim)] + [self.dim]
65
        return self.W[indices.flatten()].reshape(output_shape)
66
67
    def get_dim(self, name):
68
        if name == 'output':
69
            return self.dim
70
        if name == 'indices':
71
            return 0
72
        return super(LookupTable, self).get_dim(name)
73
74
    @property
75
    def input_dim(self):
76
        return 0
77
78
    @input_dim.setter
79
    def input_dim(self, dim):
80
        if dim != 0:
81
            raise ValueError("LookupTable input must be integer")
82
83
    @property
84
    def output_dim(self):
85
        return self.dim
86
87
    @output_dim.setter
88
    def output_dim(self, dim):
89
        self.dim = dim
90