Completed
Pull Request — master (#1064)
by Dmitry
04:46
created

Model.__init__()   F

Complexity

Conditions 14

Size

Total Lines 27

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 14
dl 0
loc 27
rs 2.7581

How to fix   Complexity   

Complexity

Complex classes like Model.__init__() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""Model - heavily annotated computation graph.
2
3
A model in Blocks is simply an annotated computation graph.  The class
4
:class:`Model` extends :class:`blocks.graph.ComputationGraph` :class:,
5
which is able to handle annotations and roles in general, but is
6
deliberately made unaware of specific annotations that a Theano graph
7
created by Blocks typically has, such as bricks and application calls.  The
8
:class:`Model` adds this functionality. Using :class:`Model` you can do
9
things like query all the bricks used to build the computation graph,
10
request "hierarchical names" of the parameters (a hierarchical name is a
11
path-like string which in addition to the parameter's name contains names
12
of the bricks on the path from a root brick to the brick that owns the
13
parameters, e.g. ``/mlp/linear/W``).
14
15
For more information, see :class:`Model` docstring.
16
17
"""
18
import logging
19
from collections import OrderedDict, Counter
20
from itertools import chain
21
22
from blocks.graph import ComputationGraph
23
from blocks.select import Selector
24
from blocks.filter import get_brick
25
26
logger = logging.getLogger(__name__)
27
28
29
class Model(ComputationGraph):
30
    """Handles annotations in Blocks-built computation graphs.
31
32
    Use this class to handle your Blocks-created computation graph.
33
34
    Examples
35
    --------
36
    >>> from theano import tensor
37
    >>> from blocks.bricks import MLP, Tanh
38
    >>> x = tensor.matrix('x')
39
    >>> mlp = MLP([Tanh(), Tanh()], [10, 10, 10])
40
    >>> y = mlp.apply(x)
41
    >>> model = Model(y)
42
43
    With :class:`Model` you can get access to the brick hierarchy. The
44
    brick hierarchy is defined by ``children`` attributes that every brick
45
    has.  The bricks that are not children of other bricks are called top
46
    bricks.  It is often useful to have access to top bricks of a brick
47
    hierarchy used to build a computation graph, and here is how you can do
48
    it:
49
50
    >>> model.get_top_bricks() #doctest: +ELLIPSIS
51
    [<blocks.bricks.sequences.MLP object at ...]
52
53
    You can also get "hierarchical" names for the parameters,
54
    which encode the position of the owning brick in the
55
    brick hierarchy.
56
57
    >>> model.get_parameter_dict() #doctest: +NORMALIZE_WHITESPACE
58
    OrderedDict([('/mlp/linear_1.b', b), ('/mlp/linear_0.b', b),
59
    ('/mlp/linear_0.W', W), ('/mlp/linear_1.W', W)])
60
61
    """
62
    def __init__(self, *args, **kwargs):
63
        super(Model, self).__init__(*args, **kwargs)
64
        bricks = [get_brick(var) for var
65
                  in self.variables + self.scan_variables if get_brick(var)]
66
        children = set(chain(*(brick.children for brick in bricks)))
67
        # Quadratic complexity: we should not have thousands of
68
        # top-level bricks.
69
        self.top_bricks = []
70
        for brick in bricks:
71
            if brick not in children and brick not in self.top_bricks:
72
                self.top_bricks.append(brick)
73
        names = Counter([brick.name for brick in self.top_bricks])
74
        repeated_names = [name for name, count in names.items() if count > 1]
75
        if repeated_names:
76
            raise ValueError("top bricks with the same name:"
77
                             " {}".format(', '.join(repeated_names)))
78
        brick_parameter_names = {
79
            v: k for k, v in Selector(
80
                self.top_bricks).get_parameters().items()}
81
        parameter_list = []
82
        for parameter in self.parameters:
83
            if parameter in brick_parameter_names:
84
                parameter_list.append((brick_parameter_names[parameter],
85
                                       parameter))
86
            else:
87
                parameter_list.append((parameter.name, parameter))
88
        self._parameter_dict = OrderedDict(parameter_list)
89
90
    def get_parameter_dict(self):
91
        """Returns parameters with their hierarchical names.
92
93
        The parameter names are formed from positions of their owner bricks
94
        in the bricks hierarchy. The variable names are used for the
95
        parameters that do not belong to any brick.
96
97
        Returns
98
        -------
99
        parameter_dict : dict
100
            A dictionary of (hierarchical name, shared variable) pairs.
101
102
        """
103
        return self._parameter_dict
104
105
    def get_parameter_values(self):
106
        """Return the values of model parameters.
107
108
        The same hierarhical names as in :meth:`get_parameter_dict` are
109
        used to uniquely identify parameters.
110
111
        Returns
112
        -------
113
        parameter_values : OrderedDict
114
            Dictionary of (hierarchical name, :class:`~numpy.ndarray`)
115
            pairs.
116
117
        """
118
        return OrderedDict(
119
            (name, parameter.get_value())
120
            for name, parameter in self.get_parameter_dict().items())
121
122
    def set_parameter_values(self, parameter_values):
123
        """Set the values of model parameters.
124
125
        The same hierarhical names as in :meth:`get_parameter_dict` are
126
        used to uniquely identify parameters.
127
128
        Parameters
129
        ----------
130
        parameter_values : OrderedDict
131
            Dictionary of (hierarchical name, :class:`~numpy.ndarray`)
132
            pairs.
133
134
        """
135
        parameters = self.get_parameter_dict()
136
137
        unknown = set(parameter_values) - set(parameters)
138
        missing = set(parameters) - set(parameter_values)
139
        if len(unknown):
140
            logger.error("unknown parameter names: {}\n".format(unknown))
141
        if len(missing):
142
            logger.error("missing values for parameters: {}\n".format(missing))
143
144
        for name, value in parameter_values.items():
145
            if name in parameters:
146
                model_shape = parameters[name].container.data.shape
147
                if model_shape != value.shape:
148
                    raise ValueError("Shape mismatch for parameter: {}. "
149
                                     "Expected {}, got {}."
150
                                     .format(name, model_shape, value.shape))
151
                parameters[name].set_value(value)
152
153
    def get_top_bricks(self):
154
        """Get the bricks that do not have parents.
155
156
        Returns
157
        -------
158
        bricks : list of :class:`~blocks.bricks.base.Brick`
159
160
        """
161
        return self.top_bricks
162