Completed
Push — master ( e89cf9...699e27 )
by Dmitry
55:35
created

blocks/utils/__init__.py (7 issues)

Labels
Severity
1
from __future__ import print_function
0 ignored issues
show
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.sequences -> blocks.bricks.simple -> blocks.bricks.wrappers -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.recurrent -> blocks.bricks.recurrent.architectures -> blocks.bricks.recurrent.base -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.simple -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.recurrent -> blocks.bricks.recurrent.base -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.sequences -> blocks.bricks.simple -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.sequences -> blocks.bricks.simple -> blocks.bricks.interfaces -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.bn -> blocks.bricks.sequences -> blocks.bricks.simple -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
2
import sys
3
import contextlib
4
from collections import OrderedDict, deque
5
6
import numpy
7
import six
8
import theano
9
from theano import tensor
10
from theano import printing
11
from theano.gof.graph import Constant
12
from theano.tensor.shared_randomstreams import RandomStateSharedVariable
13
from theano.tensor.sharedvar import SharedVariable
14
15
16
def pack(arg):
17
    """Pack variables into a list.
18
19
    Parameters
20
    ----------
21
    arg : object
22
        Either a list or tuple, or any other Python object. Lists will be
23
        returned as is, and tuples will be cast to lists. Any other
24
        variable will be returned in a singleton list.
25
26
    Returns
27
    -------
28
    list
29
        List containing the arguments
30
31
    """
32
    if isinstance(arg, (list, tuple)):
33
        return list(arg)
34
    else:
35
        return [arg]
36
37
38
def unpack(arg, singleton=False):
39
    """Unpack variables from a list or tuple.
40
41
    Parameters
42
    ----------
43
    arg : object
44
        Either a list or tuple, or any other Python object. If passed a
45
        list or tuple of length one, the only element of that list will
46
        be returned. If passed a tuple of length greater than one, it
47
        will be cast to a list before returning. Any other variable
48
        will be returned as is.
49
    singleton : bool
50
        If ``True``, `arg` is expected to be a singleton (a list or tuple
51
        with exactly one element) and an exception is raised if this is not
52
        the case. ``False`` by default.
53
54
    Returns
55
    -------
56
    object
57
        A list of length greater than one, or any other Python object
58
        except tuple.
59
60
    """
61
    if isinstance(arg, (list, tuple)):
62
        if len(arg) == 1:
63
            return arg[0]
64
        else:
65
            if singleton:
66
                raise ValueError("Expected a singleton, got {}".
67
                                 format(arg))
68
            return list(arg)
69
    else:
70
        return arg
71
72
73
def shared_floatx_zeros_matching(shared_variable, name=None, **kwargs):
74
    r"""Create another shared variable with matching shape and broadcast.
75
76
    Parameters
77
    ----------
78
    shared_variable : :class:'tensor.TensorSharedVariable'
79
        A Theano shared variable with the desired shape and broadcastable
80
        flags.
81
    name : :obj:`str`, optional
82
        The name for the shared variable. Defaults to `None`.
83
    \*\*kwargs
84
        Keyword arguments to pass to the :func:`shared_floatx_zeros`
85
        function.
86
87
    Returns
88
    -------
89
    :class:'tensor.TensorSharedVariable'
90
        A new shared variable, initialized to all zeros, with the same
91
        shape and broadcastable flags as `shared_variable`.
92
93
94
    """
95
    if not is_shared_variable(shared_variable):
96
        raise ValueError('argument must be a shared variable')
97
    return shared_floatx_zeros(shared_variable.get_value().shape,
98
                               name=name,
99
                               broadcastable=shared_variable.broadcastable,
100
                               **kwargs)
101
102
103
def shared_floatx_zeros(shape, **kwargs):
104
    r"""Creates a shared variable array filled with zeros.
105
106
    Parameters
107
    ----------
108
    shape : tuple
109
        A tuple of integers representing the shape of the array.
110
    \*\*kwargs
111
        Keyword arguments to pass to the :func:`shared_floatx` function.
112
113
    Returns
114
    -------
115
    :class:'tensor.TensorSharedVariable'
116
        A Theano shared variable filled with zeros.
117
118
    """
119
    return shared_floatx(numpy.zeros(shape), **kwargs)
120
121
122
def shared_floatx_nans(shape, **kwargs):
123
    r"""Creates a shared variable array filled with nans.
124
125
    Parameters
126
    ----------
127
    shape : tuple
128
         A tuple of integers representing the shape of the array.
129
    \*\*kwargs
130
        Keyword arguments to pass to the :func:`shared_floatx` function.
131
132
    Returns
133
    -------
134
    :class:'tensor.TensorSharedVariable'
135
        A Theano shared variable filled with nans.
136
137
    """
138
    return shared_floatx(numpy.nan * numpy.zeros(shape), **kwargs)
139
140
141
def shared_floatx(value, name=None, borrow=False, dtype=None, **kwargs):
142
    r"""Transform a value into a shared variable of type floatX.
143
144
    Parameters
145
    ----------
146
    value : :class:`~numpy.ndarray`
147
        The value to associate with the Theano shared.
148
    name : :obj:`str`, optional
149
        The name for the shared variable. Defaults to `None`.
150
    borrow : :obj:`bool`, optional
151
        If set to True, the given `value` will not be copied if possible.
152
        This can save memory and speed. Defaults to False.
153
    dtype : :obj:`str`, optional
154
        The `dtype` of the shared variable. Default value is
155
        :attr:`config.floatX`.
156
    \*\*kwargs
157
        Keyword arguments to pass to the :func:`~theano.shared` function.
158
159
    Returns
160
    -------
161
    :class:`tensor.TensorSharedVariable`
162
        A Theano shared variable with the requested value and `dtype`.
163
164
    """
165
    if dtype is None:
166
        dtype = theano.config.floatX
167
    return theano.shared(theano._asarray(value, dtype=dtype),
168
                         name=name, borrow=borrow, **kwargs)
169
170
171
def shared_like(variable, name=None, **kwargs):
172
    r"""Construct a shared variable to hold the value of a tensor variable.
173
174
    Parameters
175
    ----------
176
    variable : :class:`~tensor.TensorVariable`
177
        The variable whose dtype and ndim will be used to construct
178
        the new shared variable.
179
    name : :obj:`str` or :obj:`None`
180
        The name of the shared variable. If None, the name is determined
181
        based on variable's name.
182
    \*\*kwargs
183
        Keyword arguments to pass to the :func:`~theano.shared` function.
184
185
    """
186
    variable = tensor.as_tensor_variable(variable)
187
    if name is None:
188
        name = "shared_{}".format(variable.name)
189
    return theano.shared(numpy.zeros((0,) * variable.ndim,
190
                                     dtype=variable.dtype),
191
                         name=name, **kwargs)
192
193
194
def reraise_as(new_exc):
195
    """Reraise an exception as a different type or with a message.
196
197
    This function ensures that the original traceback is kept, making for
198
    easier debugging.
199
200
    Parameters
201
    ----------
202
    new_exc : :class:`Exception` or :obj:`str`
203
        The new error to be raised e.g. (ValueError("New message"))
204
        or a string that will be prepended to the original exception
205
        message
206
207
    Notes
208
    -----
209
    Note that when reraising exceptions, the arguments of the original
210
    exception are cast to strings and appended to the error message. If
211
    you want to retain the original exception arguments, please use:
212
213
    >>> try:
214
    ...     1 / 0
215
    ... except Exception as e:
216
    ...     reraise_as(Exception("Extra information", *e.args))
217
    Traceback (most recent call last):
218
      ...
219
    Exception: 'Extra information, ...
220
221
    Examples
222
    --------
223
    >>> class NewException(Exception):
224
    ...     def __init__(self, message):
225
    ...         super(NewException, self).__init__(message)
226
    >>> try:
227
    ...     do_something_crazy()
228
    ... except Exception:
229
    ...     reraise_as(NewException("Informative message"))
230
    Traceback (most recent call last):
231
      ...
232
    NewException: Informative message ...
233
234
    """
235
    orig_exc_type, orig_exc_value, orig_exc_traceback = sys.exc_info()
236
237
    if isinstance(new_exc, six.string_types):
238
        new_exc = orig_exc_type(new_exc)
239
240
    if hasattr(new_exc, 'args'):
241
        if len(new_exc.args) > 0:
242
            # We add all the arguments to the message, to make sure that this
243
            # information isn't lost if this exception is reraised again
244
            new_message = ', '.join(str(arg) for arg in new_exc.args)
245
        else:
246
            new_message = ""
247
        new_message += '\n\nOriginal exception:\n\t' + orig_exc_type.__name__
248
        if hasattr(orig_exc_value, 'args') and len(orig_exc_value.args) > 0:
249
            if getattr(orig_exc_value, 'reraised', False):
250
                new_message += ': ' + str(orig_exc_value.args[0])
251
            else:
252
                new_message += ': ' + ', '.join(str(arg)
253
                                                for arg in orig_exc_value.args)
254
        new_exc.args = (new_message,) + new_exc.args[1:]
255
256
    new_exc.__cause__ = orig_exc_value
257
    new_exc.reraised = True
258
    six.reraise(type(new_exc), new_exc, orig_exc_traceback)
259
260
261
def check_theano_variable(variable, n_dim, dtype_prefix):
262
    """Check number of dimensions and dtype of a Theano variable.
263
264
    If the input is not a Theano variable, it is converted to one. `None`
265
    input is handled as a special case: no checks are done.
266
267
    Parameters
268
    ----------
269
    variable : :class:`~tensor.TensorVariable` or convertible to one
270
        A variable to check.
271
    n_dim : int
272
        Expected number of dimensions or None. If None, no check is
273
        performed.
274
    dtype : str
275
        Expected dtype prefix or None. If None, no check is performed.
276
277
    """
278
    if variable is None:
279
        return
280
281
    if not isinstance(variable, tensor.Variable):
282
        variable = tensor.as_tensor_variable(variable)
283
284
    if n_dim and variable.ndim != n_dim:
285
        raise ValueError("Wrong number of dimensions:"
286
                         "\n\texpected {}, got {}".format(
287
                             n_dim, variable.ndim))
288
289
    if dtype_prefix and not variable.dtype.startswith(dtype_prefix):
290
        raise ValueError("Wrong dtype prefix:"
291
                         "\n\texpected starting with {}, got {}".format(
292
                             dtype_prefix, variable.dtype))
293
294
295
def is_graph_input(variable):
296
    """Check if variable is a user-provided graph input.
297
298
    To be considered an input the variable must have no owner, and not
299
    be a constant or shared variable.
300
301
    Parameters
302
    ----------
303
    variable : :class:`~tensor.TensorVariable`
304
305
    Returns
306
    -------
307
    bool
308
        ``True`` If the variable is a user-provided input to the graph.
309
310
    """
311
    return (not variable.owner and
312
            not isinstance(variable, SharedVariable) and
313
            not isinstance(variable, Constant))
314
315
316
def is_shared_variable(variable):
317
    """Check if a variable is a Theano shared variable.
318
319
    Notes
320
    -----
321
    This function excludes shared variables that store the state of Theano
322
    random number generators.
323
324
    """
325
    return (isinstance(variable, SharedVariable) and
326
            not isinstance(variable, RandomStateSharedVariable) and
327
            not hasattr(variable.tag, 'is_rng'))
328
329
330
def dict_subset(dict_, keys, pop=False, must_have=True):
331
    """Return a subset of a dictionary corresponding to a set of keys.
332
333
    Parameters
334
    ----------
335
    dict_ : dict
336
        The dictionary.
337
    keys : iterable
338
        The keys of interest.
339
    pop : bool
340
        If ``True``, the pairs corresponding to the keys of interest are
341
        popped from the dictionary.
342
    must_have : bool
343
        If ``True``, a ValueError will be raised when trying to retrieve a
344
        key not present in the dictionary.
345
346
    Returns
347
    -------
348
    result : ``OrderedDict``
349
        An ordered dictionary of retrieved pairs. The order is the same as
350
        in the ``keys`` argument.
351
352
    """
353
    not_found = object()
354
355
    def extract(k):
356
        if pop:
357
            if must_have:
358
                return dict_.pop(k)
359
            return dict_.pop(k, not_found)
360
        if must_have:
361
            return dict_[k]
362
        return dict_.get(k, not_found)
363
364
    result = [(key, extract(key)) for key in keys]
365
    return OrderedDict([(k, v) for k, v in result if v is not not_found])
366
367
368
def dict_union(*dicts, **kwargs):
369
    r"""Return union of a sequence of disjoint dictionaries.
370
371
    Parameters
372
    ----------
373
    dicts : dicts
374
        A set of dictionaries with no keys in common. If the first
375
        dictionary in the sequence is an instance of `OrderedDict`, the
376
        result will be OrderedDict.
377
    \*\*kwargs
378
        Keywords and values to add to the resulting dictionary.
379
380
    Raises
381
    ------
382
    ValueError
383
        If a key appears twice in the dictionaries or keyword arguments.
384
385
    """
386
    dicts = list(dicts)
387
    if dicts and isinstance(dicts[0], OrderedDict):
388
        result = OrderedDict()
389
    else:
390
        result = {}
391
    for d in list(dicts) + [kwargs]:
392
        duplicate_keys = set(result.keys()) & set(d.keys())
393
        if duplicate_keys:
394
            raise ValueError("The following keys have duplicate entries: {}"
395
                             .format(", ".join(str(key) for key in
396
                                               duplicate_keys)))
397
        result.update(d)
398
    return result
399
400
401
def repr_attrs(instance, *attrs):
402
    r"""Prints a representation of an object with certain attributes.
403
404
    Parameters
405
    ----------
406
    instance : object
407
        The object of which to print the string representation
408
    \*attrs
409
        Names of attributes that should be printed.
410
411
    Examples
412
    --------
413
    >>> class A(object):
414
    ...     def __init__(self, value):
415
    ...         self.value = value
416
    >>> a = A('a_value')
417
    >>> repr(a)  # doctest: +SKIP
418
    <blocks.utils.A object at 0x7fb2b4741a10>
419
    >>> repr_attrs(a, 'value')  # doctest: +SKIP
420
    <blocks.utils.A object at 0x7fb2b4741a10: value=a_value>
421
422
    """
423
    orig_repr_template = ("<{0.__class__.__module__}.{0.__class__.__name__} "
424
                          "object at {1:#x}")
425
    if attrs:
426
        repr_template = (orig_repr_template + ": " +
427
                         ", ".join(["{0}={{0.{0}}}".format(attr)
428
                                    for attr in attrs]))
429
    repr_template += '>'
430
    orig_repr_template += '>'
431
    try:
432
        return repr_template.format(instance, id(instance))
433
    except Exception:
434
        return orig_repr_template.format(instance, id(instance))
435
436
437
def put_hook(variable, hook_fn, *args):
438
    r"""Put a hook on a Theano variables.
439
440
    Ensures that the hook function is executed every time when the value
441
    of the Theano variable is available.
442
443
    Parameters
444
    ----------
445
    variable : :class:`~tensor.TensorVariable`
446
        The variable to put a hook on.
447
    hook_fn : function
448
        The hook function. Should take a single argument: the variable's
449
        value.
450
    \*args : list
451
        Positional arguments to pass to the hook function.
452
453
    """
454
    return printing.Print(global_fn=lambda _, x: hook_fn(x, *args))(variable)
455
456
457
def ipdb_breakpoint(x):
458
    """A simple hook function for :func:`put_hook` that runs ipdb.
459
460
    Parameters
461
    ----------
462
    x : :class:`~numpy.ndarray`
463
        The value of the hooked variable.
464
465
    """
466
    import ipdb
467
    ipdb.set_trace()
468
469
470
def print_sum(x, header=None):
471
    if not header:
472
        header = 'print_sum'
473
    print(header + ':', x.sum())
474
475
476
def print_shape(x, header=None):
477
    if not header:
478
        header = 'print_shape'
479
    print(header + ':', x.shape)
480
481
482
@contextlib.contextmanager
483
def change_recursion_limit(limit):
484
    """Temporarily changes the recursion limit."""
485
    old_limit = sys.getrecursionlimit()
486
    if old_limit < limit:
487
        sys.setrecursionlimit(limit)
488
    yield
489
    sys.setrecursionlimit(old_limit)
490
491
492
def extract_args(expected, *args, **kwargs):
493
    r"""Route keyword and positional arguments to a list of names.
494
495
    A frequent situation is that a method of the class gets to
496
    know its positional arguments only when an instance of the class
497
    has been created. In such cases the signature of such method has to
498
    be `*args, **kwargs`. The downside of such signatures is that the
499
    validity of a call is not checked.
500
501
    Use :func:`extract_args` if your method knows at runtime, but not
502
    at evaluation/compile time, what arguments it actually expects,
503
    in order to check that they are correctly received.
504
505
    Parameters
506
    ----------
507
    expected : list of str
508
        A list of strings denoting names for the expected arguments,
509
        in order.
510
    args : iterable
511
        Positional arguments that have been passed.
512
    kwargs : Mapping
513
        Keyword arguments that have been passed.
514
515
    Returns
516
    -------
517
    routed_args : OrderedDict
518
        An OrderedDict mapping the names in `expected` to values drawn
519
        from either `args` or `kwargs` in the usual Python fashion.
520
521
    Raises
522
    ------
523
    KeyError
524
        If a keyword argument is passed, the key for which is not
525
        contained within `expected`.
526
    TypeError
527
        If an expected argument is accounted for in both the positional
528
        and keyword arguments.
529
    ValueError
530
        If certain arguments in `expected` are not assigned a value
531
        by either a positional or keyword argument.
532
533
    """
534
    # Use of zip() rather than equizip() intentional here. We want
535
    # to truncate to the length of args.
536
    routed_args = dict(zip(expected, args))
537
    for name in kwargs:
538
        if name not in expected:
539
            raise KeyError('invalid input name: {}'.format(name))
540
        elif name in routed_args:
541
            raise TypeError("got multiple values for "
542
                            "argument '{}'".format(name))
543
        else:
544
            routed_args[name] = kwargs[name]
545
    if set(expected) != set(routed_args):
546
        raise ValueError('missing values for inputs: {}'.format(
547
                         [name for name in expected
548
                          if name not in routed_args]))
549
    return OrderedDict((key, routed_args[key]) for key in expected)
550
551
552
def find_bricks(top_bricks, predicate):
553
    """Walk the brick hierarchy, return bricks that satisfy a predicate.
554
555
    Parameters
556
    ----------
557
    top_bricks : list
558
        A list of root bricks to search downward from.
559
    predicate : callable
560
        A callable that returns `True` for bricks that meet the
561
        desired criteria or `False` for those that don't.
562
563
    Returns
564
    -------
565
    found : list
566
        A list of all bricks that are descendants of any element of
567
        `top_bricks` that satisfy `predicate`.
568
569
    """
570
    found = []
571
    visited = set()
572
    to_visit = deque(top_bricks)
573
    while len(to_visit) > 0:
574
        current = to_visit.popleft()
575
        if current not in visited:
576
            visited.add(current)
577
            if predicate(current):
578
                found.append(current)
579
            to_visit.extend(current.children)
580
    return found
581