1
|
|
|
from __future__ import print_function |
|
|
|
|
2
|
|
|
import sys |
3
|
|
|
import contextlib |
4
|
|
|
from collections import OrderedDict, deque |
5
|
|
|
|
6
|
|
|
import numpy |
7
|
|
|
import six |
8
|
|
|
import theano |
9
|
|
|
from theano import tensor |
10
|
|
|
from theano import printing |
11
|
|
|
from theano.gof.graph import Constant |
12
|
|
|
from theano.tensor.shared_randomstreams import RandomStateSharedVariable |
13
|
|
|
from theano.tensor.sharedvar import SharedVariable |
14
|
|
|
|
15
|
|
|
|
16
|
|
|
def pack(arg): |
17
|
|
|
"""Pack variables into a list. |
18
|
|
|
|
19
|
|
|
Parameters |
20
|
|
|
---------- |
21
|
|
|
arg : object |
22
|
|
|
Either a list or tuple, or any other Python object. Lists will be |
23
|
|
|
returned as is, and tuples will be cast to lists. Any other |
24
|
|
|
variable will be returned in a singleton list. |
25
|
|
|
|
26
|
|
|
Returns |
27
|
|
|
------- |
28
|
|
|
list |
29
|
|
|
List containing the arguments |
30
|
|
|
|
31
|
|
|
""" |
32
|
|
|
if isinstance(arg, (list, tuple)): |
33
|
|
|
return list(arg) |
34
|
|
|
else: |
35
|
|
|
return [arg] |
36
|
|
|
|
37
|
|
|
|
38
|
|
|
def unpack(arg, singleton=False): |
39
|
|
|
"""Unpack variables from a list or tuple. |
40
|
|
|
|
41
|
|
|
Parameters |
42
|
|
|
---------- |
43
|
|
|
arg : object |
44
|
|
|
Either a list or tuple, or any other Python object. If passed a |
45
|
|
|
list or tuple of length one, the only element of that list will |
46
|
|
|
be returned. If passed a tuple of length greater than one, it |
47
|
|
|
will be cast to a list before returning. Any other variable |
48
|
|
|
will be returned as is. |
49
|
|
|
singleton : bool |
50
|
|
|
If ``True``, `arg` is expected to be a singleton (a list or tuple |
51
|
|
|
with exactly one element) and an exception is raised if this is not |
52
|
|
|
the case. ``False`` by default. |
53
|
|
|
|
54
|
|
|
Returns |
55
|
|
|
------- |
56
|
|
|
object |
57
|
|
|
A list of length greater than one, or any other Python object |
58
|
|
|
except tuple. |
59
|
|
|
|
60
|
|
|
""" |
61
|
|
|
if isinstance(arg, (list, tuple)): |
62
|
|
|
if len(arg) == 1: |
63
|
|
|
return arg[0] |
64
|
|
|
else: |
65
|
|
|
if singleton: |
66
|
|
|
raise ValueError("Expected a singleton, got {}". |
67
|
|
|
format(arg)) |
68
|
|
|
return list(arg) |
69
|
|
|
else: |
70
|
|
|
return arg |
71
|
|
|
|
72
|
|
|
|
73
|
|
|
def shared_floatx_zeros_matching(shared_variable, name=None, **kwargs): |
74
|
|
|
r"""Create another shared variable with matching shape and broadcast. |
75
|
|
|
|
76
|
|
|
Parameters |
77
|
|
|
---------- |
78
|
|
|
shared_variable : :class:'tensor.TensorSharedVariable' |
79
|
|
|
A Theano shared variable with the desired shape and broadcastable |
80
|
|
|
flags. |
81
|
|
|
name : :obj:`str`, optional |
82
|
|
|
The name for the shared variable. Defaults to `None`. |
83
|
|
|
\*\*kwargs |
84
|
|
|
Keyword arguments to pass to the :func:`shared_floatx_zeros` |
85
|
|
|
function. |
86
|
|
|
|
87
|
|
|
Returns |
88
|
|
|
------- |
89
|
|
|
:class:'tensor.TensorSharedVariable' |
90
|
|
|
A new shared variable, initialized to all zeros, with the same |
91
|
|
|
shape and broadcastable flags as `shared_variable`. |
92
|
|
|
|
93
|
|
|
|
94
|
|
|
""" |
95
|
|
|
if not is_shared_variable(shared_variable): |
96
|
|
|
raise ValueError('argument must be a shared variable') |
97
|
|
|
return shared_floatx_zeros(shared_variable.get_value().shape, |
98
|
|
|
name=name, |
99
|
|
|
broadcastable=shared_variable.broadcastable, |
100
|
|
|
**kwargs) |
101
|
|
|
|
102
|
|
|
|
103
|
|
|
def shared_floatx_zeros(shape, **kwargs): |
104
|
|
|
r"""Creates a shared variable array filled with zeros. |
105
|
|
|
|
106
|
|
|
Parameters |
107
|
|
|
---------- |
108
|
|
|
shape : tuple |
109
|
|
|
A tuple of integers representing the shape of the array. |
110
|
|
|
\*\*kwargs |
111
|
|
|
Keyword arguments to pass to the :func:`shared_floatx` function. |
112
|
|
|
|
113
|
|
|
Returns |
114
|
|
|
------- |
115
|
|
|
:class:'tensor.TensorSharedVariable' |
116
|
|
|
A Theano shared variable filled with zeros. |
117
|
|
|
|
118
|
|
|
""" |
119
|
|
|
return shared_floatx(numpy.zeros(shape), **kwargs) |
120
|
|
|
|
121
|
|
|
|
122
|
|
|
def shared_floatx_nans(shape, **kwargs): |
123
|
|
|
r"""Creates a shared variable array filled with nans. |
124
|
|
|
|
125
|
|
|
Parameters |
126
|
|
|
---------- |
127
|
|
|
shape : tuple |
128
|
|
|
A tuple of integers representing the shape of the array. |
129
|
|
|
\*\*kwargs |
130
|
|
|
Keyword arguments to pass to the :func:`shared_floatx` function. |
131
|
|
|
|
132
|
|
|
Returns |
133
|
|
|
------- |
134
|
|
|
:class:'tensor.TensorSharedVariable' |
135
|
|
|
A Theano shared variable filled with nans. |
136
|
|
|
|
137
|
|
|
""" |
138
|
|
|
return shared_floatx(numpy.nan * numpy.zeros(shape), **kwargs) |
139
|
|
|
|
140
|
|
|
|
141
|
|
|
def shared_floatx(value, name=None, borrow=False, dtype=None, **kwargs): |
142
|
|
|
r"""Transform a value into a shared variable of type floatX. |
143
|
|
|
|
144
|
|
|
Parameters |
145
|
|
|
---------- |
146
|
|
|
value : :class:`~numpy.ndarray` |
147
|
|
|
The value to associate with the Theano shared. |
148
|
|
|
name : :obj:`str`, optional |
149
|
|
|
The name for the shared variable. Defaults to `None`. |
150
|
|
|
borrow : :obj:`bool`, optional |
151
|
|
|
If set to True, the given `value` will not be copied if possible. |
152
|
|
|
This can save memory and speed. Defaults to False. |
153
|
|
|
dtype : :obj:`str`, optional |
154
|
|
|
The `dtype` of the shared variable. Default value is |
155
|
|
|
:attr:`config.floatX`. |
156
|
|
|
\*\*kwargs |
157
|
|
|
Keyword arguments to pass to the :func:`~theano.shared` function. |
158
|
|
|
|
159
|
|
|
Returns |
160
|
|
|
------- |
161
|
|
|
:class:`tensor.TensorSharedVariable` |
162
|
|
|
A Theano shared variable with the requested value and `dtype`. |
163
|
|
|
|
164
|
|
|
""" |
165
|
|
|
if dtype is None: |
166
|
|
|
dtype = theano.config.floatX |
167
|
|
|
return theano.shared(theano._asarray(value, dtype=dtype), |
168
|
|
|
name=name, borrow=borrow, **kwargs) |
169
|
|
|
|
170
|
|
|
|
171
|
|
|
def shared_like(variable, name=None, **kwargs): |
172
|
|
|
r"""Construct a shared variable to hold the value of a tensor variable. |
173
|
|
|
|
174
|
|
|
Parameters |
175
|
|
|
---------- |
176
|
|
|
variable : :class:`~tensor.TensorVariable` |
177
|
|
|
The variable whose dtype and ndim will be used to construct |
178
|
|
|
the new shared variable. |
179
|
|
|
name : :obj:`str` or :obj:`None` |
180
|
|
|
The name of the shared variable. If None, the name is determined |
181
|
|
|
based on variable's name. |
182
|
|
|
\*\*kwargs |
183
|
|
|
Keyword arguments to pass to the :func:`~theano.shared` function. |
184
|
|
|
|
185
|
|
|
""" |
186
|
|
|
variable = tensor.as_tensor_variable(variable) |
187
|
|
|
if name is None: |
188
|
|
|
name = "shared_{}".format(variable.name) |
189
|
|
|
return theano.shared(numpy.zeros((0,) * variable.ndim, |
190
|
|
|
dtype=variable.dtype), |
191
|
|
|
name=name, **kwargs) |
192
|
|
|
|
193
|
|
|
|
194
|
|
|
def reraise_as(new_exc): |
195
|
|
|
"""Reraise an exception as a different type or with a message. |
196
|
|
|
|
197
|
|
|
This function ensures that the original traceback is kept, making for |
198
|
|
|
easier debugging. |
199
|
|
|
|
200
|
|
|
Parameters |
201
|
|
|
---------- |
202
|
|
|
new_exc : :class:`Exception` or :obj:`str` |
203
|
|
|
The new error to be raised e.g. (ValueError("New message")) |
204
|
|
|
or a string that will be prepended to the original exception |
205
|
|
|
message |
206
|
|
|
|
207
|
|
|
Notes |
208
|
|
|
----- |
209
|
|
|
Note that when reraising exceptions, the arguments of the original |
210
|
|
|
exception are cast to strings and appended to the error message. If |
211
|
|
|
you want to retain the original exception arguments, please use: |
212
|
|
|
|
213
|
|
|
>>> try: |
214
|
|
|
... 1 / 0 |
215
|
|
|
... except Exception as e: |
216
|
|
|
... reraise_as(Exception("Extra information", *e.args)) |
217
|
|
|
Traceback (most recent call last): |
218
|
|
|
... |
219
|
|
|
Exception: 'Extra information, ... |
220
|
|
|
|
221
|
|
|
Examples |
222
|
|
|
-------- |
223
|
|
|
>>> class NewException(Exception): |
224
|
|
|
... def __init__(self, message): |
225
|
|
|
... super(NewException, self).__init__(message) |
226
|
|
|
>>> try: |
227
|
|
|
... do_something_crazy() |
228
|
|
|
... except Exception: |
229
|
|
|
... reraise_as(NewException("Informative message")) |
230
|
|
|
Traceback (most recent call last): |
231
|
|
|
... |
232
|
|
|
NewException: Informative message ... |
233
|
|
|
|
234
|
|
|
""" |
235
|
|
|
orig_exc_type, orig_exc_value, orig_exc_traceback = sys.exc_info() |
236
|
|
|
|
237
|
|
|
if isinstance(new_exc, six.string_types): |
238
|
|
|
new_exc = orig_exc_type(new_exc) |
239
|
|
|
|
240
|
|
|
if hasattr(new_exc, 'args'): |
241
|
|
|
if len(new_exc.args) > 0: |
242
|
|
|
# We add all the arguments to the message, to make sure that this |
243
|
|
|
# information isn't lost if this exception is reraised again |
244
|
|
|
new_message = ', '.join(str(arg) for arg in new_exc.args) |
245
|
|
|
else: |
246
|
|
|
new_message = "" |
247
|
|
|
new_message += '\n\nOriginal exception:\n\t' + orig_exc_type.__name__ |
248
|
|
|
if hasattr(orig_exc_value, 'args') and len(orig_exc_value.args) > 0: |
249
|
|
|
if getattr(orig_exc_value, 'reraised', False): |
250
|
|
|
new_message += ': ' + str(orig_exc_value.args[0]) |
251
|
|
|
else: |
252
|
|
|
new_message += ': ' + ', '.join(str(arg) |
253
|
|
|
for arg in orig_exc_value.args) |
254
|
|
|
new_exc.args = (new_message,) + new_exc.args[1:] |
255
|
|
|
|
256
|
|
|
new_exc.__cause__ = orig_exc_value |
257
|
|
|
new_exc.reraised = True |
258
|
|
|
six.reraise(type(new_exc), new_exc, orig_exc_traceback) |
259
|
|
|
|
260
|
|
|
|
261
|
|
|
def check_theano_variable(variable, n_dim, dtype_prefix): |
262
|
|
|
"""Check number of dimensions and dtype of a Theano variable. |
263
|
|
|
|
264
|
|
|
If the input is not a Theano variable, it is converted to one. `None` |
265
|
|
|
input is handled as a special case: no checks are done. |
266
|
|
|
|
267
|
|
|
Parameters |
268
|
|
|
---------- |
269
|
|
|
variable : :class:`~tensor.TensorVariable` or convertible to one |
270
|
|
|
A variable to check. |
271
|
|
|
n_dim : int |
272
|
|
|
Expected number of dimensions or None. If None, no check is |
273
|
|
|
performed. |
274
|
|
|
dtype : str |
275
|
|
|
Expected dtype prefix or None. If None, no check is performed. |
276
|
|
|
|
277
|
|
|
""" |
278
|
|
|
if variable is None: |
279
|
|
|
return |
280
|
|
|
|
281
|
|
|
if not isinstance(variable, tensor.Variable): |
282
|
|
|
variable = tensor.as_tensor_variable(variable) |
283
|
|
|
|
284
|
|
|
if n_dim and variable.ndim != n_dim: |
285
|
|
|
raise ValueError("Wrong number of dimensions:" |
286
|
|
|
"\n\texpected {}, got {}".format( |
287
|
|
|
n_dim, variable.ndim)) |
288
|
|
|
|
289
|
|
|
if dtype_prefix and not variable.dtype.startswith(dtype_prefix): |
290
|
|
|
raise ValueError("Wrong dtype prefix:" |
291
|
|
|
"\n\texpected starting with {}, got {}".format( |
292
|
|
|
dtype_prefix, variable.dtype)) |
293
|
|
|
|
294
|
|
|
|
295
|
|
|
def is_graph_input(variable): |
296
|
|
|
"""Check if variable is a user-provided graph input. |
297
|
|
|
|
298
|
|
|
To be considered an input the variable must have no owner, and not |
299
|
|
|
be a constant or shared variable. |
300
|
|
|
|
301
|
|
|
Parameters |
302
|
|
|
---------- |
303
|
|
|
variable : :class:`~tensor.TensorVariable` |
304
|
|
|
|
305
|
|
|
Returns |
306
|
|
|
------- |
307
|
|
|
bool |
308
|
|
|
``True`` If the variable is a user-provided input to the graph. |
309
|
|
|
|
310
|
|
|
""" |
311
|
|
|
return (not variable.owner and |
312
|
|
|
not isinstance(variable, SharedVariable) and |
313
|
|
|
not isinstance(variable, Constant)) |
314
|
|
|
|
315
|
|
|
|
316
|
|
|
def is_shared_variable(variable): |
317
|
|
|
"""Check if a variable is a Theano shared variable. |
318
|
|
|
|
319
|
|
|
Notes |
320
|
|
|
----- |
321
|
|
|
This function excludes shared variables that store the state of Theano |
322
|
|
|
random number generators. |
323
|
|
|
|
324
|
|
|
""" |
325
|
|
|
return (isinstance(variable, SharedVariable) and |
326
|
|
|
not isinstance(variable, RandomStateSharedVariable) and |
327
|
|
|
not hasattr(variable.tag, 'is_rng')) |
328
|
|
|
|
329
|
|
|
|
330
|
|
|
def dict_subset(dict_, keys, pop=False, must_have=True): |
331
|
|
|
"""Return a subset of a dictionary corresponding to a set of keys. |
332
|
|
|
|
333
|
|
|
Parameters |
334
|
|
|
---------- |
335
|
|
|
dict_ : dict |
336
|
|
|
The dictionary. |
337
|
|
|
keys : iterable |
338
|
|
|
The keys of interest. |
339
|
|
|
pop : bool |
340
|
|
|
If ``True``, the pairs corresponding to the keys of interest are |
341
|
|
|
popped from the dictionary. |
342
|
|
|
must_have : bool |
343
|
|
|
If ``True``, a ValueError will be raised when trying to retrieve a |
344
|
|
|
key not present in the dictionary. |
345
|
|
|
|
346
|
|
|
Returns |
347
|
|
|
------- |
348
|
|
|
result : ``OrderedDict`` |
349
|
|
|
An ordered dictionary of retrieved pairs. The order is the same as |
350
|
|
|
in the ``keys`` argument. |
351
|
|
|
|
352
|
|
|
""" |
353
|
|
|
not_found = object() |
354
|
|
|
|
355
|
|
|
def extract(k): |
356
|
|
|
if pop: |
357
|
|
|
if must_have: |
358
|
|
|
return dict_.pop(k) |
359
|
|
|
return dict_.pop(k, not_found) |
360
|
|
|
if must_have: |
361
|
|
|
return dict_[k] |
362
|
|
|
return dict_.get(k, not_found) |
363
|
|
|
|
364
|
|
|
result = [(key, extract(key)) for key in keys] |
365
|
|
|
return OrderedDict([(k, v) for k, v in result if v is not not_found]) |
366
|
|
|
|
367
|
|
|
|
368
|
|
|
def dict_union(*dicts, **kwargs): |
369
|
|
|
r"""Return union of a sequence of disjoint dictionaries. |
370
|
|
|
|
371
|
|
|
Parameters |
372
|
|
|
---------- |
373
|
|
|
dicts : dicts |
374
|
|
|
A set of dictionaries with no keys in common. If the first |
375
|
|
|
dictionary in the sequence is an instance of `OrderedDict`, the |
376
|
|
|
result will be OrderedDict. |
377
|
|
|
\*\*kwargs |
378
|
|
|
Keywords and values to add to the resulting dictionary. |
379
|
|
|
|
380
|
|
|
Raises |
381
|
|
|
------ |
382
|
|
|
ValueError |
383
|
|
|
If a key appears twice in the dictionaries or keyword arguments. |
384
|
|
|
|
385
|
|
|
""" |
386
|
|
|
dicts = list(dicts) |
387
|
|
|
if dicts and isinstance(dicts[0], OrderedDict): |
388
|
|
|
result = OrderedDict() |
389
|
|
|
else: |
390
|
|
|
result = {} |
391
|
|
|
for d in list(dicts) + [kwargs]: |
392
|
|
|
duplicate_keys = set(result.keys()) & set(d.keys()) |
393
|
|
|
if duplicate_keys: |
394
|
|
|
raise ValueError("The following keys have duplicate entries: {}" |
395
|
|
|
.format(", ".join(str(key) for key in |
396
|
|
|
duplicate_keys))) |
397
|
|
|
result.update(d) |
398
|
|
|
return result |
399
|
|
|
|
400
|
|
|
|
401
|
|
|
def repr_attrs(instance, *attrs): |
402
|
|
|
r"""Prints a representation of an object with certain attributes. |
403
|
|
|
|
404
|
|
|
Parameters |
405
|
|
|
---------- |
406
|
|
|
instance : object |
407
|
|
|
The object of which to print the string representation |
408
|
|
|
\*attrs |
409
|
|
|
Names of attributes that should be printed. |
410
|
|
|
|
411
|
|
|
Examples |
412
|
|
|
-------- |
413
|
|
|
>>> class A(object): |
414
|
|
|
... def __init__(self, value): |
415
|
|
|
... self.value = value |
416
|
|
|
>>> a = A('a_value') |
417
|
|
|
>>> repr(a) # doctest: +SKIP |
418
|
|
|
<blocks.utils.A object at 0x7fb2b4741a10> |
419
|
|
|
>>> repr_attrs(a, 'value') # doctest: +SKIP |
420
|
|
|
<blocks.utils.A object at 0x7fb2b4741a10: value=a_value> |
421
|
|
|
|
422
|
|
|
""" |
423
|
|
|
orig_repr_template = ("<{0.__class__.__module__}.{0.__class__.__name__} " |
424
|
|
|
"object at {1:#x}") |
425
|
|
|
if attrs: |
426
|
|
|
repr_template = (orig_repr_template + ": " + |
427
|
|
|
", ".join(["{0}={{0.{0}}}".format(attr) |
428
|
|
|
for attr in attrs])) |
429
|
|
|
repr_template += '>' |
430
|
|
|
orig_repr_template += '>' |
431
|
|
|
try: |
432
|
|
|
return repr_template.format(instance, id(instance)) |
433
|
|
|
except Exception: |
434
|
|
|
return orig_repr_template.format(instance, id(instance)) |
435
|
|
|
|
436
|
|
|
|
437
|
|
|
def put_hook(variable, hook_fn, *args): |
438
|
|
|
r"""Put a hook on a Theano variables. |
439
|
|
|
|
440
|
|
|
Ensures that the hook function is executed every time when the value |
441
|
|
|
of the Theano variable is available. |
442
|
|
|
|
443
|
|
|
Parameters |
444
|
|
|
---------- |
445
|
|
|
variable : :class:`~tensor.TensorVariable` |
446
|
|
|
The variable to put a hook on. |
447
|
|
|
hook_fn : function |
448
|
|
|
The hook function. Should take a single argument: the variable's |
449
|
|
|
value. |
450
|
|
|
\*args : list |
451
|
|
|
Positional arguments to pass to the hook function. |
452
|
|
|
|
453
|
|
|
""" |
454
|
|
|
return printing.Print(global_fn=lambda _, x: hook_fn(x, *args))(variable) |
455
|
|
|
|
456
|
|
|
|
457
|
|
|
def ipdb_breakpoint(x): |
458
|
|
|
"""A simple hook function for :func:`put_hook` that runs ipdb. |
459
|
|
|
|
460
|
|
|
Parameters |
461
|
|
|
---------- |
462
|
|
|
x : :class:`~numpy.ndarray` |
463
|
|
|
The value of the hooked variable. |
464
|
|
|
|
465
|
|
|
""" |
466
|
|
|
import ipdb |
467
|
|
|
ipdb.set_trace() |
468
|
|
|
|
469
|
|
|
|
470
|
|
|
def print_sum(x, header=None): |
471
|
|
|
if not header: |
472
|
|
|
header = 'print_sum' |
473
|
|
|
print(header + ':', x.sum()) |
474
|
|
|
|
475
|
|
|
|
476
|
|
|
def print_shape(x, header=None): |
477
|
|
|
if not header: |
478
|
|
|
header = 'print_shape' |
479
|
|
|
print(header + ':', x.shape) |
480
|
|
|
|
481
|
|
|
|
482
|
|
|
@contextlib.contextmanager |
483
|
|
|
def change_recursion_limit(limit): |
484
|
|
|
"""Temporarily changes the recursion limit.""" |
485
|
|
|
old_limit = sys.getrecursionlimit() |
486
|
|
|
if old_limit < limit: |
487
|
|
|
sys.setrecursionlimit(limit) |
488
|
|
|
yield |
489
|
|
|
sys.setrecursionlimit(old_limit) |
490
|
|
|
|
491
|
|
|
|
492
|
|
|
def extract_args(expected, *args, **kwargs): |
493
|
|
|
r"""Route keyword and positional arguments to a list of names. |
494
|
|
|
|
495
|
|
|
A frequent situation is that a method of the class gets to |
496
|
|
|
know its positional arguments only when an instance of the class |
497
|
|
|
has been created. In such cases the signature of such method has to |
498
|
|
|
be `*args, **kwargs`. The downside of such signatures is that the |
499
|
|
|
validity of a call is not checked. |
500
|
|
|
|
501
|
|
|
Use :func:`extract_args` if your method knows at runtime, but not |
502
|
|
|
at evaluation/compile time, what arguments it actually expects, |
503
|
|
|
in order to check that they are correctly received. |
504
|
|
|
|
505
|
|
|
Parameters |
506
|
|
|
---------- |
507
|
|
|
expected : list of str |
508
|
|
|
A list of strings denoting names for the expected arguments, |
509
|
|
|
in order. |
510
|
|
|
args : iterable |
511
|
|
|
Positional arguments that have been passed. |
512
|
|
|
kwargs : Mapping |
513
|
|
|
Keyword arguments that have been passed. |
514
|
|
|
|
515
|
|
|
Returns |
516
|
|
|
------- |
517
|
|
|
routed_args : OrderedDict |
518
|
|
|
An OrderedDict mapping the names in `expected` to values drawn |
519
|
|
|
from either `args` or `kwargs` in the usual Python fashion. |
520
|
|
|
|
521
|
|
|
Raises |
522
|
|
|
------ |
523
|
|
|
KeyError |
524
|
|
|
If a keyword argument is passed, the key for which is not |
525
|
|
|
contained within `expected`. |
526
|
|
|
TypeError |
527
|
|
|
If an expected argument is accounted for in both the positional |
528
|
|
|
and keyword arguments. |
529
|
|
|
ValueError |
530
|
|
|
If certain arguments in `expected` are not assigned a value |
531
|
|
|
by either a positional or keyword argument. |
532
|
|
|
|
533
|
|
|
""" |
534
|
|
|
# Use of zip() rather than equizip() intentional here. We want |
535
|
|
|
# to truncate to the length of args. |
536
|
|
|
routed_args = dict(zip(expected, args)) |
537
|
|
|
for name in kwargs: |
538
|
|
|
if name not in expected: |
539
|
|
|
raise KeyError('invalid input name: {}'.format(name)) |
540
|
|
|
elif name in routed_args: |
541
|
|
|
raise TypeError("got multiple values for " |
542
|
|
|
"argument '{}'".format(name)) |
543
|
|
|
else: |
544
|
|
|
routed_args[name] = kwargs[name] |
545
|
|
|
if set(expected) != set(routed_args): |
546
|
|
|
raise ValueError('missing values for inputs: {}'.format( |
547
|
|
|
[name for name in expected |
548
|
|
|
if name not in routed_args])) |
549
|
|
|
return OrderedDict((key, routed_args[key]) for key in expected) |
550
|
|
|
|
551
|
|
|
|
552
|
|
|
def find_bricks(top_bricks, predicate): |
553
|
|
|
"""Walk the brick hierarchy, return bricks that satisfy a predicate. |
554
|
|
|
|
555
|
|
|
Parameters |
556
|
|
|
---------- |
557
|
|
|
top_bricks : list |
558
|
|
|
A list of root bricks to search downward from. |
559
|
|
|
predicate : callable |
560
|
|
|
A callable that returns `True` for bricks that meet the |
561
|
|
|
desired criteria or `False` for those that don't. |
562
|
|
|
|
563
|
|
|
Returns |
564
|
|
|
------- |
565
|
|
|
found : list |
566
|
|
|
A list of all bricks that are descendants of any element of |
567
|
|
|
`top_bricks` that satisfy `predicate`. |
568
|
|
|
|
569
|
|
|
""" |
570
|
|
|
found = [] |
571
|
|
|
visited = set() |
572
|
|
|
to_visit = deque(top_bricks) |
573
|
|
|
while len(to_visit) > 0: |
574
|
|
|
current = to_visit.popleft() |
575
|
|
|
if current not in visited: |
576
|
|
|
visited.add(current) |
577
|
|
|
if predicate(current): |
578
|
|
|
found.append(current) |
579
|
|
|
to_visit.extend(current.children) |
580
|
|
|
return found |
581
|
|
|
|
Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.