Completed
Pull Request — master (#959)
by Dmitry
01:46
created

blocks.bricks.LookupTable.get_dim()   A

Complexity

Conditions 3

Size

Total Lines 5

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 3
dl 0
loc 5
rs 9.4285
1
"""Introduces Lookup brick."""
2
from blocks.bricks import Initializable
3
from blocks.bricks.base import application, lazy
4
from blocks.roles import WEIGHT, add_role
5
from blocks.utils import check_theano_variable, shared_floatx_nans
6
7
8
class LookupTable(Initializable):
9
    """Encapsulates representations of a range of integers.
10
11
    Parameters
12
    ----------
13
    length : int
14
        The size of the lookup table, or in other words, one plus the
15
        maximum index for which a representation is contained.
16
    dim : int
17
        The dimensionality of representations.
18
19
    Notes
20
    -----
21
    See :class:`.Initializable` for initialization parameters.
22
23
    """
24
    has_bias = False
25
26
    @lazy(allocation=['length', 'dim'])
27
    def __init__(self, length, dim, **kwargs):
28
        super(LookupTable, self).__init__(**kwargs)
29
        self.length = length
30
        self.dim = dim
31
32
    @property
33
    def W(self):
34
        return self.parameters[0]
35
36
    def _allocate(self):
37
        self.parameters.append(shared_floatx_nans((self.length, self.dim),
38
                               name='W'))
39
        add_role(self.parameters[-1], WEIGHT)
40
41
    def _initialize(self):
42
        self.weights_init.initialize(self.W, self.rng)
43
44
    @application(inputs=['indices'], outputs=['output'])
45
    def apply(self, indices):
46
        """Perform lookup.
47
48
        Parameters
49
        ----------
50
        indices : :class:`~tensor.TensorVariable`
51
            The indices of interest. The dtype must be integer.
52
53
        Returns
54
        -------
55
        output : :class:`~tensor.TensorVariable`
56
            Representations for the indices of the query. Has :math:`k+1`
57
            dimensions, where :math:`k` is the number of dimensions of the
58
            `indices` parameter. The last dimension stands for the
59
            representation element.
60
61
        """
62
        check_theano_variable(indices, None, ("int", "uint"))
63
        output_shape = [indices.shape[i]
64
                        for i in range(indices.ndim)] + [self.dim]
65
        return self.W[indices.flatten()].reshape(output_shape)
66
67
    def get_dim(self, name):
68
        if name == 'output':
69
            return self.dim
70
        if name == 'indices':
71
            return 0
72