Completed
Pull Request — master (#941)
by David
20:04 queued 10:23
created

blocks.utils.find_bricks()   B

Complexity

Conditions 4

Size

Total Lines 29

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 4
dl 0
loc 29
rs 8.5806
1
from __future__ import print_function
2
import sys
3
import contextlib
4
from collections import OrderedDict, deque
5
6
import numpy
7
import six
8
import theano
9
from theano import tensor
10
from theano import printing
11
from theano.gof.graph import Constant
12
from theano.tensor.shared_randomstreams import RandomStateSharedVariable
13
from theano.tensor.sharedvar import SharedVariable
14
15
16
def pack(arg):
17
    """Pack variables into a list.
18
19
    Parameters
20
    ----------
21
    arg : object
22
        Either a list or tuple, or any other Python object. Lists will be
23
        returned as is, and tuples will be cast to lists. Any other
24
        variable will be returned in a singleton list.
25
26
    Returns
27
    -------
28
    list
29
        List containing the arguments
30
31
    """
32
    if isinstance(arg, (list, tuple)):
33
        return list(arg)
34
    else:
35
        return [arg]
36
37
38
def unpack(arg, singleton=False):
39
    """Unpack variables from a list or tuple.
40
41
    Parameters
42
    ----------
43
    arg : object
44
        Either a list or tuple, or any other Python object. If passed a
45
        list or tuple of length one, the only element of that list will
46
        be returned. If passed a tuple of length greater than one, it
47
        will be cast to a list before returning. Any other variable
48
        will be returned as is.
49
    singleton : bool
50
        If ``True``, `arg` is expected to be a singleton (a list or tuple
51
        with exactly one element) and an exception is raised if this is not
52
        the case. ``False`` by default.
53
54
    Returns
55
    -------
56
    object
57
        A list of length greater than one, or any other Python object
58
        except tuple.
59
60
    """
61
    if isinstance(arg, (list, tuple)):
62
        if len(arg) == 1:
63
            return arg[0]
64
        else:
65
            if singleton:
66
                raise ValueError("Expected a singleton, got {}".
67
                                 format(arg))
68
            return list(arg)
69
    else:
70
        return arg
71
72
73
def shared_floatx_zeros(shape, **kwargs):
74
    r"""Creates a shared variable array filled with zeros.
75
76
    Parameters
77
    ----------
78
    shape : tuple
79
        A tuple of integers representing the shape of the array.
80
    \*\*kwargs
81
        Keyword arguments to pass to the :func:`shared_floatx` function.
82
83
    Returns
84
    -------
85
    :class:'tensor.TensorSharedVariable'
86
        A Theano shared variable filled with zeros.
87
88
    """
89
    return shared_floatx(numpy.zeros(shape), **kwargs)
90
91
92
def shared_floatx_nans(shape, **kwargs):
93
    r"""Creates a shared variable array filled with nans.
94
95
    Parameters
96
    ----------
97
    shape : tuple
98
         A tuple of integers representing the shape of the array.
99
    \*\*kwargs
100
        Keyword arguments to pass to the :func:`shared_floatx` function.
101
102
    Returns
103
    -------
104
    :class:'tensor.TensorSharedVariable'
105
        A Theano shared variable filled with nans.
106
107
    """
108
    return shared_floatx(numpy.nan * numpy.zeros(shape), **kwargs)
109
110
111
def shared_floatx(value, name=None, borrow=False, dtype=None, **kwargs):
112
    r"""Transform a value into a shared variable of type floatX.
113
114
    Parameters
115
    ----------
116
    value : :class:`~numpy.ndarray`
117
        The value to associate with the Theano shared.
118
    name : :obj:`str`, optional
119
        The name for the shared variable. Defaults to `None`.
120
    borrow : :obj:`bool`, optional
121
        If set to True, the given `value` will not be copied if possible.
122
        This can save memory and speed. Defaults to False.
123
    dtype : :obj:`str`, optional
124
        The `dtype` of the shared variable. Default value is
125
        :attr:`config.floatX`.
126
    \*\*kwargs
127
        Keyword arguments to pass to the :func:`~theano.shared` function.
128
129
    Returns
130
    -------
131
    :class:`tensor.TensorSharedVariable`
132
        A Theano shared variable with the requested value and `dtype`.
133
134
    """
135
    if dtype is None:
136
        dtype = theano.config.floatX
137
    return theano.shared(theano._asarray(value, dtype=dtype),
138
                         name=name, borrow=borrow, **kwargs)
139
140
141
def shared_like(variable, name=None, **kwargs):
142
    r"""Construct a shared variable to hold the value of a tensor variable.
143
144
    Parameters
145
    ----------
146
    variable : :class:`~tensor.TensorVariable`
147
        The variable whose dtype and ndim will be used to construct
148
        the new shared variable.
149
    name : :obj:`str` or :obj:`None`
150
        The name of the shared variable. If None, the name is determined
151
        based on variable's name.
152
    \*\*kwargs
153
        Keyword arguments to pass to the :func:`~theano.shared` function.
154
155
    """
156
    variable = tensor.as_tensor_variable(variable)
157
    if name is None:
158
        name = "shared_{}".format(variable.name)
159
    return theano.shared(numpy.zeros((0,) * variable.ndim,
160
                                     dtype=variable.dtype),
161
                         name=name, **kwargs)
162
163
164
def reraise_as(new_exc):
165
    """Reraise an exception as a different type or with a message.
166
167
    This function ensures that the original traceback is kept, making for
168
    easier debugging.
169
170
    Parameters
171
    ----------
172
    new_exc : :class:`Exception` or :obj:`str`
173
        The new error to be raised e.g. (ValueError("New message"))
174
        or a string that will be prepended to the original exception
175
        message
176
177
    Notes
178
    -----
179
    Note that when reraising exceptions, the arguments of the original
180
    exception are cast to strings and appended to the error message. If
181
    you want to retain the original exception arguments, please use:
182
183
    >>> try:
184
    ...     1 / 0
185
    ... except Exception as e:
186
    ...     reraise_as(Exception("Extra information", *e.args))
187
    Traceback (most recent call last):
188
      ...
189
    Exception: 'Extra information, ...
190
191
    Examples
192
    --------
193
    >>> class NewException(Exception):
194
    ...     def __init__(self, message):
195
    ...         super(NewException, self).__init__(message)
196
    >>> try:
197
    ...     do_something_crazy()
198
    ... except Exception:
199
    ...     reraise_as(NewException("Informative message"))
200
    Traceback (most recent call last):
201
      ...
202
    NewException: Informative message ...
203
204
    """
205
    orig_exc_type, orig_exc_value, orig_exc_traceback = sys.exc_info()
206
207
    if isinstance(new_exc, six.string_types):
208
        new_exc = orig_exc_type(new_exc)
209
210
    if hasattr(new_exc, 'args'):
211
        if len(new_exc.args) > 0:
212
            # We add all the arguments to the message, to make sure that this
213
            # information isn't lost if this exception is reraised again
214
            new_message = ', '.join(str(arg) for arg in new_exc.args)
215
        else:
216
            new_message = ""
217
        new_message += '\n\nOriginal exception:\n\t' + orig_exc_type.__name__
218
        if hasattr(orig_exc_value, 'args') and len(orig_exc_value.args) > 0:
219
            if getattr(orig_exc_value, 'reraised', False):
220
                new_message += ': ' + str(orig_exc_value.args[0])
221
            else:
222
                new_message += ': ' + ', '.join(str(arg)
223
                                                for arg in orig_exc_value.args)
224
        new_exc.args = (new_message,) + new_exc.args[1:]
225
226
    new_exc.__cause__ = orig_exc_value
227
    new_exc.reraised = True
228
    six.reraise(type(new_exc), new_exc, orig_exc_traceback)
229
230
231
def check_theano_variable(variable, n_dim, dtype_prefix):
232
    """Check number of dimensions and dtype of a Theano variable.
233
234
    If the input is not a Theano variable, it is converted to one. `None`
235
    input is handled as a special case: no checks are done.
236
237
    Parameters
238
    ----------
239
    variable : :class:`~tensor.TensorVariable` or convertible to one
240
        A variable to check.
241
    n_dim : int
242
        Expected number of dimensions or None. If None, no check is
243
        performed.
244
    dtype : str
245
        Expected dtype prefix or None. If None, no check is performed.
246
247
    """
248
    if variable is None:
249
        return
250
251
    if not isinstance(variable, tensor.Variable):
252
        variable = tensor.as_tensor_variable(variable)
253
254
    if n_dim and variable.ndim != n_dim:
255
        raise ValueError("Wrong number of dimensions:"
256
                         "\n\texpected {}, got {}".format(
257
                             n_dim, variable.ndim))
258
259
    if dtype_prefix and not variable.dtype.startswith(dtype_prefix):
260
        raise ValueError("Wrong dtype prefix:"
261
                         "\n\texpected starting with {}, got {}".format(
262
                             dtype_prefix, variable.dtype))
263
264
265
def is_graph_input(variable):
266
    """Check if variable is a user-provided graph input.
267
268
    To be considered an input the variable must have no owner, and not
269
    be a constant or shared variable.
270
271
    Parameters
272
    ----------
273
    variable : :class:`~tensor.TensorVariable`
274
275
    Returns
276
    -------
277
    bool
278
        ``True`` If the variable is a user-provided input to the graph.
279
280
    """
281
    return (not variable.owner and
282
            not isinstance(variable, SharedVariable) and
283
            not isinstance(variable, Constant))
284
285
286
def is_shared_variable(variable):
287
    """Check if a variable is a Theano shared variable.
288
289
    Notes
290
    -----
291
    This function excludes shared variables that store the state of Theano
292
    random number generators.
293
294
    """
295
    return (isinstance(variable, SharedVariable) and
296
            not isinstance(variable, RandomStateSharedVariable) and
297
            not hasattr(variable.tag, 'is_rng'))
298
299
300
def dict_subset(dict_, keys, pop=False, must_have=True):
301
    """Return a subset of a dictionary corresponding to a set of keys.
302
303
    Parameters
304
    ----------
305
    dict_ : dict
306
        The dictionary.
307
    keys : iterable
308
        The keys of interest.
309
    pop : bool
310
        If ``True``, the pairs corresponding to the keys of interest are
311
        popped from the dictionary.
312
    must_have : bool
313
        If ``True``, a ValueError will be raised when trying to retrieve a
314
        key not present in the dictionary.
315
316
    Returns
317
    -------
318
    result : ``OrderedDict``
319
        An ordered dictionary of retrieved pairs. The order is the same as
320
        in the ``keys`` argument.
321
322
    """
323
    not_found = object()
324
325
    def extract(k):
326
        if pop:
327
            if must_have:
328
                return dict_.pop(k)
329
            return dict_.pop(k, not_found)
330
        if must_have:
331
            return dict_[k]
332
        return dict_.get(k, not_found)
333
334
    result = [(key, extract(key)) for key in keys]
335
    return OrderedDict([(k, v) for k, v in result if v is not not_found])
336
337
338
def dict_union(*dicts, **kwargs):
339
    r"""Return union of a sequence of disjoint dictionaries.
340
341
    Parameters
342
    ----------
343
    dicts : dicts
344
        A set of dictionaries with no keys in common. If the first
345
        dictionary in the sequence is an instance of `OrderedDict`, the
346
        result will be OrderedDict.
347
    \*\*kwargs
348
        Keywords and values to add to the resulting dictionary.
349
350
    Raises
351
    ------
352
    ValueError
353
        If a key appears twice in the dictionaries or keyword arguments.
354
355
    """
356
    dicts = list(dicts)
357
    if dicts and isinstance(dicts[0], OrderedDict):
358
        result = OrderedDict()
359
    else:
360
        result = {}
361
    for d in list(dicts) + [kwargs]:
362
        duplicate_keys = set(result.keys()) & set(d.keys())
363
        if duplicate_keys:
364
            raise ValueError("The following keys have duplicate entries: {}"
365
                             .format(", ".join(str(key) for key in
366
                                               duplicate_keys)))
367
        result.update(d)
368
    return result
369
370
371
def repr_attrs(instance, *attrs):
372
    r"""Prints a representation of an object with certain attributes.
373
374
    Parameters
375
    ----------
376
    instance : object
377
        The object of which to print the string representation
378
    \*attrs
379
        Names of attributes that should be printed.
380
381
    Examples
382
    --------
383
    >>> class A(object):
384
    ...     def __init__(self, value):
385
    ...         self.value = value
386
    >>> a = A('a_value')
387
    >>> repr(a)  # doctest: +SKIP
388
    <blocks.utils.A object at 0x7fb2b4741a10>
389
    >>> repr_attrs(a, 'value')  # doctest: +SKIP
390
    <blocks.utils.A object at 0x7fb2b4741a10: value=a_value>
391
392
    """
393
    orig_repr_template = ("<{0.__class__.__module__}.{0.__class__.__name__} "
394
                          "object at {1:#x}")
395
    if attrs:
396
        repr_template = (orig_repr_template + ": " +
397
                         ", ".join(["{0}={{0.{0}}}".format(attr)
398
                                    for attr in attrs]))
399
    repr_template += '>'
400
    orig_repr_template += '>'
401
    try:
402
        return repr_template.format(instance, id(instance))
403
    except Exception:
404
        return orig_repr_template.format(instance, id(instance))
405
406
407
def put_hook(variable, hook_fn, *args):
408
    r"""Put a hook on a Theano variables.
409
410
    Ensures that the hook function is executed every time when the value
411
    of the Theano variable is available.
412
413
    Parameters
414
    ----------
415
    variable : :class:`~tensor.TensorVariable`
416
        The variable to put a hook on.
417
    hook_fn : function
418
        The hook function. Should take a single argument: the variable's
419
        value.
420
    \*args : list
421
        Positional arguments to pass to the hook function.
422
423
    """
424
    return printing.Print(global_fn=lambda _, x: hook_fn(x, *args))(variable)
425
426
427
def ipdb_breakpoint(x):
428
    """A simple hook function for :func:`put_hook` that runs ipdb.
429
430
    Parameters
431
    ----------
432
    x : :class:`~numpy.ndarray`
433
        The value of the hooked variable.
434
435
    """
436
    import ipdb
437
    ipdb.set_trace()
438
439
440
def print_sum(x, header=None):
441
    if not header:
442
        header = 'print_sum'
443
    print(header + ':', x.sum())
444
445
446
def print_shape(x, header=None):
447
    if not header:
448
        header = 'print_shape'
449
    print(header + ':', x.shape)
450
451
452
@contextlib.contextmanager
453
def change_recursion_limit(limit):
454
    """Temporarily changes the recursion limit."""
455
    old_limit = sys.getrecursionlimit()
456
    sys.setrecursionlimit(limit)
457
    yield
458
    sys.setrecursionlimit(old_limit)
459
460
461
def extract_args(expected, *args, **kwargs):
462
    r"""Route keyword and positional arguments to a list of names.
463
464
    A frequent situation is that a method of the class gets to
465
    know its positional arguments only when an instance of the class
466
    has been created. In such cases the signature of such method has to
467
    be `*args, **kwargs`. The downside of such signatures is that the
468
    validity of a call is not checked.
469
470
    Use :func:`extract_args` if your method knows at runtime, but not
471
    at evaluation/compile time, what arguments it actually expects,
472
    in order to check that they are correctly received.
473
474
    Parameters
475
    ----------
476
    expected : list of str
477
        A list of strings denoting names for the expected arguments,
478
        in order.
479
    args : iterable
480
        Positional arguments that have been passed.
481
    kwargs : Mapping
482
        Keyword arguments that have been passed.
483
484
    Returns
485
    -------
486
    routed_args : OrderedDict
487
        An OrderedDict mapping the names in `expected` to values drawn
488
        from either `args` or `kwargs` in the usual Python fashion.
489
490
    Raises
491
    ------
492
    KeyError
493
        If a keyword argument is passed, the key for which is not
494
        contained within `expected`.
495
    TypeError
496
        If an expected argument is accounted for in both the positional
497
        and keyword arguments.
498
    ValueError
499
        If certain arguments in `expected` are not assigned a value
500
        by either a positional or keyword argument.
501
502
    """
503
    # Use of zip() rather than equizip() intentional here. We want
504
    # to truncate to the length of args.
505
    routed_args = dict(zip(expected, args))
506
    for name in kwargs:
507
        if name not in expected:
508
            raise KeyError('invalid input name: {}'.format(name))
509
        elif name in routed_args:
510
            raise TypeError("got multiple values for "
511
                            "argument '{}'".format(name))
512
        else:
513
            routed_args[name] = kwargs[name]
514
    if set(expected) != set(routed_args):
515
        raise ValueError('missing values for inputs: {}'.format(
516
                         [name for name in expected
517
                          if name not in routed_args]))
518
    return OrderedDict((key, routed_args[key]) for key in expected)
519
520
521
def find_bricks(top_bricks, predicate):
522
    """Walk the brick hierarchy, return bricks that satisfy a predicate.
523
524
    Parameters
525
    ----------
526
    top_bricks : list
527
        A list of root bricks to search downward from.
528
    predicate : callable
529
        A callable that returns `True` for bricks that meet the
530
        desired criteria or `False` for those that don't.
531
532
    Returns
533
    -------
534
    found : list
535
        A list of all bricks that are descendants of any element of
536
        `top_bricks` that satisfy `predicate`.
537
538
    """
539
    found = []
540
    visited = set()
541
    to_visit = deque(top_bricks)
542
    while len(to_visit) > 0:
543
        current = to_visit.popleft()
544
        if current not in visited:
545
            visited.add(current)
546
            if predicate(current):
547
                found.append(current)
548
            to_visit.extend(current.children)
549
    return found
550