Selector.select()   F
last analyzed

Complexity

Conditions 11

Size

Total Lines 42

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 11
dl 0
loc 42
rs 3.1764
c 0
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like Selector.select() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import logging
2
import re
3
from collections import OrderedDict
4
5
from picklable_itertools.extras import equizip
6
import six
7
8
from blocks.bricks.base import Brick
9
from blocks.utils import dict_union
10
11
logger = logging.getLogger(__name__)
12
13
name_collision_error_message = """
0 ignored issues
show
Coding Style Naming introduced by
The name name_collision_error_message does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
14
15
The '{}' name appears more than once. Make sure that all bricks' children \
16
have different names and that user-defined shared variables have unique names.
17
"""
18
19
20
class Path(object):
21
    """Encapsulates a path in a hierarchy of bricks.
22
23
    Currently the only allowed elements of paths are names of the bricks
24
    and names of parameters. The latter can only be put in the end of the
25
    path. It is planned to support regular expressions in some way later.
26
27
    Parameters
28
    ----------
29
    nodes : list or tuple of path nodes
30
        The nodes of the path.
31
32
    Attributes
33
    ----------
34
    nodes : tuple
35
        The tuple containing path nodes.
36
37
    """
38
    separator = "/"
39
    parameter_separator = "."
40
    separator_re = re.compile("([{}{}])".format(separator,
41
                                                parameter_separator))
42
43
    class BrickName(str):
44
45
        def part(self):
46
            return Path.separator + self
47
48
    class ParameterName(str):
49
50
        def part(self):
51
            return Path.parameter_separator + self
52
53
    def __init__(self, nodes):
54
        if not isinstance(nodes, (list, tuple)):
55
            raise ValueError
56
        self.nodes = tuple(nodes)
57
58
    def __str__(self):
59
        return "".join([node.part() for node in self.nodes])
60
61
    def __add__(self, other):
62
        return Path(self.nodes + other.nodes)
63
64
    def __eq__(self, other):
65
        return self.nodes == other.nodes
66
67
    def __hash__(self):
68
        return hash(self.nodes)
69
70
    @staticmethod
71
    def parse(string):
72
        """Constructs a path from its string representation.
73
74
        .. todo::
75
76
            More error checking.
77
78
        Parameters
79
        ----------
80
        string : str
81
            String representation of the path.
82
83
        """
84
        elements = Path.separator_re.split(string)[1:]
85
        separators = elements[::2]
86
        parts = elements[1::2]
87
        if not len(elements) == 2 * len(separators) == 2 * len(parts):
88
            raise ValueError
89
90
        nodes = []
91
        for separator, part in equizip(separators, parts):
92
            if separator == Path.separator:
93
                nodes.append(Path.BrickName(part))
94
            elif Path.parameter_separator == Path.parameter_separator:
95
                nodes.append(Path.ParameterName(part))
96
            else:
97
                # This can not if separator_re is a correct regexp
98
                raise ValueError("Wrong separator {}".format(separator))
99
100
        return Path(nodes)
101
102
103
class Selector(object):
104
    """Selection of elements of a hierarchy of bricks.
105
106
    Parameters
107
    ----------
108
    bricks : list of :class:`~.bricks.Brick`
109
        The bricks of the selection.
110
111
    """
112
    def __init__(self, bricks):
113
        if isinstance(bricks, Brick):
114
            bricks = [bricks]
115
        self.bricks = bricks
116
117
    def select(self, path):
118
        """Select a subset of current selection matching the path given.
119
120
        .. warning::
121
122
            Current implementation is very inefficient (theoretical
123
            complexity is :math:`O(n^3)`, where :math:`n` is the number
124
            of bricks in the hierarchy). It can be sped up easily.
125
126
        Parameters
127
        ----------
128
        path : :class:`Path` or str
129
            The path for the desired selection. If a string is given
130
            it is parsed into a path.
131
132
        Returns
133
        -------
134
        Depending on the path given, one of the following:
135
136
        * :class:`Selector` with desired bricks.
137
        * list of :class:`~tensor.SharedTensorVariable`.
138
139
        """
140
        if isinstance(path, six.string_types):
141
            path = Path.parse(path)
142
143
        current_bricks = [None]
144
        for node in path.nodes:
145
            next_bricks = []
146
            if isinstance(node, Path.ParameterName):
147
                return list(Selector(
148
                    current_bricks).get_parameters(node).values())
149
            if isinstance(node, Path.BrickName):
150
                for brick in current_bricks:
151
                    children = brick.children if brick else self.bricks
152
                    matching_bricks = [child for child in children
153
                                       if child.name == node]
154
                    for match in matching_bricks:
155
                        if match not in next_bricks:
156
                            next_bricks.append(match)
157
            current_bricks = next_bricks
158
        return Selector(current_bricks)
159
160
    def get_parameters(self, parameter_name=None):
161
        r"""Returns parameters from selected bricks and their descendants.
162
163
        Parameters
164
        ----------
165
        parameter_name : :class:`Path.ParameterName`, optional
166
            If given, only parameters with a `name` attribute equal to
167
            `parameter_name` are returned.
168
169
        Returns
170
        -------
171
        parameters : OrderedDict
172
            A dictionary of (`path`, `parameter`) pairs, where `path` is
173
            a string representation of the path in the brick hierarchy
174
            to the parameter (i.e. the slash-delimited path to the brick
175
            that owns the parameter, followed by a dot, followed by the
176
            parameter's name), and `parameter` is the Theano variable
177
            representing the parameter.
178
179
        Examples
180
        --------
181
        >>> from blocks.bricks import MLP, Tanh
182
        >>> mlp = MLP([Tanh(), Tanh(), Tanh()], [5, 7, 11, 2])
183
        >>> mlp.allocate()
184
        >>> selector = Selector([mlp])
185
        >>> selector.get_parameters()  # doctest: +NORMALIZE_WHITESPACE
186
        OrderedDict([('/mlp/linear_0.W', W), ('/mlp/linear_0.b', b),
187
        ('/mlp/linear_1.W', W), ('/mlp/linear_1.b', b),
188
        ('/mlp/linear_2.W', W), ('/mlp/linear_2.b', b)])
189
190
        Or, select just the weights of the MLP by passing the parameter
191
        name `W`:
192
193
        >>> w_select = Selector([mlp])
194
        >>> w_select.get_parameters('W')  # doctest: +NORMALIZE_WHITESPACE
195
        OrderedDict([('/mlp/linear_0.W', W), ('/mlp/linear_1.W', W),
196
        ('/mlp/linear_2.W', W)])
197
198
        """
199
        def recursion(brick):
200
            # TODO path logic should be separate
201
            result = [
202
                (Path([Path.BrickName(brick.name),
203
                       Path.ParameterName(parameter.name)]),
204
                 parameter)
205
                for parameter in brick.parameters
206
                if not parameter_name or parameter.name == parameter_name]
207
            result = OrderedDict(result)
208
            for child in brick.children:
209
                for path, parameter in recursion(child).items():
210
                    new_path = Path([Path.BrickName(brick.name)]) + path
211
                    if new_path in result:
212
                        raise ValueError(
213
                            "Name collision encountered while retrieving " +
214
                            "parameters." +
215
                            name_collision_error_message.format(new_path))
216
                    result[new_path] = parameter
217
            return result
218
        result = dict_union(*[recursion(brick)
219
                            for brick in self.bricks])
220
        return OrderedDict((str(key), value) for key, value in result.items())
221