1 | import logging |
||
2 | import re |
||
3 | from collections import OrderedDict |
||
4 | |||
5 | from picklable_itertools.extras import equizip |
||
6 | import six |
||
7 | |||
8 | from blocks.bricks.base import Brick |
||
9 | from blocks.utils import dict_union |
||
10 | |||
11 | logger = logging.getLogger(__name__) |
||
12 | |||
13 | name_collision_error_message = """ |
||
0 ignored issues
–
show
|
|||
14 | |||
15 | The '{}' name appears more than once. Make sure that all bricks' children \ |
||
16 | have different names and that user-defined shared variables have unique names. |
||
17 | """ |
||
18 | |||
19 | |||
20 | class Path(object): |
||
21 | """Encapsulates a path in a hierarchy of bricks. |
||
22 | |||
23 | Currently the only allowed elements of paths are names of the bricks |
||
24 | and names of parameters. The latter can only be put in the end of the |
||
25 | path. It is planned to support regular expressions in some way later. |
||
26 | |||
27 | Parameters |
||
28 | ---------- |
||
29 | nodes : list or tuple of path nodes |
||
30 | The nodes of the path. |
||
31 | |||
32 | Attributes |
||
33 | ---------- |
||
34 | nodes : tuple |
||
35 | The tuple containing path nodes. |
||
36 | |||
37 | """ |
||
38 | separator = "/" |
||
39 | parameter_separator = "." |
||
40 | separator_re = re.compile("([{}{}])".format(separator, |
||
41 | parameter_separator)) |
||
42 | |||
43 | class BrickName(str): |
||
44 | |||
45 | def part(self): |
||
46 | return Path.separator + self |
||
47 | |||
48 | class ParameterName(str): |
||
49 | |||
50 | def part(self): |
||
51 | return Path.parameter_separator + self |
||
52 | |||
53 | def __init__(self, nodes): |
||
54 | if not isinstance(nodes, (list, tuple)): |
||
55 | raise ValueError |
||
56 | self.nodes = tuple(nodes) |
||
57 | |||
58 | def __str__(self): |
||
59 | return "".join([node.part() for node in self.nodes]) |
||
60 | |||
61 | def __add__(self, other): |
||
62 | return Path(self.nodes + other.nodes) |
||
63 | |||
64 | def __eq__(self, other): |
||
65 | return self.nodes == other.nodes |
||
66 | |||
67 | def __hash__(self): |
||
68 | return hash(self.nodes) |
||
69 | |||
70 | @staticmethod |
||
71 | def parse(string): |
||
72 | """Constructs a path from its string representation. |
||
73 | |||
74 | .. todo:: |
||
75 | |||
76 | More error checking. |
||
77 | |||
78 | Parameters |
||
79 | ---------- |
||
80 | string : str |
||
81 | String representation of the path. |
||
82 | |||
83 | """ |
||
84 | elements = Path.separator_re.split(string)[1:] |
||
85 | separators = elements[::2] |
||
86 | parts = elements[1::2] |
||
87 | if not len(elements) == 2 * len(separators) == 2 * len(parts): |
||
88 | raise ValueError |
||
89 | |||
90 | nodes = [] |
||
91 | for separator, part in equizip(separators, parts): |
||
92 | if separator == Path.separator: |
||
93 | nodes.append(Path.BrickName(part)) |
||
94 | elif Path.parameter_separator == Path.parameter_separator: |
||
95 | nodes.append(Path.ParameterName(part)) |
||
96 | else: |
||
97 | # This can not if separator_re is a correct regexp |
||
98 | raise ValueError("Wrong separator {}".format(separator)) |
||
99 | |||
100 | return Path(nodes) |
||
101 | |||
102 | |||
103 | class Selector(object): |
||
104 | """Selection of elements of a hierarchy of bricks. |
||
105 | |||
106 | Parameters |
||
107 | ---------- |
||
108 | bricks : list of :class:`~.bricks.Brick` |
||
109 | The bricks of the selection. |
||
110 | |||
111 | """ |
||
112 | def __init__(self, bricks): |
||
113 | if isinstance(bricks, Brick): |
||
114 | bricks = [bricks] |
||
115 | self.bricks = bricks |
||
116 | |||
117 | def select(self, path): |
||
118 | """Select a subset of current selection matching the path given. |
||
119 | |||
120 | .. warning:: |
||
121 | |||
122 | Current implementation is very inefficient (theoretical |
||
123 | complexity is :math:`O(n^3)`, where :math:`n` is the number |
||
124 | of bricks in the hierarchy). It can be sped up easily. |
||
125 | |||
126 | Parameters |
||
127 | ---------- |
||
128 | path : :class:`Path` or str |
||
129 | The path for the desired selection. If a string is given |
||
130 | it is parsed into a path. |
||
131 | |||
132 | Returns |
||
133 | ------- |
||
134 | Depending on the path given, one of the following: |
||
135 | |||
136 | * :class:`Selector` with desired bricks. |
||
137 | * list of :class:`~tensor.SharedTensorVariable`. |
||
138 | |||
139 | """ |
||
140 | if isinstance(path, six.string_types): |
||
141 | path = Path.parse(path) |
||
142 | |||
143 | current_bricks = [None] |
||
144 | for node in path.nodes: |
||
145 | next_bricks = [] |
||
146 | if isinstance(node, Path.ParameterName): |
||
147 | return list(Selector( |
||
148 | current_bricks).get_parameters(node).values()) |
||
149 | if isinstance(node, Path.BrickName): |
||
150 | for brick in current_bricks: |
||
151 | children = brick.children if brick else self.bricks |
||
152 | matching_bricks = [child for child in children |
||
153 | if child.name == node] |
||
154 | for match in matching_bricks: |
||
155 | if match not in next_bricks: |
||
156 | next_bricks.append(match) |
||
157 | current_bricks = next_bricks |
||
158 | return Selector(current_bricks) |
||
159 | |||
160 | def get_parameters(self, parameter_name=None): |
||
161 | r"""Returns parameters from selected bricks and their descendants. |
||
162 | |||
163 | Parameters |
||
164 | ---------- |
||
165 | parameter_name : :class:`Path.ParameterName`, optional |
||
166 | If given, only parameters with a `name` attribute equal to |
||
167 | `parameter_name` are returned. |
||
168 | |||
169 | Returns |
||
170 | ------- |
||
171 | parameters : OrderedDict |
||
172 | A dictionary of (`path`, `parameter`) pairs, where `path` is |
||
173 | a string representation of the path in the brick hierarchy |
||
174 | to the parameter (i.e. the slash-delimited path to the brick |
||
175 | that owns the parameter, followed by a dot, followed by the |
||
176 | parameter's name), and `parameter` is the Theano variable |
||
177 | representing the parameter. |
||
178 | |||
179 | Examples |
||
180 | -------- |
||
181 | >>> from blocks.bricks import MLP, Tanh |
||
182 | >>> mlp = MLP([Tanh(), Tanh(), Tanh()], [5, 7, 11, 2]) |
||
183 | >>> mlp.allocate() |
||
184 | >>> selector = Selector([mlp]) |
||
185 | >>> selector.get_parameters() # doctest: +NORMALIZE_WHITESPACE |
||
186 | OrderedDict([('/mlp/linear_0.W', W), ('/mlp/linear_0.b', b), |
||
187 | ('/mlp/linear_1.W', W), ('/mlp/linear_1.b', b), |
||
188 | ('/mlp/linear_2.W', W), ('/mlp/linear_2.b', b)]) |
||
189 | |||
190 | Or, select just the weights of the MLP by passing the parameter |
||
191 | name `W`: |
||
192 | |||
193 | >>> w_select = Selector([mlp]) |
||
194 | >>> w_select.get_parameters('W') # doctest: +NORMALIZE_WHITESPACE |
||
195 | OrderedDict([('/mlp/linear_0.W', W), ('/mlp/linear_1.W', W), |
||
196 | ('/mlp/linear_2.W', W)]) |
||
197 | |||
198 | """ |
||
199 | def recursion(brick): |
||
200 | # TODO path logic should be separate |
||
201 | result = [ |
||
202 | (Path([Path.BrickName(brick.name), |
||
203 | Path.ParameterName(parameter.name)]), |
||
204 | parameter) |
||
205 | for parameter in brick.parameters |
||
206 | if not parameter_name or parameter.name == parameter_name] |
||
207 | result = OrderedDict(result) |
||
208 | for child in brick.children: |
||
209 | for path, parameter in recursion(child).items(): |
||
210 | new_path = Path([Path.BrickName(brick.name)]) + path |
||
211 | if new_path in result: |
||
212 | raise ValueError( |
||
213 | "Name collision encountered while retrieving " + |
||
214 | "parameters." + |
||
215 | name_collision_error_message.format(new_path)) |
||
216 | result[new_path] = parameter |
||
217 | return result |
||
218 | result = dict_union(*[recursion(brick) |
||
219 | for brick in self.bricks]) |
||
220 | return OrderedDict((str(key), value) for key, value in result.items()) |
||
221 |
This check looks for invalid names for a range of different identifiers.
You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.
If your project includes a Pylint configuration file, the settings contained in that file take precedence.
To find out more about Pylint, please refer to their site.