LookupTable   A
last analyzed

Complexity

Total Complexity 14

Size/Duplication

Total Lines 85
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
dl 0
loc 85
rs 10
c 0
b 0
f 0
wmc 14

8 Methods

Rating   Name   Duplication   Size   Complexity  
A __init__() 0 5 1
A _initialize() 0 2 1
A apply() 0 22 2
A W() 0 3 1
A _allocate() 0 4 1
A get_dim() 0 6 3
A output_dim() 0 3 1
A input_dim() 0 3 2
1
"""Introduces Lookup brick."""
2
from blocks.bricks import Initializable, Feedforward
3
from blocks.bricks.base import application, lazy
4
from blocks.roles import WEIGHT, add_role
5
from blocks.utils import check_theano_variable, shared_floatx_nans
6
7
8
class LookupTable(Initializable, Feedforward):
9
    """Encapsulates representations of a range of integers.
10
11
    This brick can be used to embed integers, e.g. word indices,
12
    into a vector space.
13
14
    Parameters
15
    ----------
16
    length : int
17
        The size of the lookup table, or in other words, one plus the
18
        maximum index for which a representation is contained.
19
    dim : int
20
        The dimensionality of representations.
21
22
    Notes
23
    -----
24
    See :class:`.Initializable` for initialization parameters.
25
26
    """
27
    has_bias = False
28
29
    @lazy(allocation=['length', 'dim'])
30
    def __init__(self, length, dim, **kwargs):
31
        super(LookupTable, self).__init__(**kwargs)
32
        self.length = length
33
        self.dim = dim
34
35
    @property
36
    def W(self):
37
        return self.parameters[0]
38
39
    def _allocate(self):
40
        self.parameters.append(shared_floatx_nans((self.length, self.dim),
41
                               name='W'))
42
        add_role(self.parameters[-1], WEIGHT)
43
44
    def _initialize(self):
45
        self.weights_init.initialize(self.W, self.rng)
46
47
    @application(inputs=['indices'], outputs=['output'])
48
    def apply(self, indices):
49
        """Perform lookup.
50
51
        Parameters
52
        ----------
53
        indices : :class:`~tensor.TensorVariable`
54
            The indices of interest. The dtype must be integer.
55
56
        Returns
57
        -------
58
        output : :class:`~tensor.TensorVariable`
59
            Representations for the indices of the query. Has :math:`k+1`
60
            dimensions, where :math:`k` is the number of dimensions of the
61
            `indices` parameter. The last dimension stands for the
62
            representation element.
63
64
        """
65
        check_theano_variable(indices, None, ("int", "uint"))
66
        output_shape = [indices.shape[i]
67
                        for i in range(indices.ndim)] + [self.dim]
68
        return self.W[indices.flatten()].reshape(output_shape)
69
70
    def get_dim(self, name):
71
        if name == 'output':
72
            return self.dim
73
        if name == 'indices':
74
            return 0
75
        return super(LookupTable, self).get_dim(name)
76
77
    @property
78
    def input_dim(self):
79
        return 0
80
81
    @input_dim.setter
82
    def input_dim(self, dim):
83
        if dim != 0:
84
            raise ValueError("LookupTable input must be integer")
85
86
    @property
87
    def output_dim(self):
88
        return self.dim
89
90
    @output_dim.setter
91
    def output_dim(self, dim):
92
        self.dim = dim
93