|
1
|
|
|
import re |
|
2
|
|
|
|
|
3
|
|
|
|
|
4
|
|
|
def add_role(var, role): |
|
5
|
|
|
r"""Add a role to a given Theano variable. |
|
6
|
|
|
|
|
7
|
|
|
Parameters |
|
8
|
|
|
---------- |
|
9
|
|
|
var : :class:`~tensor.TensorVariable` |
|
10
|
|
|
The variable to assign the new role to. |
|
11
|
|
|
role : :class:`.VariableRole` instance |
|
12
|
|
|
|
|
13
|
|
|
Notes |
|
14
|
|
|
----- |
|
15
|
|
|
Some roles are subroles of others (e.g. :const:`WEIGHT` is a subrole |
|
16
|
|
|
of :const:`PARAMETER`). This function will not add a role if a more |
|
17
|
|
|
specific role has already been added. If you need to replace a role |
|
18
|
|
|
with a parent role (e.g. replace :const:`WEIGHT` with |
|
19
|
|
|
:const:`PARAMETER`) you must do so manually. |
|
20
|
|
|
|
|
21
|
|
|
Examples |
|
22
|
|
|
-------- |
|
23
|
|
|
>>> from theano import tensor |
|
24
|
|
|
>>> W = tensor.matrix() |
|
25
|
|
|
>>> from blocks.roles import PARAMETER, WEIGHT |
|
26
|
|
|
>>> add_role(W, PARAMETER) |
|
27
|
|
|
>>> print(*W.tag.roles) |
|
28
|
|
|
PARAMETER |
|
29
|
|
|
>>> add_role(W, WEIGHT) |
|
30
|
|
|
>>> print(*W.tag.roles) |
|
31
|
|
|
WEIGHT |
|
32
|
|
|
>>> add_role(W, PARAMETER) |
|
33
|
|
|
>>> print(*W.tag.roles) |
|
34
|
|
|
WEIGHT |
|
35
|
|
|
|
|
36
|
|
|
""" |
|
37
|
|
|
roles = getattr(var.tag, 'roles', []) |
|
38
|
|
|
roles = [old_role for old_role in roles |
|
39
|
|
|
if not isinstance(role, old_role.__class__)] |
|
40
|
|
|
if not any(isinstance(old_role, role.__class__) for old_role in roles): |
|
41
|
|
|
roles += [role] |
|
42
|
|
|
var.tag.roles = roles |
|
43
|
|
|
|
|
44
|
|
|
|
|
45
|
|
|
def has_roles(var, roles, match_all=False): |
|
46
|
|
|
r"""Test if a variable has given roles taking subroles into account. |
|
47
|
|
|
|
|
48
|
|
|
Parameters |
|
49
|
|
|
---------- |
|
50
|
|
|
var : :class:`~tensor.TensorVariable` |
|
51
|
|
|
Variable being queried. |
|
52
|
|
|
roles : an iterable of :class:`.VariableRole` instances. |
|
53
|
|
|
match_all : bool, optional |
|
54
|
|
|
If ``True``, checks if the variable has all given roles. |
|
55
|
|
|
If ``False``, any of the roles is sufficient. |
|
56
|
|
|
``False`` by default. |
|
57
|
|
|
|
|
58
|
|
|
""" |
|
59
|
|
|
var_roles = getattr(var.tag, 'roles', []) |
|
60
|
|
|
matches = (any(isinstance(var_role, role.__class__) for |
|
61
|
|
|
var_role in var_roles) for role in roles) |
|
62
|
|
|
return all(matches) if match_all else any(matches) |
|
63
|
|
|
|
|
64
|
|
|
|
|
65
|
|
|
class VariableRole(object): |
|
66
|
|
|
"""Base class for all variable roles.""" |
|
67
|
|
|
def __eq__(self, other): |
|
68
|
|
|
return self.__class__ == other.__class__ |
|
69
|
|
|
|
|
70
|
|
|
def __repr__(self): |
|
71
|
|
|
return re.sub(r'(?!^)([A-Z]+)', r'_\1', |
|
72
|
|
|
self.__class__.__name__[:-4]).upper() |
|
73
|
|
|
|
|
74
|
|
|
def __hash__(self): |
|
75
|
|
|
return hash(str(self)) |
|
76
|
|
|
|
|
77
|
|
|
|
|
78
|
|
|
class InputRole(VariableRole): |
|
79
|
|
|
pass |
|
80
|
|
|
|
|
81
|
|
|
#: The input of a :class:`.Brick` |
|
82
|
|
|
INPUT = InputRole() |
|
83
|
|
|
|
|
84
|
|
|
|
|
85
|
|
|
class OutputRole(VariableRole): |
|
86
|
|
|
pass |
|
87
|
|
|
|
|
88
|
|
|
#: The output of a :class:`.Brick` |
|
89
|
|
|
OUTPUT = OutputRole() |
|
90
|
|
|
|
|
91
|
|
|
|
|
92
|
|
|
class CostRole(VariableRole): |
|
93
|
|
|
pass |
|
94
|
|
|
|
|
95
|
|
|
#: A scalar cost that can be used to train or regularize |
|
96
|
|
|
COST = CostRole() |
|
97
|
|
|
|
|
98
|
|
|
|
|
99
|
|
|
class PersistentRole(VariableRole): |
|
100
|
|
|
pass |
|
101
|
|
|
|
|
102
|
|
|
# Any persistent quantity that should be saved as part of the model |
|
103
|
|
|
PERSISTENT = PersistentRole() |
|
104
|
|
|
|
|
105
|
|
|
|
|
106
|
|
|
class ParameterRole(PersistentRole): |
|
107
|
|
|
pass |
|
108
|
|
|
|
|
109
|
|
|
#: A parameter of the model |
|
110
|
|
|
PARAMETER = ParameterRole() |
|
111
|
|
|
|
|
112
|
|
|
|
|
113
|
|
|
class AuxiliaryRole(VariableRole): |
|
114
|
|
|
pass |
|
115
|
|
|
|
|
116
|
|
|
#: Variables added to the graph as annotations |
|
117
|
|
|
AUXILIARY = AuxiliaryRole() |
|
118
|
|
|
|
|
119
|
|
|
|
|
120
|
|
|
class WeightRole(ParameterRole): |
|
121
|
|
|
pass |
|
122
|
|
|
|
|
123
|
|
|
#: The weight matrices of linear transformations |
|
124
|
|
|
WEIGHT = WeightRole() |
|
125
|
|
|
|
|
126
|
|
|
|
|
127
|
|
|
class BiasRole(ParameterRole): |
|
128
|
|
|
pass |
|
129
|
|
|
|
|
130
|
|
|
#: Biases of linear transformations |
|
131
|
|
|
BIAS = BiasRole() |
|
132
|
|
|
|
|
133
|
|
|
|
|
134
|
|
|
class InitialStateRole(ParameterRole): |
|
135
|
|
|
pass |
|
136
|
|
|
|
|
137
|
|
|
#: Initial state of a recurrent network |
|
138
|
|
|
INITIAL_STATE = InitialStateRole() |
|
139
|
|
|
|
|
140
|
|
|
|
|
141
|
|
|
class FilterRole(WeightRole): |
|
142
|
|
|
pass |
|
143
|
|
|
|
|
144
|
|
|
#: The filters (kernels) of a convolution operation |
|
145
|
|
|
FILTER = FilterRole() |
|
146
|
|
|
|
|
147
|
|
|
|
|
148
|
|
|
class DropoutRole(VariableRole): |
|
149
|
|
|
pass |
|
150
|
|
|
|
|
151
|
|
|
#: Inputs with applied dropout |
|
152
|
|
|
DROPOUT = DropoutRole() |
|
153
|
|
|
|
|
154
|
|
|
|
|
155
|
|
|
class CollectedRole(VariableRole): |
|
156
|
|
|
pass |
|
157
|
|
|
|
|
158
|
|
|
#: The replacement of a variable collected into a single shared variable |
|
159
|
|
|
COLLECTED = CollectedRole() |
|
160
|
|
|
|
|
161
|
|
|
|
|
162
|
|
|
class CollectorRole(ParameterRole): |
|
163
|
|
|
pass |
|
164
|
|
|
|
|
165
|
|
|
#: A collection of parameters combined into a single shared variable |
|
166
|
|
|
COLLECTOR = CollectorRole() |
|
167
|
|
|
|
|
168
|
|
|
|
|
169
|
|
|
class AlgorithmStateRole(VariableRole): |
|
170
|
|
|
pass |
|
171
|
|
|
|
|
172
|
|
|
#: Shared variables used in algorithms updates |
|
173
|
|
|
ALGORITHM_STATE = AlgorithmStateRole() |
|
174
|
|
|
|
|
175
|
|
|
|
|
176
|
|
|
class AlgorithmHyperparameterRole(AlgorithmStateRole): |
|
177
|
|
|
pass |
|
178
|
|
|
|
|
179
|
|
|
#: hyperparameters accociated with algorithms |
|
180
|
|
|
ALGORITHM_HYPERPARAMETER = AlgorithmHyperparameterRole() |
|
181
|
|
|
|
|
182
|
|
|
|
|
183
|
|
|
class AlgorithmBufferRole(AlgorithmStateRole): |
|
184
|
|
|
pass |
|
185
|
|
|
|
|
186
|
|
|
#: buffers accociated with algorithms |
|
187
|
|
|
ALGORITHM_BUFFER = AlgorithmBufferRole() |
|
188
|
|
|
|
|
189
|
|
|
|
|
190
|
|
|
class BatchNormPopulationStatisticsRole(PersistentRole): |
|
191
|
|
|
pass |
|
192
|
|
|
|
|
193
|
|
|
#: base role for batch normalization population statistics |
|
194
|
|
|
BATCH_NORM_POPULATION_STATISTICS = BatchNormPopulationStatisticsRole() |
|
195
|
|
|
|
|
196
|
|
|
|
|
197
|
|
|
class BatchNormPopulationMeanRole(BatchNormPopulationStatisticsRole): |
|
198
|
|
|
pass |
|
199
|
|
|
|
|
200
|
|
|
#: mean activations accumulated over the dataset |
|
201
|
|
|
BATCH_NORM_POPULATION_MEAN = BatchNormPopulationMeanRole() |
|
202
|
|
|
|
|
203
|
|
|
|
|
204
|
|
|
class BatchNormPopulationStdevRole(BatchNormPopulationStatisticsRole): |
|
205
|
|
|
pass |
|
206
|
|
|
|
|
207
|
|
|
#: standard deviations of activations accumulated over the dataset |
|
208
|
|
|
BATCH_NORM_POPULATION_STDEV = BatchNormPopulationStdevRole() |
|
209
|
|
|
|
|
210
|
|
|
|
|
211
|
|
|
class BatchNormGraphVariableRole(VariableRole): |
|
212
|
|
|
pass |
|
213
|
|
|
|
|
214
|
|
|
#: base for roles used for within-graph batch normalization replacement |
|
215
|
|
|
BATCH_NORM_GRAPH_VARIABLE = BatchNormGraphVariableRole() |
|
216
|
|
|
|
|
217
|
|
|
|
|
218
|
|
|
class BatchNormOffsetRole(BatchNormGraphVariableRole): |
|
219
|
|
|
pass |
|
220
|
|
|
|
|
221
|
|
|
#: offset applied in a BatchNormalization application (or its |
|
222
|
|
|
# batch-normalized replacement) |
|
223
|
|
|
BATCH_NORM_OFFSET = BatchNormOffsetRole() |
|
224
|
|
|
|
|
225
|
|
|
|
|
226
|
|
|
class BatchNormDivisorRole(BatchNormGraphVariableRole): |
|
227
|
|
|
pass |
|
228
|
|
|
|
|
229
|
|
|
#: divisor applied in a BatchNormalization application (or its |
|
230
|
|
|
# batch-normalized replacement) |
|
231
|
|
|
BATCH_NORM_DIVISOR = BatchNormDivisorRole() |
|
232
|
|
|
|
|
233
|
|
|
|
|
234
|
|
|
class BatchNormMinibatchEstimateRole(BatchNormGraphVariableRole): |
|
235
|
|
|
pass |
|
236
|
|
|
|
|
237
|
|
|
#: role added to variables that are the result of a batch normalization |
|
238
|
|
|
# replacement, rather than the original population statistics variables. |
|
239
|
|
|
BATCH_NORM_MINIBATCH_ESTIMATE = BatchNormMinibatchEstimateRole() |
|
240
|
|
|
|
|
241
|
|
|
|
|
242
|
|
|
class BatchNormScaleParameterRole(ParameterRole): |
|
243
|
|
|
pass |
|
244
|
|
|
|
|
245
|
|
|
#: role given to the scale parameter, referred to as "scale" in the |
|
246
|
|
|
# batch normalization manuscript, applied after normalizing. |
|
247
|
|
|
BATCH_NORM_SCALE_PARAMETER = BatchNormScaleParameterRole() |
|
248
|
|
|
|
|
249
|
|
|
|
|
250
|
|
|
class BatchNormShiftParameterRole(BiasRole): |
|
251
|
|
|
pass |
|
252
|
|
|
|
|
253
|
|
|
#: role given to the shift parameter, referred to as "beta" in the |
|
254
|
|
|
# batch normalization manuscript, applied after normalizing and scaling. |
|
255
|
|
|
# Inherits from BIAS, because there really is no functional difference |
|
256
|
|
|
# with a normal bias, and indeed these are the only biases present |
|
257
|
|
|
# inside a BatchNormalizedMLP. |
|
258
|
|
|
BATCH_NORM_SHIFT_PARAMETER = BatchNormShiftParameterRole() |
|
259
|
|
|
|