1
|
|
|
import re |
2
|
|
|
|
3
|
|
|
|
4
|
|
|
def add_role(var, role): |
5
|
|
|
r"""Add a role to a given Theano variable. |
6
|
|
|
|
7
|
|
|
Parameters |
8
|
|
|
---------- |
9
|
|
|
var : :class:`~tensor.TensorVariable` |
10
|
|
|
The variable to assign the new role to. |
11
|
|
|
role : :class:`.VariableRole` instance |
12
|
|
|
|
13
|
|
|
Notes |
14
|
|
|
----- |
15
|
|
|
Some roles are subroles of others (e.g. :const:`WEIGHT` is a subrole |
16
|
|
|
of :const:`PARAMETER`). This function will not add a role if a more |
17
|
|
|
specific role has already been added. If you need to replace a role |
18
|
|
|
with a parent role (e.g. replace :const:`WEIGHT` with |
19
|
|
|
:const:`PARAMETER`) you must do so manually. |
20
|
|
|
|
21
|
|
|
Examples |
22
|
|
|
-------- |
23
|
|
|
>>> from theano import tensor |
24
|
|
|
>>> W = tensor.matrix() |
25
|
|
|
>>> from blocks.roles import PARAMETER, WEIGHT |
26
|
|
|
>>> add_role(W, PARAMETER) |
27
|
|
|
>>> print(*W.tag.roles) |
28
|
|
|
PARAMETER |
29
|
|
|
>>> add_role(W, WEIGHT) |
30
|
|
|
>>> print(*W.tag.roles) |
31
|
|
|
WEIGHT |
32
|
|
|
>>> add_role(W, PARAMETER) |
33
|
|
|
>>> print(*W.tag.roles) |
34
|
|
|
WEIGHT |
35
|
|
|
|
36
|
|
|
""" |
37
|
|
|
roles = getattr(var.tag, 'roles', []) |
38
|
|
|
roles = [old_role for old_role in roles |
39
|
|
|
if not isinstance(role, old_role.__class__)] |
40
|
|
|
if not any(isinstance(old_role, role.__class__) for old_role in roles): |
41
|
|
|
roles += [role] |
42
|
|
|
var.tag.roles = roles |
43
|
|
|
|
44
|
|
|
|
45
|
|
|
def has_roles(var, roles, match_all=False): |
46
|
|
|
r"""Test if a variable has given roles taking subroles into account. |
47
|
|
|
|
48
|
|
|
Parameters |
49
|
|
|
---------- |
50
|
|
|
var : :class:`~tensor.TensorVariable` |
51
|
|
|
Variable being queried. |
52
|
|
|
roles : an iterable of :class:`.VariableRole` instances. |
53
|
|
|
match_all : bool, optional |
54
|
|
|
If ``True``, checks if the variable has all given roles. |
55
|
|
|
If ``False``, any of the roles is sufficient. |
56
|
|
|
``False`` by default. |
57
|
|
|
|
58
|
|
|
""" |
59
|
|
|
var_roles = getattr(var.tag, 'roles', []) |
60
|
|
|
matches = (any(isinstance(var_role, role.__class__) for |
61
|
|
|
var_role in var_roles) for role in roles) |
62
|
|
|
return all(matches) if match_all else any(matches) |
63
|
|
|
|
64
|
|
|
|
65
|
|
|
class VariableRole(object): |
66
|
|
|
"""Base class for all variable roles.""" |
67
|
|
|
def __eq__(self, other): |
68
|
|
|
return self.__class__ == other.__class__ |
69
|
|
|
|
70
|
|
|
def __repr__(self): |
71
|
|
|
return re.sub(r'(?!^)([A-Z]+)', r'_\1', |
72
|
|
|
self.__class__.__name__[:-4]).upper() |
73
|
|
|
|
74
|
|
|
|
75
|
|
|
class InputRole(VariableRole): |
76
|
|
|
pass |
77
|
|
|
|
78
|
|
|
#: The input of a :class:`~.bricks.Brick` |
79
|
|
|
INPUT = InputRole() |
80
|
|
|
|
81
|
|
|
|
82
|
|
|
class OutputRole(VariableRole): |
83
|
|
|
pass |
84
|
|
|
|
85
|
|
|
#: The output of a :class:`~.bricks.Brick` |
86
|
|
|
OUTPUT = OutputRole() |
87
|
|
|
|
88
|
|
|
|
89
|
|
|
class CostRole(VariableRole): |
90
|
|
|
pass |
91
|
|
|
|
92
|
|
|
#: A scalar cost that can be used to train or regularize |
93
|
|
|
COST = CostRole() |
94
|
|
|
|
95
|
|
|
|
96
|
|
|
class PersistentRole(VariableRole): |
97
|
|
|
pass |
98
|
|
|
|
99
|
|
|
# Any persistent quantity that should be saved as part of the model |
100
|
|
|
PERSISTENT = PersistentRole() |
101
|
|
|
|
102
|
|
|
|
103
|
|
|
class ParameterRole(PersistentRole): |
104
|
|
|
pass |
105
|
|
|
|
106
|
|
|
#: A parameter of the model |
107
|
|
|
PARAMETER = ParameterRole() |
108
|
|
|
|
109
|
|
|
|
110
|
|
|
class AuxiliaryRole(VariableRole): |
111
|
|
|
pass |
112
|
|
|
|
113
|
|
|
#: Variables added to the graph as annotations |
114
|
|
|
AUXILIARY = AuxiliaryRole() |
115
|
|
|
|
116
|
|
|
|
117
|
|
|
class WeightRole(ParameterRole): |
118
|
|
|
pass |
119
|
|
|
|
120
|
|
|
#: The weight matrices of linear transformations |
121
|
|
|
WEIGHT = WeightRole() |
122
|
|
|
|
123
|
|
|
|
124
|
|
|
class BiasRole(ParameterRole): |
125
|
|
|
pass |
126
|
|
|
|
127
|
|
|
#: Biases of linear transformations |
128
|
|
|
BIAS = BiasRole() |
129
|
|
|
|
130
|
|
|
|
131
|
|
|
class InitialStateRole(ParameterRole): |
132
|
|
|
pass |
133
|
|
|
|
134
|
|
|
#: Initial state of a recurrent network |
135
|
|
|
INITIAL_STATE = InitialStateRole() |
136
|
|
|
|
137
|
|
|
|
138
|
|
|
class FilterRole(WeightRole): |
139
|
|
|
pass |
140
|
|
|
|
141
|
|
|
#: The filters (kernels) of a convolution operation |
142
|
|
|
FILTER = FilterRole() |
143
|
|
|
|
144
|
|
|
|
145
|
|
|
class DropoutRole(VariableRole): |
146
|
|
|
pass |
147
|
|
|
|
148
|
|
|
#: Inputs with applied dropout |
149
|
|
|
DROPOUT = DropoutRole() |
150
|
|
|
|
151
|
|
|
|
152
|
|
|
class CollectedRole(VariableRole): |
153
|
|
|
pass |
154
|
|
|
|
155
|
|
|
#: The replacement of a variable collected into a single shared variable |
156
|
|
|
COLLECTED = CollectedRole() |
157
|
|
|
|
158
|
|
|
|
159
|
|
|
class CollectorRole(ParameterRole): |
160
|
|
|
pass |
161
|
|
|
|
162
|
|
|
#: A collection of parameters combined into a single shared variable |
163
|
|
|
COLLECTOR = CollectorRole() |
164
|
|
|
|
165
|
|
|
|
166
|
|
|
class AlgorithmStateRole(VariableRole): |
167
|
|
|
pass |
168
|
|
|
|
169
|
|
|
#: Shared variables used in algorithms updates |
170
|
|
|
ALGORITHM_STATE = AlgorithmStateRole() |
171
|
|
|
|
172
|
|
|
|
173
|
|
|
class AlgorithmHyperparameterRole(AlgorithmStateRole): |
174
|
|
|
pass |
175
|
|
|
|
176
|
|
|
#: hyperparameters accociated with algorithms |
177
|
|
|
ALGORITHM_HYPERPARAMETER = AlgorithmHyperparameterRole() |
178
|
|
|
|
179
|
|
|
|
180
|
|
|
class AlgorithmBufferRole(AlgorithmStateRole): |
181
|
|
|
pass |
182
|
|
|
|
183
|
|
|
#: buffers accociated with algorithms |
184
|
|
|
ALGORITHM_BUFFER = AlgorithmBufferRole() |
185
|
|
|
|
186
|
|
|
|
187
|
|
|
class BatchNormPopulationStatisticsRole(PersistentRole): |
188
|
|
|
pass |
189
|
|
|
|
190
|
|
|
#: base role for batch normalization population statistics |
191
|
|
|
BATCH_NORM_POPULATION_STATISTICS = BatchNormPopulationStatisticsRole() |
192
|
|
|
|
193
|
|
|
|
194
|
|
|
class BatchNormPopulationMeanRole(BatchNormPopulationStatisticsRole): |
195
|
|
|
pass |
196
|
|
|
|
197
|
|
|
#: mean activations accumulated over the dataset |
198
|
|
|
BATCH_NORM_POPULATION_MEAN = BatchNormPopulationMeanRole() |
199
|
|
|
|
200
|
|
|
|
201
|
|
|
class BatchNormPopulationStdevRole(BatchNormPopulationStatisticsRole): |
202
|
|
|
pass |
203
|
|
|
|
204
|
|
|
#: standard deviations of activations accumulated over the dataset |
205
|
|
|
BATCH_NORM_POPULATION_STDEV = BatchNormPopulationStdevRole() |
206
|
|
|
|
207
|
|
|
|
208
|
|
|
class BatchNormGraphVariableRole(VariableRole): |
209
|
|
|
pass |
210
|
|
|
|
211
|
|
|
#: base for roles used for within-graph batch normalization replacement |
212
|
|
|
BATCH_NORM_GRAPH_VARIABLE = BatchNormGraphVariableRole() |
213
|
|
|
|
214
|
|
|
|
215
|
|
|
class BatchNormOffsetRole(BatchNormGraphVariableRole): |
216
|
|
|
pass |
217
|
|
|
|
218
|
|
|
#: offset applied in a BatchNormalization application (or its |
219
|
|
|
# batch-normalized replacement) |
220
|
|
|
BATCH_NORM_OFFSET = BatchNormOffsetRole() |
221
|
|
|
|
222
|
|
|
|
223
|
|
|
class BatchNormDivisorRole(BatchNormGraphVariableRole): |
224
|
|
|
pass |
225
|
|
|
|
226
|
|
|
#: divisor applied in a BatchNormalization application (or its |
227
|
|
|
# batch-normalized replacement) |
228
|
|
|
BATCH_NORM_DIVISOR = BatchNormDivisorRole() |
229
|
|
|
|
230
|
|
|
|
231
|
|
|
class BatchNormMinibatchEstimateRole(BatchNormGraphVariableRole): |
232
|
|
|
pass |
233
|
|
|
|
234
|
|
|
#: role added to variables that are the result of a batch normalization |
235
|
|
|
# replacement, rather than the original population statistics variables. |
236
|
|
|
BATCH_NORM_MINIBATCH_ESTIMATE = BatchNormMinibatchEstimateRole() |
237
|
|
|
|
238
|
|
|
|
239
|
|
|
class BatchNormScaleParameterRole(ParameterRole): |
240
|
|
|
pass |
241
|
|
|
|
242
|
|
|
#: role given to the scale parameter, referred to as "scale" in the |
243
|
|
|
# batch normalization manuscript, applied after normalizing. |
244
|
|
|
BATCH_NORM_SCALE_PARAMETER = BatchNormScaleParameterRole() |
245
|
|
|
|
246
|
|
|
|
247
|
|
|
class BatchNormShiftParameterRole(BiasRole): |
248
|
|
|
pass |
249
|
|
|
|
250
|
|
|
#: role given to the shift parameter, referred to as "beta" in the |
251
|
|
|
# batch normalization manuscript, applied after normalizing and scaling. |
252
|
|
|
# Inherits from BIAS, because there really is no functional difference |
253
|
|
|
# with a normal bias, and indeed these are the only biases present |
254
|
|
|
# inside a BatchNormalizedMLP. |
255
|
|
|
BATCH_NORM_SHIFT_PARAMETER = BatchNormShiftParameterRole() |
256
|
|
|
|