Model.__init__()   F
last analyzed

Complexity

Conditions 13

Size

Total Lines 25

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 13
dl 0
loc 25
rs 2.7716
c 0
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like Model.__init__() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""Model - heavily annotated computation graph.
2
3
A model in Blocks is simply an annotated computation graph.  The class
4
:class:`Model` extends :class:`blocks.graph.ComputationGraph` :class:,
5
which is able to handle annotations and roles in general, but is
6
deliberately made unaware of specific annotations that a Theano graph
7
created by Blocks typically has, such as bricks and application calls.  The
8
:class:`Model` adds this functionality. Using :class:`Model` you can do
9
things like query all the bricks used to build the computation graph,
10
request "hierarchical names" of the parameters (a hierarchical name is a
11
path-like string which in addition to the parameter's name contains names
12
of the bricks on the path from a root brick to the brick that owns the
13
parameters, e.g. ``/mlp/linear/W``).
14
15
For more information, see :class:`Model` docstring.
16
17
"""
18
import logging
19
from collections import OrderedDict, Counter
20
from itertools import chain
21
22
from blocks.graph import ComputationGraph
23
from blocks.filter import get_brick
24
25
logger = logging.getLogger(__name__)
26
27
28
class Model(ComputationGraph):
29
    """Handles annotations in Blocks-built computation graphs.
30
31
    Use this class to handle your Blocks-created computation graph.
32
33
    Examples
34
    --------
35
    >>> from theano import tensor
36
    >>> from blocks.bricks import MLP, Tanh
37
    >>> x = tensor.matrix('x')
38
    >>> mlp = MLP([Tanh(), Tanh()], [10, 10, 10])
39
    >>> y = mlp.apply(x)
40
    >>> model = Model(y)
41
42
    With :class:`Model` you can get access to the brick hierarchy. The
43
    brick hierarchy is defined by ``children`` attributes that every brick
44
    has.  The bricks that are not children of other bricks are called top
45
    bricks.  It is often useful to have access to top bricks of a brick
46
    hierarchy used to build a computation graph, and here is how you can do
47
    it:
48
49
    >>> model.get_top_bricks() #doctest: +ELLIPSIS
50
    [<blocks.bricks.sequences.MLP object at ...]
51
52
    You can also get "hierarchical" names for the parameters,
53
    which encode the position of the owning brick in the
54
    brick hierarchy.
55
56
    >>> model.get_parameter_dict() #doctest: +NORMALIZE_WHITESPACE
57
    OrderedDict([('/mlp/linear_1.b', b), ('/mlp/linear_0.b', b),
58
    ('/mlp/linear_0.W', W), ('/mlp/linear_1.W', W)])
59
60
    """
61
    def __init__(self, *args, **kwargs):
62
        super(Model, self).__init__(*args, **kwargs)
63
        bricks = [get_brick(var) for var
64
                  in self.variables + self.scan_variables if get_brick(var)]
65
        children = set(chain(*(brick.children for brick in bricks)))
66
        # Quadratic complexity: we should not have thousands of
67
        # top-level bricks.
68
        self.top_bricks = []
69
        for brick in bricks:
70
            if brick not in children and brick not in self.top_bricks:
71
                self.top_bricks.append(brick)
72
        names = Counter([brick.name for brick in self.top_bricks])
73
        repeated_names = [name for name, count in names.items() if count > 1]
74
        if repeated_names:
75
            raise ValueError("top bricks with the same name:"
76
                             " {}".format(', '.join(repeated_names)))
77
        parameter_list = []
78
        for parameter in self.parameters:
79
            if get_brick(parameter):
80
                parameter_list.append(
81
                    (get_brick(parameter).get_hierarchical_name(parameter),
82
                     parameter))
83
            else:
84
                parameter_list.append((parameter.name, parameter))
85
        self._parameter_dict = OrderedDict(parameter_list)
86
87
    def get_parameter_dict(self):
88
        """Returns parameters with their hierarchical names.
89
90
        The parameter names are formed from positions of their owner bricks
91
        in the bricks hierarchy. The variable names are used for the
92
        parameters that do not belong to any brick.
93
94
        Returns
95
        -------
96
        parameter_dict : dict
97
            A dictionary of (hierarchical name, shared variable) pairs.
98
99
        """
100
        return self._parameter_dict
101
102
    def get_parameter_values(self):
103
        """Return the values of model parameters.
104
105
        The same hierarhical names as in :meth:`get_parameter_dict` are
106
        used to uniquely identify parameters.
107
108
        Returns
109
        -------
110
        parameter_values : OrderedDict
111
            Dictionary of (hierarchical name, :class:`~numpy.ndarray`)
112
            pairs.
113
114
        """
115
        return OrderedDict(
116
            (name, parameter.get_value())
117
            for name, parameter in self.get_parameter_dict().items())
118
119
    def set_parameter_values(self, parameter_values):
120
        """Set the values of model parameters.
121
122
        The same hierarhical names as in :meth:`get_parameter_dict` are
123
        used to uniquely identify parameters.
124
125
        Parameters
126
        ----------
127
        parameter_values : OrderedDict
128
            Dictionary of (hierarchical name, :class:`~numpy.ndarray`)
129
            pairs.
130
131
        """
132
        parameters = self.get_parameter_dict()
133
134
        unknown = set(parameter_values) - set(parameters)
135
        missing = set(parameters) - set(parameter_values)
136
        if len(unknown):
137
            logger.error("unknown parameter names: {}\n".format(unknown))
138
        if len(missing):
139
            logger.error("missing values for parameters: {}\n".format(missing))
140
141
        for name, value in parameter_values.items():
142
            if name in parameters:
143
                model_shape = parameters[name].container.data.shape
144
                if model_shape != value.shape:
145
                    raise ValueError("Shape mismatch for parameter: {}. "
146
                                     "Expected {}, got {}."
147
                                     .format(name, model_shape, value.shape))
148
                parameters[name].set_value(value)
149
150
    def get_top_bricks(self):
151
        """Get the bricks that do not have parents.
152
153
        Returns
154
        -------
155
        bricks : list of :class:`~blocks.bricks.base.Brick`
156
157
        """
158
        return self.top_bricks
159