Issues (119)

blocks/bricks/wrappers.py (3 issues)

1
from abc import ABCMeta, abstractmethod
2
from six import add_metaclass
3
from theano import tensor
4
5
from blocks.bricks.base import (
6
    application, Application, rename_function)
7
from blocks.utils import dict_subset, pack
8
9
_wrapped_class_doc = \
0 ignored issues
show
Coding Style Naming introduced by
The name _wrapped_class_doc does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
10
    """A wrapped brick class.
11
12
This brick was automatically constructed by wrapping :class:`.{0}` with
13
:class:`.{1}`.
14
15
See Also
16
--------
17
:class:`~blocks.bricks.wrappers.BrickWrapper`
18
    For explanation of brick wrapping.
19
20
:class:`.{0}`
21
:class:`.{1}`
22
23
"""
24
25
_wrapped_application_doc = \
0 ignored issues
show
Coding Style Naming introduced by
The name _wrapped_application_doc does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
26
    """{0}
27
28
See Also
29
--------
30
:meth:`{1}.{2}`
31
    For documentation of the wrapped application method.
32
33
"""
34
35
_with_extra_dims_application_prefix = \
0 ignored issues
show
Coding Style Naming introduced by
The name _with_extra_dims_application_prefix does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
36
    """Wraps the application method with reshapes.
37
38
Parameters
39
----------
40
extra_ndim : int, optional
41
    The number of extra dimensions. Default is zero.
42
43
"""
44
45
46
@add_metaclass(ABCMeta)
47
class BrickWrapper(object):
48
    """Base class for wrapper metaclasses.
49
50
    Sometimes one wants to extend a brick with the capability to handle
51
    inputs different from what it was designed to handle. A typical
52
    example are inputs with more dimensions that was foreseen at
53
    the development stage. One way to proceed in such a situation
54
    is to write a decorator that wraps all application methods of
55
    the brick class by some additional logic before and after
56
    the application call. :class:`BrickWrapper` serves as a
57
    convenient base class for such decorators.
58
59
    Note, that since directly applying a decorator to a :class:`Brick`
60
    subclass will only take place after
61
    :func:`~blocks.bricks.base._Brick.__new__` is called, subclasses
62
    of :class:`BrickWrapper` should be applied by setting the `decorators`
63
    attribute of the new brick class, like in the example below:
64
65
    >>> from blocks.bricks.base import Brick
66
    >>> class WrappedBrick(Brick):
67
    ...     decorators = [WithExtraDims()]
68
69
    """
70
    def __call__(self, mcs, name, bases, namespace):
71
        """Calls :meth:`wrap` for all applications of the base class."""
72
        if not len(bases) == 1:
73
            raise ValueError("can only wrap one class")
74
        base, = bases
75
        for attribute in base.__dict__.values():
76
            if isinstance(attribute, Application):
77
                self.wrap(attribute, namespace)
78
        namespace['__doc__'] = _wrapped_class_doc.format(
79
            base.__name__, self.__class__.__name__)
80
81
    @abstractmethod
82
    def wrap(self, wrapped, namespace):
83
        """Wrap an application of the base brick.
84
85
        This method should be overriden to write into its
86
        `namespace` argument all required changes.
87
88
        Parameters
89
        ----------
90
        mcs : type
91
            The metaclass.
92
        wrapped : :class:`~blocks.bricks.base.Application`
93
            The application to be wrapped.
94
        namespace : dict
95
            The namespace of the class being created.
96
97
        """
98
        pass
99
100
101
class WithExtraDims(BrickWrapper):
102
    """Wraps a brick's applications to handle inputs with extra dimensions.
103
104
    A brick can be often reused even when data has more dimensions
105
    than in the default setting. An example is a situation when one wants
106
    to apply :meth:`~blocks.bricks.Softmax.categorical_cross_entropy`
107
    to temporal data, that is when an additional 'time' axis is prepended
108
    to its both `x` and `y` inputs.
109
110
    This wrapper adds reshapes required to use application
111
    methods of a brick with such data by merging the extra dimensions
112
    with the first non-extra one. Two key assumptions
113
    are made: that all inputs and outputs have the same number of extra
114
    dimensions and that these extra dimensions are equal throughout
115
    all inputs and outputs.
116
117
    While this might be inconvinient, the wrapped brick does not try to
118
    guess the number of extra dimensions, but demands it as an argument.
119
    The considerations of simplicity and reliability motivated this design
120
    choice. Upon availability in Blocks of a mechanism to request the
121
    expected number of dimensions for an input of a brick, this can be
122
    reconsidered.
123
124
    """
125
    def wrap(self, wrapped, namespace):
126
        def apply(self, application, *args, **kwargs):
127
            # extra_ndim is a mandatory parameter, but in order not to
128
            # confuse with positional inputs, it has to be extracted from
129
            # **kwargs
130
            extra_ndim = kwargs.get('extra_ndim', 0)
131
132
            inputs = dict(zip(application.inputs, args))
133
            inputs.update(dict_subset(kwargs, application.inputs,
134
                                      must_have=False))
135
            reshaped_inputs = inputs
136
            # To prevent pollution of the computation graph with no-ops
137
            if extra_ndim > 0:
138
                for name, input_ in inputs.items():
139
                    shape, ndim = input_.shape, input_.ndim
140
                    # Remember extra_dims for reshaping the outputs correctly.
141
                    # Does not matter from which input, since we assume
142
                    # extra dimension match for all inputs.
143
                    extra_dims = shape[:extra_ndim]
144
                    new_first_dim = tensor.prod(shape[:extra_ndim + 1])
145
                    new_shape = tensor.join(
146
                        0, new_first_dim[None], shape[extra_ndim + 1:])
147
                    reshaped_inputs[name] = input_.reshape(
148
                        new_shape, ndim=ndim - extra_ndim)
149
            outputs = wrapped.__get__(self, None)(**reshaped_inputs)
150
            if extra_ndim == 0:
151
                return outputs
152
            reshaped_outputs = []
153
            for output in pack(outputs):
154
                shape, ndim = output.shape, output.ndim
155
                new_shape = tensor.join(
156
                    0, extra_dims, (shape[0] // tensor.prod(extra_dims))[None],
157
                    shape[1:])
158
                reshaped_outputs.append(
159
                    output.reshape(new_shape, ndim=ndim + extra_ndim))
160
            return reshaped_outputs
161
162
        def apply_delegate(self):
163
            return wrapped.__get__(self, None)
164
165
        apply = application(rename_function(apply, wrapped.application_name))
166
        apply.__doc__ = _wrapped_application_doc.format(
167
            _with_extra_dims_application_prefix,
168
            wrapped.brick.__name__, wrapped.application_name)
169
        apply_delegate = apply.delegate(
170
            rename_function(apply_delegate,
171
                            wrapped.application_name + "_delegate"))
172
        namespace[wrapped.application_name] = apply
173
        namespace[wrapped.application_name + "_delegate"] = apply_delegate
174