Completed
Pull Request — master (#1030)
by
unknown
04:44
created

VariableRole.__hash__()   A

Complexity

Conditions 1

Size

Total Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
import re
2
3
4
def add_role(var, role):
5
    r"""Add a role to a given Theano variable.
6
7
    Parameters
8
    ----------
9
    var : :class:`~tensor.TensorVariable`
10
        The variable to assign the new role to.
11
    role : :class:`.VariableRole` instance
12
13
    Notes
14
    -----
15
    Some roles are subroles of others (e.g. :const:`WEIGHT` is a subrole
16
    of :const:`PARAMETER`). This function will not add a role if a more
17
    specific role has already been added. If you need to replace a role
18
    with a parent role (e.g. replace :const:`WEIGHT` with
19
    :const:`PARAMETER`) you must do so manually.
20
21
    Examples
22
    --------
23
    >>> from theano import tensor
24
    >>> W = tensor.matrix()
25
    >>> from blocks.roles import PARAMETER, WEIGHT
26
    >>> add_role(W, PARAMETER)
27
    >>> print(*W.tag.roles)
28
    PARAMETER
29
    >>> add_role(W, WEIGHT)
30
    >>> print(*W.tag.roles)
31
    WEIGHT
32
    >>> add_role(W, PARAMETER)
33
    >>> print(*W.tag.roles)
34
    WEIGHT
35
36
    """
37
    roles = getattr(var.tag, 'roles', [])
38
    roles = [old_role for old_role in roles
39
             if not isinstance(role, old_role.__class__)]
40
    if not any(isinstance(old_role, role.__class__) for old_role in roles):
41
        roles += [role]
42
    var.tag.roles = roles
43
44
45
def has_roles(var, roles, match_all=False):
46
    r"""Test if a variable has given roles taking subroles into account.
47
48
    Parameters
49
    ----------
50
    var : :class:`~tensor.TensorVariable`
51
        Variable being queried.
52
    roles : an iterable of :class:`.VariableRole` instances.
53
    match_all : bool, optional
54
        If ``True``, checks if the variable has all given roles.
55
        If ``False``, any of the roles is sufficient.
56
        ``False`` by default.
57
58
    """
59
    var_roles = getattr(var.tag, 'roles', [])
60
    matches = (any(isinstance(var_role, role.__class__) for
61
                   var_role in var_roles) for role in roles)
62
    return all(matches) if match_all else any(matches)
63
64
65
class VariableRole(object):
66
    """Base class for all variable roles."""
67
    def __eq__(self, other):
68
        return self.__class__ == other.__class__
69
70
    def __repr__(self):
71
        return re.sub(r'(?!^)([A-Z]+)', r'_\1',
72
                      self.__class__.__name__[:-4]).upper()
73
74
    def __hash__(self):
75
        return hash(str(self))
76
77
78
class InputRole(VariableRole):
79
    pass
80
81
#: The input of a :class:`.Brick`
82
INPUT = InputRole()
83
84
85
class OutputRole(VariableRole):
86
    pass
87
88
#: The output of a :class:`.Brick`
89
OUTPUT = OutputRole()
90
91
92
class CostRole(VariableRole):
93
    pass
94
95
#: A scalar cost that can be used to train or regularize
96
COST = CostRole()
97
98
99
class PersistentRole(VariableRole):
100
    pass
101
102
# Any persistent quantity that should be saved as part of the model
103
PERSISTENT = PersistentRole()
104
105
106
class ParameterRole(PersistentRole):
107
    pass
108
109
#: A parameter of the model
110
PARAMETER = ParameterRole()
111
112
113
class AuxiliaryRole(VariableRole):
114
    pass
115
116
#: Variables added to the graph as annotations
117
AUXILIARY = AuxiliaryRole()
118
119
120
class WeightRole(ParameterRole):
121
    pass
122
123
#: The weight matrices of linear transformations
124
WEIGHT = WeightRole()
125
126
127
class BiasRole(ParameterRole):
128
    pass
129
130
#: Biases of linear transformations
131
BIAS = BiasRole()
132
133
134
class InitialStateRole(ParameterRole):
135
    pass
136
137
#: Initial state of a recurrent network
138
INITIAL_STATE = InitialStateRole()
139
140
141
class FilterRole(WeightRole):
142
    pass
143
144
#: The filters (kernels) of a convolution operation
145
FILTER = FilterRole()
146
147
148
class DropoutRole(VariableRole):
149
    pass
150
151
#: Inputs with applied dropout
152
DROPOUT = DropoutRole()
153
154
155
class CollectedRole(VariableRole):
156
    pass
157
158
#: The replacement of a variable collected into a single shared variable
159
COLLECTED = CollectedRole()
160
161
162
class CollectorRole(ParameterRole):
163
    pass
164
165
#: A collection of parameters combined into a single shared variable
166
COLLECTOR = CollectorRole()
167
168
169
class AlgorithmStateRole(VariableRole):
170
    pass
171
172
#: Shared variables used in algorithms updates
173
ALGORITHM_STATE = AlgorithmStateRole()
174
175
176
class AlgorithmHyperparameterRole(AlgorithmStateRole):
177
    pass
178
179
#: hyperparameters accociated with algorithms
180
ALGORITHM_HYPERPARAMETER = AlgorithmHyperparameterRole()
181
182
183
class AlgorithmBufferRole(AlgorithmStateRole):
184
    pass
185
186
#: buffers accociated with algorithms
187
ALGORITHM_BUFFER = AlgorithmBufferRole()
188
189
190
class BatchNormPopulationStatisticsRole(PersistentRole):
191
    pass
192
193
#: base role for batch normalization population statistics
194
BATCH_NORM_POPULATION_STATISTICS = BatchNormPopulationStatisticsRole()
195
196
197
class BatchNormPopulationMeanRole(BatchNormPopulationStatisticsRole):
198
    pass
199
200
#: mean activations accumulated over the dataset
201
BATCH_NORM_POPULATION_MEAN = BatchNormPopulationMeanRole()
202
203
204
class BatchNormPopulationStdevRole(BatchNormPopulationStatisticsRole):
205
    pass
206
207
#: standard deviations of activations accumulated over the dataset
208
BATCH_NORM_POPULATION_STDEV = BatchNormPopulationStdevRole()
209
210
211
class BatchNormGraphVariableRole(VariableRole):
212
    pass
213
214
#: base for roles used for within-graph batch normalization replacement
215
BATCH_NORM_GRAPH_VARIABLE = BatchNormGraphVariableRole()
216
217
218
class BatchNormOffsetRole(BatchNormGraphVariableRole):
219
    pass
220
221
#: offset applied in a BatchNormalization application (or its
222
#  batch-normalized replacement)
223
BATCH_NORM_OFFSET = BatchNormOffsetRole()
224
225
226
class BatchNormDivisorRole(BatchNormGraphVariableRole):
227
    pass
228
229
#: divisor applied in a BatchNormalization application (or its
230
#  batch-normalized replacement)
231
BATCH_NORM_DIVISOR = BatchNormDivisorRole()
232
233
234
class BatchNormMinibatchEstimateRole(BatchNormGraphVariableRole):
235
    pass
236
237
#: role added to variables that are the result of a batch normalization
238
#  replacement, rather than the original population statistics variables.
239
BATCH_NORM_MINIBATCH_ESTIMATE = BatchNormMinibatchEstimateRole()
240
241
242
class BatchNormScaleParameterRole(ParameterRole):
243
    pass
244
245
#: role given to the scale parameter, referred to as "scale" in the
246
# batch normalization manuscript, applied after normalizing.
247
BATCH_NORM_SCALE_PARAMETER = BatchNormScaleParameterRole()
248
249
250
class BatchNormShiftParameterRole(BiasRole):
251
    pass
252
253
#: role given to the shift parameter, referred to as "beta" in the
254
# batch normalization manuscript, applied after normalizing and scaling.
255
# Inherits from BIAS, because there really is no functional difference
256
# with a normal bias, and indeed these are the only biases present
257
# inside a BatchNormalizedMLP.
258
BATCH_NORM_SHIFT_PARAMETER = BatchNormShiftParameterRole()
259