|
1
|
|
|
from inspect import isclass |
|
2
|
|
|
import re |
|
3
|
|
|
|
|
4
|
|
|
from blocks.bricks.base import ApplicationCall, BoundApplication, Brick |
|
5
|
|
|
from blocks.roles import has_roles |
|
6
|
|
|
|
|
7
|
|
|
|
|
8
|
|
|
def get_annotation(var, cls): |
|
9
|
|
|
"""A helper function to retrieve an annotation of a particular type. |
|
10
|
|
|
|
|
11
|
|
|
Notes |
|
12
|
|
|
----- |
|
13
|
|
|
This function returns the first annotation of a particular type. If |
|
14
|
|
|
there are multiple--there shouldn't be--it will ignore them. |
|
15
|
|
|
|
|
16
|
|
|
""" |
|
17
|
|
|
for annotation in getattr(var.tag, 'annotations', []): |
|
18
|
|
|
if isinstance(annotation, cls): |
|
19
|
|
|
return annotation |
|
20
|
|
|
|
|
21
|
|
|
|
|
22
|
|
|
def get_brick(var): |
|
23
|
|
|
"""Retrieves the brick that created this variable. |
|
24
|
|
|
|
|
25
|
|
|
See :func:`get_annotation`. |
|
26
|
|
|
|
|
27
|
|
|
""" |
|
28
|
|
|
return get_annotation(var, Brick) |
|
29
|
|
|
|
|
30
|
|
|
|
|
31
|
|
|
def get_application_call(var): |
|
32
|
|
|
"""Retrieves the application call that created this variable. |
|
33
|
|
|
|
|
34
|
|
|
See :func:`get_annotation`. |
|
35
|
|
|
|
|
36
|
|
|
""" |
|
37
|
|
|
return get_annotation(var, ApplicationCall) |
|
38
|
|
|
|
|
39
|
|
|
|
|
40
|
|
|
class VariableFilter(object): |
|
41
|
|
|
"""Filters Theano variables based on a range of criteria. |
|
42
|
|
|
|
|
43
|
|
|
Parameters |
|
44
|
|
|
---------- |
|
45
|
|
|
roles : list of :class:`.VariableRole` instances, optional |
|
46
|
|
|
Matches any variable which has one of the roles given. |
|
47
|
|
|
bricks : list of :class:`~.bricks.Brick` classes or list of |
|
48
|
|
|
instances of :class:`~.bricks.Brick`, optional |
|
49
|
|
|
Matches any variable that is instance of any of the given classes |
|
50
|
|
|
or that is owned by any of the given brick instances. |
|
51
|
|
|
each_role : bool, optional |
|
52
|
|
|
If ``True``, the variable needs to have all given roles. If |
|
53
|
|
|
``False``, a variable matching any of the roles given will be |
|
54
|
|
|
returned. ``False`` by default. |
|
55
|
|
|
name : str, optional |
|
56
|
|
|
The variable name. The Blocks name (i.e. |
|
57
|
|
|
`x.tag.name`) is used. |
|
58
|
|
|
name_regex : str, optional |
|
59
|
|
|
A regular expression for the variable name. The Blocks name (i.e. |
|
60
|
|
|
`x.tag.name`) is used. |
|
61
|
|
|
theano_name : str, optional |
|
62
|
|
|
The variable name. The Theano name (i.e. |
|
63
|
|
|
`x.name`) is used. |
|
64
|
|
|
theano_name_regex : str, optional |
|
65
|
|
|
A regular expression for the variable name. The Theano name (i.e. |
|
66
|
|
|
`x.name`) is used. |
|
67
|
|
|
applications : list of :class:`.Application`, optional |
|
68
|
|
|
Matches a variable that was produced by any of the applications |
|
69
|
|
|
given. |
|
70
|
|
|
|
|
71
|
|
|
Notes |
|
72
|
|
|
----- |
|
73
|
|
|
Note that only auxiliary variables, parameters, inputs and outputs are |
|
74
|
|
|
tagged with the brick that created them. Other Theano variables that |
|
75
|
|
|
were created in the process of applying a brick will be filtered out. |
|
76
|
|
|
|
|
77
|
|
|
Note that technically speaking, bricks are able to have non-shared |
|
78
|
|
|
variables as parameters. For example, we can use the transpose of |
|
79
|
|
|
another weight matrix as the parameter of a particular brick. This |
|
80
|
|
|
means that in some unusual cases, filtering by the :const:`PARAMETER` |
|
81
|
|
|
role alone will not be enough to retrieve all trainable parameters in |
|
82
|
|
|
your model; you will need to filter out the shared variables from these |
|
83
|
|
|
(using e.g. :func:`is_shared_variable`). |
|
84
|
|
|
|
|
85
|
|
|
Examples |
|
86
|
|
|
-------- |
|
87
|
|
|
>>> from blocks.bricks import MLP, Linear, Logistic, Identity |
|
88
|
|
|
>>> from blocks.roles import BIAS |
|
89
|
|
|
>>> mlp = MLP(activations=[Identity(), Logistic()], dims=[20, 10, 20]) |
|
90
|
|
|
>>> from theano import tensor |
|
91
|
|
|
>>> x = tensor.matrix() |
|
92
|
|
|
>>> y_hat = mlp.apply(x) |
|
93
|
|
|
>>> from blocks.graph import ComputationGraph |
|
94
|
|
|
>>> cg = ComputationGraph(y_hat) |
|
95
|
|
|
>>> from blocks.filter import VariableFilter |
|
96
|
|
|
>>> var_filter = VariableFilter(roles=[BIAS], |
|
97
|
|
|
... bricks=[mlp.linear_transformations[0]]) |
|
98
|
|
|
>>> var_filter(cg.variables) |
|
99
|
|
|
[b] |
|
100
|
|
|
|
|
101
|
|
|
""" |
|
102
|
|
|
def __init__(self, roles=None, bricks=None, each_role=False, name=None, |
|
103
|
|
|
name_regex=None, theano_name=None, theano_name_regex=None, |
|
104
|
|
|
applications=None): |
|
105
|
|
|
if bricks is not None and not all( |
|
106
|
|
|
isinstance(brick, Brick) or issubclass(brick, Brick) |
|
107
|
|
|
for brick in bricks): |
|
108
|
|
|
raise ValueError('`bricks` should be a list of Bricks') |
|
109
|
|
|
if applications is not None and not all( |
|
110
|
|
|
isinstance(application, BoundApplication) |
|
111
|
|
|
for application in applications): |
|
112
|
|
|
raise ValueError('`applications` should be a list of ' |
|
113
|
|
|
'BoundApplications') |
|
114
|
|
|
self.roles = roles |
|
115
|
|
|
self.bricks = bricks |
|
116
|
|
|
self.each_role = each_role |
|
117
|
|
|
self.name = name |
|
118
|
|
|
self.name_regex = name_regex |
|
119
|
|
|
self.theano_name = theano_name |
|
120
|
|
|
self.theano_name_regex = theano_name_regex |
|
121
|
|
|
self.applications = applications |
|
122
|
|
|
|
|
123
|
|
|
def __call__(self, variables): |
|
124
|
|
|
"""Filter the given variables. |
|
125
|
|
|
|
|
126
|
|
|
Parameters |
|
127
|
|
|
---------- |
|
128
|
|
|
variables : list of :class:`~tensor.TensorVariable` |
|
129
|
|
|
|
|
130
|
|
|
""" |
|
131
|
|
|
if self.roles: |
|
132
|
|
|
variables = [var for var in variables |
|
133
|
|
|
if has_roles(var, self.roles, self.each_role)] |
|
134
|
|
|
if self.bricks is not None: |
|
135
|
|
|
filtered_variables = [] |
|
136
|
|
|
for var in variables: |
|
137
|
|
|
var_brick = get_brick(var) |
|
138
|
|
|
if var_brick is None: |
|
139
|
|
|
continue |
|
140
|
|
|
for brick in self.bricks: |
|
141
|
|
|
if isclass(brick) and isinstance(var_brick, brick): |
|
142
|
|
|
filtered_variables.append(var) |
|
143
|
|
|
break |
|
144
|
|
|
elif isinstance(brick, Brick) and var_brick is brick: |
|
145
|
|
|
filtered_variables.append(var) |
|
146
|
|
|
break |
|
147
|
|
|
variables = filtered_variables |
|
148
|
|
|
if self.name: |
|
149
|
|
|
variables = [var for var in variables |
|
150
|
|
|
if hasattr(var.tag, 'name') and |
|
151
|
|
|
self.name == var.tag.name] |
|
152
|
|
|
if self.name_regex: |
|
153
|
|
|
variables = [var for var in variables |
|
154
|
|
|
if hasattr(var.tag, 'name') and |
|
155
|
|
|
re.match(self.name_regex, var.tag.name)] |
|
156
|
|
|
if self.theano_name: |
|
157
|
|
|
variables = [var for var in variables |
|
158
|
|
|
if (var.name is not None) and |
|
159
|
|
|
self.theano_name == var.name] |
|
160
|
|
|
if self.theano_name_regex: |
|
161
|
|
|
variables = [var for var in variables |
|
162
|
|
|
if (var.name is not None) and |
|
163
|
|
|
re.match(self.theano_name_regex, var.name)] |
|
164
|
|
|
if self.applications: |
|
165
|
|
|
variables = [var for var in variables |
|
166
|
|
|
if get_application_call(var) and |
|
167
|
|
|
get_application_call(var).application in |
|
168
|
|
|
self.applications] |
|
169
|
|
|
return variables |
|
170
|
|
|
|