1 | from abc import ABCMeta, abstractmethod |
||
2 | from six import add_metaclass |
||
3 | from theano import tensor |
||
4 | |||
5 | from blocks.bricks.base import ( |
||
6 | application, Application, rename_function) |
||
7 | from blocks.utils import dict_subset, pack |
||
8 | |||
9 | _wrapped_class_doc = \ |
||
0 ignored issues
–
show
|
|||
10 | """A wrapped brick class. |
||
11 | |||
12 | This brick was automatically constructed by wrapping :class:`.{0}` with |
||
13 | :class:`.{1}`. |
||
14 | |||
15 | See Also |
||
16 | -------- |
||
17 | :class:`~blocks.bricks.wrappers.BrickWrapper` |
||
18 | For explanation of brick wrapping. |
||
19 | |||
20 | :class:`.{0}` |
||
21 | :class:`.{1}` |
||
22 | |||
23 | """ |
||
24 | |||
25 | _wrapped_application_doc = \ |
||
0 ignored issues
–
show
The name
_wrapped_application_doc does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$ ).
This check looks for invalid names for a range of different identifiers. You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements. If your project includes a Pylint configuration file, the settings contained in that file take precedence. To find out more about Pylint, please refer to their site.
Loading history...
|
|||
26 | """{0} |
||
27 | |||
28 | See Also |
||
29 | -------- |
||
30 | :meth:`{1}.{2}` |
||
31 | For documentation of the wrapped application method. |
||
32 | |||
33 | """ |
||
34 | |||
35 | _with_extra_dims_application_prefix = \ |
||
0 ignored issues
–
show
The name
_with_extra_dims_application_prefix does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$ ).
This check looks for invalid names for a range of different identifiers. You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements. If your project includes a Pylint configuration file, the settings contained in that file take precedence. To find out more about Pylint, please refer to their site.
Loading history...
|
|||
36 | """Wraps the application method with reshapes. |
||
37 | |||
38 | Parameters |
||
39 | ---------- |
||
40 | extra_ndim : int, optional |
||
41 | The number of extra dimensions. Default is zero. |
||
42 | |||
43 | """ |
||
44 | |||
45 | |||
46 | @add_metaclass(ABCMeta) |
||
47 | class BrickWrapper(object): |
||
48 | """Base class for wrapper metaclasses. |
||
49 | |||
50 | Sometimes one wants to extend a brick with the capability to handle |
||
51 | inputs different from what it was designed to handle. A typical |
||
52 | example are inputs with more dimensions that was foreseen at |
||
53 | the development stage. One way to proceed in such a situation |
||
54 | is to write a decorator that wraps all application methods of |
||
55 | the brick class by some additional logic before and after |
||
56 | the application call. :class:`BrickWrapper` serves as a |
||
57 | convenient base class for such decorators. |
||
58 | |||
59 | Note, that since directly applying a decorator to a :class:`Brick` |
||
60 | subclass will only take place after |
||
61 | :func:`~blocks.bricks.base._Brick.__new__` is called, subclasses |
||
62 | of :class:`BrickWrapper` should be applied by setting the `decorators` |
||
63 | attribute of the new brick class, like in the example below: |
||
64 | |||
65 | >>> from blocks.bricks.base import Brick |
||
66 | >>> class WrappedBrick(Brick): |
||
67 | ... decorators = [WithExtraDims()] |
||
68 | |||
69 | """ |
||
70 | def __call__(self, mcs, name, bases, namespace): |
||
71 | """Calls :meth:`wrap` for all applications of the base class.""" |
||
72 | if not len(bases) == 1: |
||
73 | raise ValueError("can only wrap one class") |
||
74 | base, = bases |
||
75 | for attribute in base.__dict__.values(): |
||
76 | if isinstance(attribute, Application): |
||
77 | self.wrap(attribute, namespace) |
||
78 | namespace['__doc__'] = _wrapped_class_doc.format( |
||
79 | base.__name__, self.__class__.__name__) |
||
80 | |||
81 | @abstractmethod |
||
82 | def wrap(self, wrapped, namespace): |
||
83 | """Wrap an application of the base brick. |
||
84 | |||
85 | This method should be overriden to write into its |
||
86 | `namespace` argument all required changes. |
||
87 | |||
88 | Parameters |
||
89 | ---------- |
||
90 | mcs : type |
||
91 | The metaclass. |
||
92 | wrapped : :class:`~blocks.bricks.base.Application` |
||
93 | The application to be wrapped. |
||
94 | namespace : dict |
||
95 | The namespace of the class being created. |
||
96 | |||
97 | """ |
||
98 | pass |
||
99 | |||
100 | |||
101 | class WithExtraDims(BrickWrapper): |
||
102 | """Wraps a brick's applications to handle inputs with extra dimensions. |
||
103 | |||
104 | A brick can be often reused even when data has more dimensions |
||
105 | than in the default setting. An example is a situation when one wants |
||
106 | to apply :meth:`~blocks.bricks.Softmax.categorical_cross_entropy` |
||
107 | to temporal data, that is when an additional 'time' axis is prepended |
||
108 | to its both `x` and `y` inputs. |
||
109 | |||
110 | This wrapper adds reshapes required to use application |
||
111 | methods of a brick with such data by merging the extra dimensions |
||
112 | with the first non-extra one. Two key assumptions |
||
113 | are made: that all inputs and outputs have the same number of extra |
||
114 | dimensions and that these extra dimensions are equal throughout |
||
115 | all inputs and outputs. |
||
116 | |||
117 | While this might be inconvinient, the wrapped brick does not try to |
||
118 | guess the number of extra dimensions, but demands it as an argument. |
||
119 | The considerations of simplicity and reliability motivated this design |
||
120 | choice. Upon availability in Blocks of a mechanism to request the |
||
121 | expected number of dimensions for an input of a brick, this can be |
||
122 | reconsidered. |
||
123 | |||
124 | """ |
||
125 | def wrap(self, wrapped, namespace): |
||
126 | def apply(self, application, *args, **kwargs): |
||
127 | # extra_ndim is a mandatory parameter, but in order not to |
||
128 | # confuse with positional inputs, it has to be extracted from |
||
129 | # **kwargs |
||
130 | extra_ndim = kwargs.get('extra_ndim', 0) |
||
131 | |||
132 | inputs = dict(zip(application.inputs, args)) |
||
133 | inputs.update(dict_subset(kwargs, application.inputs, |
||
134 | must_have=False)) |
||
135 | reshaped_inputs = inputs |
||
136 | # To prevent pollution of the computation graph with no-ops |
||
137 | if extra_ndim > 0: |
||
138 | for name, input_ in inputs.items(): |
||
139 | shape, ndim = input_.shape, input_.ndim |
||
140 | # Remember extra_dims for reshaping the outputs correctly. |
||
141 | # Does not matter from which input, since we assume |
||
142 | # extra dimension match for all inputs. |
||
143 | extra_dims = shape[:extra_ndim] |
||
144 | new_first_dim = tensor.prod(shape[:extra_ndim + 1]) |
||
145 | new_shape = tensor.join( |
||
146 | 0, new_first_dim[None], shape[extra_ndim + 1:]) |
||
147 | reshaped_inputs[name] = input_.reshape( |
||
148 | new_shape, ndim=ndim - extra_ndim) |
||
149 | outputs = wrapped.__get__(self, None)(**reshaped_inputs) |
||
150 | if extra_ndim == 0: |
||
151 | return outputs |
||
152 | reshaped_outputs = [] |
||
153 | for output in pack(outputs): |
||
154 | shape, ndim = output.shape, output.ndim |
||
155 | new_shape = tensor.join( |
||
156 | 0, extra_dims, (shape[0] // tensor.prod(extra_dims))[None], |
||
157 | shape[1:]) |
||
158 | reshaped_outputs.append( |
||
159 | output.reshape(new_shape, ndim=ndim + extra_ndim)) |
||
160 | return reshaped_outputs |
||
161 | |||
162 | def apply_delegate(self): |
||
163 | return wrapped.__get__(self, None) |
||
164 | |||
165 | apply = application(rename_function(apply, wrapped.application_name)) |
||
166 | apply.__doc__ = _wrapped_application_doc.format( |
||
167 | _with_extra_dims_application_prefix, |
||
168 | wrapped.brick.__name__, wrapped.application_name) |
||
169 | apply_delegate = apply.delegate( |
||
170 | rename_function(apply_delegate, |
||
171 | wrapped.application_name + "_delegate")) |
||
172 | namespace[wrapped.application_name] = apply |
||
173 | namespace[wrapped.application_name + "_delegate"] = apply_delegate |
||
174 |
This check looks for invalid names for a range of different identifiers.
You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.
If your project includes a Pylint configuration file, the settings contained in that file take precedence.
To find out more about Pylint, please refer to their site.