VariableRole   A
last analyzed

Complexity

Total Complexity 2

Size/Duplication

Total Lines 8
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
dl 0
loc 8
rs 10
c 0
b 0
f 0
wmc 2

2 Methods

Rating   Name   Duplication   Size   Complexity  
A __repr__() 0 3 1
A __eq__() 0 2 1
1
import re
2
3
4
def add_role(var, role):
5
    r"""Add a role to a given Theano variable.
6
7
    Parameters
8
    ----------
9
    var : :class:`~tensor.TensorVariable`
10
        The variable to assign the new role to.
11
    role : :class:`.VariableRole` instance
12
13
    Notes
14
    -----
15
    Some roles are subroles of others (e.g. :const:`WEIGHT` is a subrole
16
    of :const:`PARAMETER`). This function will not add a role if a more
17
    specific role has already been added. If you need to replace a role
18
    with a parent role (e.g. replace :const:`WEIGHT` with
19
    :const:`PARAMETER`) you must do so manually.
20
21
    Examples
22
    --------
23
    >>> from theano import tensor
24
    >>> W = tensor.matrix()
25
    >>> from blocks.roles import PARAMETER, WEIGHT
26
    >>> add_role(W, PARAMETER)
27
    >>> print(*W.tag.roles)
28
    PARAMETER
29
    >>> add_role(W, WEIGHT)
30
    >>> print(*W.tag.roles)
31
    WEIGHT
32
    >>> add_role(W, PARAMETER)
33
    >>> print(*W.tag.roles)
34
    WEIGHT
35
36
    """
37
    roles = getattr(var.tag, 'roles', [])
38
    roles = [old_role for old_role in roles
39
             if not isinstance(role, old_role.__class__)]
40
    if not any(isinstance(old_role, role.__class__) for old_role in roles):
41
        roles += [role]
42
    var.tag.roles = roles
43
44
45
def has_roles(var, roles, match_all=False):
46
    r"""Test if a variable has given roles taking subroles into account.
47
48
    Parameters
49
    ----------
50
    var : :class:`~tensor.TensorVariable`
51
        Variable being queried.
52
    roles : an iterable of :class:`.VariableRole` instances.
53
    match_all : bool, optional
54
        If ``True``, checks if the variable has all given roles.
55
        If ``False``, any of the roles is sufficient.
56
        ``False`` by default.
57
58
    """
59
    var_roles = getattr(var.tag, 'roles', [])
60
    matches = (any(isinstance(var_role, role.__class__) for
61
                   var_role in var_roles) for role in roles)
62
    return all(matches) if match_all else any(matches)
63
64
65
class VariableRole(object):
66
    """Base class for all variable roles."""
67
    def __eq__(self, other):
68
        return self.__class__ == other.__class__
69
70
    def __repr__(self):
71
        return re.sub(r'(?!^)([A-Z]+)', r'_\1',
72
                      self.__class__.__name__[:-4]).upper()
73
74
75
class InputRole(VariableRole):
76
    pass
77
78
#: The input of a :class:`~.bricks.Brick`
79
INPUT = InputRole()
80
81
82
class OutputRole(VariableRole):
83
    pass
84
85
#: The output of a :class:`~.bricks.Brick`
86
OUTPUT = OutputRole()
87
88
89
class CostRole(VariableRole):
90
    pass
91
92
#: A scalar cost that can be used to train or regularize
93
COST = CostRole()
94
95
96
class PersistentRole(VariableRole):
97
    pass
98
99
# Any persistent quantity that should be saved as part of the model
100
PERSISTENT = PersistentRole()
101
102
103
class ParameterRole(PersistentRole):
104
    pass
105
106
#: A parameter of the model
107
PARAMETER = ParameterRole()
108
109
110
class AuxiliaryRole(VariableRole):
111
    pass
112
113
#: Variables added to the graph as annotations
114
AUXILIARY = AuxiliaryRole()
115
116
117
class WeightRole(ParameterRole):
118
    pass
119
120
#: The weight matrices of linear transformations
121
WEIGHT = WeightRole()
122
123
124
class BiasRole(ParameterRole):
125
    pass
126
127
#: Biases of linear transformations
128
BIAS = BiasRole()
129
130
131
class InitialStateRole(ParameterRole):
132
    pass
133
134
#: Initial state of a recurrent network
135
INITIAL_STATE = InitialStateRole()
136
137
138
class FilterRole(WeightRole):
139
    pass
140
141
#: The filters (kernels) of a convolution operation
142
FILTER = FilterRole()
143
144
145
class DropoutRole(VariableRole):
146
    pass
147
148
#: Inputs with applied dropout
149
DROPOUT = DropoutRole()
150
151
152
class CollectedRole(VariableRole):
153
    pass
154
155
#: The replacement of a variable collected into a single shared variable
156
COLLECTED = CollectedRole()
157
158
159
class CollectorRole(ParameterRole):
160
    pass
161
162
#: A collection of parameters combined into a single shared variable
163
COLLECTOR = CollectorRole()
164
165
166
class AlgorithmStateRole(VariableRole):
167
    pass
168
169
#: Shared variables used in algorithms updates
170
ALGORITHM_STATE = AlgorithmStateRole()
171
172
173
class AlgorithmHyperparameterRole(AlgorithmStateRole):
174
    pass
175
176
#: hyperparameters accociated with algorithms
177
ALGORITHM_HYPERPARAMETER = AlgorithmHyperparameterRole()
178
179
180
class AlgorithmBufferRole(AlgorithmStateRole):
181
    pass
182
183
#: buffers accociated with algorithms
184
ALGORITHM_BUFFER = AlgorithmBufferRole()
185
186
187
class BatchNormPopulationStatisticsRole(PersistentRole):
188
    pass
189
190
#: base role for batch normalization population statistics
191
BATCH_NORM_POPULATION_STATISTICS = BatchNormPopulationStatisticsRole()
192
193
194
class BatchNormPopulationMeanRole(BatchNormPopulationStatisticsRole):
195
    pass
196
197
#: mean activations accumulated over the dataset
198
BATCH_NORM_POPULATION_MEAN = BatchNormPopulationMeanRole()
199
200
201
class BatchNormPopulationStdevRole(BatchNormPopulationStatisticsRole):
202
    pass
203
204
#: standard deviations of activations accumulated over the dataset
205
BATCH_NORM_POPULATION_STDEV = BatchNormPopulationStdevRole()
206
207
208
class BatchNormGraphVariableRole(VariableRole):
209
    pass
210
211
#: base for roles used for within-graph batch normalization replacement
212
BATCH_NORM_GRAPH_VARIABLE = BatchNormGraphVariableRole()
213
214
215
class BatchNormOffsetRole(BatchNormGraphVariableRole):
216
    pass
217
218
#: offset applied in a BatchNormalization application (or its
219
#  batch-normalized replacement)
220
BATCH_NORM_OFFSET = BatchNormOffsetRole()
221
222
223
class BatchNormDivisorRole(BatchNormGraphVariableRole):
224
    pass
225
226
#: divisor applied in a BatchNormalization application (or its
227
#  batch-normalized replacement)
228
BATCH_NORM_DIVISOR = BatchNormDivisorRole()
229
230
231
class BatchNormMinibatchEstimateRole(BatchNormGraphVariableRole):
232
    pass
233
234
#: role added to variables that are the result of a batch normalization
235
#  replacement, rather than the original population statistics variables.
236
BATCH_NORM_MINIBATCH_ESTIMATE = BatchNormMinibatchEstimateRole()
237
238
239
class BatchNormScaleParameterRole(ParameterRole):
240
    pass
241
242
#: role given to the scale parameter, referred to as "scale" in the
243
# batch normalization manuscript, applied after normalizing.
244
BATCH_NORM_SCALE_PARAMETER = BatchNormScaleParameterRole()
245
246
247
class BatchNormShiftParameterRole(BiasRole):
248
    pass
249
250
#: role given to the shift parameter, referred to as "beta" in the
251
# batch normalization manuscript, applied after normalizing and scaling.
252
# Inherits from BIAS, because there really is no functional difference
253
# with a normal bias, and indeed these are the only biases present
254
# inside a BatchNormalizedMLP.
255
BATCH_NORM_SHIFT_PARAMETER = BatchNormShiftParameterRole()
256