1
|
|
|
from inspect import isclass |
2
|
|
|
import re |
3
|
|
|
|
4
|
|
|
from blocks.bricks.base import ApplicationCall, BoundApplication, Brick |
5
|
|
|
from blocks.roles import has_roles |
6
|
|
|
|
7
|
|
|
|
8
|
|
|
def get_annotation(var, cls): |
9
|
|
|
"""A helper function to retrieve an annotation of a particular type. |
10
|
|
|
|
11
|
|
|
Notes |
12
|
|
|
----- |
13
|
|
|
This function returns the first annotation of a particular type. If |
14
|
|
|
there are multiple--there shouldn't be--it will ignore them. |
15
|
|
|
|
16
|
|
|
""" |
17
|
|
|
for annotation in getattr(var.tag, 'annotations', []): |
18
|
|
|
if isinstance(annotation, cls): |
19
|
|
|
return annotation |
20
|
|
|
|
21
|
|
|
|
22
|
|
|
def get_brick(var): |
23
|
|
|
"""Retrieves the brick that created this variable. |
24
|
|
|
|
25
|
|
|
See :func:`get_annotation`. |
26
|
|
|
|
27
|
|
|
""" |
28
|
|
|
return get_annotation(var, Brick) |
29
|
|
|
|
30
|
|
|
|
31
|
|
|
def get_application_call(var): |
32
|
|
|
"""Retrieves the application call that created this variable. |
33
|
|
|
|
34
|
|
|
See :func:`get_annotation`. |
35
|
|
|
|
36
|
|
|
""" |
37
|
|
|
return get_annotation(var, ApplicationCall) |
38
|
|
|
|
39
|
|
|
|
40
|
|
|
class VariableFilter(object): |
41
|
|
|
"""Filters Theano variables based on a range of criteria. |
42
|
|
|
|
43
|
|
|
Parameters |
44
|
|
|
---------- |
45
|
|
|
roles : list of :class:`.VariableRole` instances, optional |
46
|
|
|
Matches any variable which has one of the roles given. |
47
|
|
|
bricks : list of :class:`~.bricks.Brick` classes or list of |
48
|
|
|
instances of :class:`~.bricks.Brick`, optional |
49
|
|
|
Matches any variable that is instance of any of the given classes |
50
|
|
|
or that is owned by any of the given brick instances. |
51
|
|
|
each_role : bool, optional |
52
|
|
|
If ``True``, the variable needs to have all given roles. If |
53
|
|
|
``False``, a variable matching any of the roles given will be |
54
|
|
|
returned. ``False`` by default. |
55
|
|
|
name : str, optional |
56
|
|
|
The variable name. The Blocks name (i.e. |
57
|
|
|
`x.tag.name`) is used. |
58
|
|
|
name_regex : str, optional |
59
|
|
|
A regular expression for the variable name. The Blocks name (i.e. |
60
|
|
|
`x.tag.name`) is used. |
61
|
|
|
theano_name : str, optional |
62
|
|
|
The variable name. The Theano name (i.e. |
63
|
|
|
`x.name`) is used. |
64
|
|
|
theano_name_regex : str, optional |
65
|
|
|
A regular expression for the variable name. The Theano name (i.e. |
66
|
|
|
`x.name`) is used. |
67
|
|
|
applications : list of :class:`.Application`, optional |
68
|
|
|
Matches a variable that was produced by any of the applications |
69
|
|
|
given. |
70
|
|
|
|
71
|
|
|
Notes |
72
|
|
|
----- |
73
|
|
|
Note that only auxiliary variables, parameters, inputs and outputs are |
74
|
|
|
tagged with the brick that created them. Other Theano variables that |
75
|
|
|
were created in the process of applying a brick will be filtered out. |
76
|
|
|
|
77
|
|
|
Note that technically speaking, bricks are able to have non-shared |
78
|
|
|
variables as parameters. For example, we can use the transpose of |
79
|
|
|
another weight matrix as the parameter of a particular brick. This |
80
|
|
|
means that in some unusual cases, filtering by the :const:`PARAMETER` |
81
|
|
|
role alone will not be enough to retrieve all trainable parameters in |
82
|
|
|
your model; you will need to filter out the shared variables from these |
83
|
|
|
(using e.g. :func:`is_shared_variable`). |
84
|
|
|
|
85
|
|
|
Examples |
86
|
|
|
-------- |
87
|
|
|
>>> from blocks.bricks import MLP, Linear, Logistic, Identity |
88
|
|
|
>>> from blocks.roles import BIAS |
89
|
|
|
>>> mlp = MLP(activations=[Identity(), Logistic()], dims=[20, 10, 20]) |
90
|
|
|
>>> from theano import tensor |
91
|
|
|
>>> x = tensor.matrix() |
92
|
|
|
>>> y_hat = mlp.apply(x) |
93
|
|
|
>>> from blocks.graph import ComputationGraph |
94
|
|
|
>>> cg = ComputationGraph(y_hat) |
95
|
|
|
>>> from blocks.filter import VariableFilter |
96
|
|
|
>>> var_filter = VariableFilter(roles=[BIAS], |
97
|
|
|
... bricks=[mlp.linear_transformations[0]]) |
98
|
|
|
>>> var_filter(cg.variables) |
99
|
|
|
[b] |
100
|
|
|
|
101
|
|
|
""" |
102
|
|
|
def __init__(self, roles=None, bricks=None, each_role=False, name=None, |
103
|
|
|
name_regex=None, theano_name=None, theano_name_regex=None, |
104
|
|
|
applications=None): |
105
|
|
|
if bricks is not None and not all( |
106
|
|
|
isinstance(brick, Brick) or issubclass(brick, Brick) |
107
|
|
|
for brick in bricks): |
108
|
|
|
raise ValueError('`bricks` should be a list of Bricks') |
109
|
|
|
if applications is not None and not all( |
110
|
|
|
isinstance(application, BoundApplication) |
111
|
|
|
for application in applications): |
112
|
|
|
raise ValueError('`applications` should be a list of ' |
113
|
|
|
'BoundApplications') |
114
|
|
|
self.roles = roles |
115
|
|
|
self.bricks = bricks |
116
|
|
|
self.each_role = each_role |
117
|
|
|
self.name = name |
118
|
|
|
self.name_regex = name_regex |
119
|
|
|
self.theano_name = theano_name |
120
|
|
|
self.theano_name_regex = theano_name_regex |
121
|
|
|
self.applications = applications |
122
|
|
|
|
123
|
|
|
def __call__(self, variables): |
124
|
|
|
"""Filter the given variables. |
125
|
|
|
|
126
|
|
|
Parameters |
127
|
|
|
---------- |
128
|
|
|
variables : list of :class:`~tensor.TensorVariable` |
129
|
|
|
|
130
|
|
|
""" |
131
|
|
|
if self.roles: |
132
|
|
|
variables = [var for var in variables |
133
|
|
|
if has_roles(var, self.roles, self.each_role)] |
134
|
|
|
if self.bricks is not None: |
135
|
|
|
filtered_variables = [] |
136
|
|
|
for var in variables: |
137
|
|
|
var_brick = get_brick(var) |
138
|
|
|
if var_brick is None: |
139
|
|
|
continue |
140
|
|
|
for brick in self.bricks: |
141
|
|
|
if isclass(brick) and isinstance(var_brick, brick): |
142
|
|
|
filtered_variables.append(var) |
143
|
|
|
break |
144
|
|
|
elif isinstance(brick, Brick) and var_brick is brick: |
145
|
|
|
filtered_variables.append(var) |
146
|
|
|
break |
147
|
|
|
variables = filtered_variables |
148
|
|
|
if self.name: |
149
|
|
|
variables = [var for var in variables |
150
|
|
|
if hasattr(var.tag, 'name') and |
151
|
|
|
self.name == var.tag.name] |
152
|
|
|
if self.name_regex: |
153
|
|
|
variables = [var for var in variables |
154
|
|
|
if hasattr(var.tag, 'name') and |
155
|
|
|
re.match(self.name_regex, var.tag.name)] |
156
|
|
|
if self.theano_name: |
157
|
|
|
variables = [var for var in variables |
158
|
|
|
if (var.name is not None) and |
159
|
|
|
self.theano_name == var.name] |
160
|
|
|
if self.theano_name_regex: |
161
|
|
|
variables = [var for var in variables |
162
|
|
|
if (var.name is not None) and |
163
|
|
|
re.match(self.theano_name_regex, var.name)] |
164
|
|
|
if self.applications: |
165
|
|
|
variables = [var for var in variables |
166
|
|
|
if get_application_call(var) and |
167
|
|
|
get_application_call(var).application in |
168
|
|
|
self.applications] |
169
|
|
|
return variables |
170
|
|
|
|