VariableFilter   B
last analyzed

Complexity

Total Complexity 40

Size/Duplication

Total Lines 130
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
dl 0
loc 130
rs 8.2608
c 0
b 0
f 0
wmc 40

2 Methods

Rating   Name   Duplication   Size   Complexity  
C __init__() 0 20 8
F __call__() 0 47 32

How to fix   Complexity   

Complex Class

Complex classes like VariableFilter often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from inspect import isclass
2
import re
3
4
from blocks.bricks.base import ApplicationCall, BoundApplication, Brick
5
from blocks.roles import has_roles
6
7
8
def get_annotation(var, cls):
9
    """A helper function to retrieve an annotation of a particular type.
10
11
    Notes
12
    -----
13
    This function returns the first annotation of a particular type. If
14
    there are multiple--there shouldn't be--it will ignore them.
15
16
    """
17
    for annotation in getattr(var.tag, 'annotations', []):
18
        if isinstance(annotation, cls):
19
            return annotation
20
21
22
def get_brick(var):
23
    """Retrieves the brick that created this variable.
24
25
    See :func:`get_annotation`.
26
27
    """
28
    return get_annotation(var, Brick)
29
30
31
def get_application_call(var):
32
    """Retrieves the application call that created this variable.
33
34
    See :func:`get_annotation`.
35
36
    """
37
    return get_annotation(var, ApplicationCall)
38
39
40
class VariableFilter(object):
41
    """Filters Theano variables based on a range of criteria.
42
43
    Parameters
44
    ----------
45
    roles : list of :class:`.VariableRole` instances, optional
46
        Matches any variable which has one of the roles given.
47
    bricks : list of :class:`~.bricks.Brick` classes or list of
48
             instances of :class:`~.bricks.Brick`, optional
49
        Matches any variable that is instance of any of the given classes
50
        or that is owned by any of the given brick instances.
51
    each_role : bool, optional
52
        If ``True``, the variable needs to have all given roles.  If
53
        ``False``, a variable matching any of the roles given will be
54
        returned. ``False`` by default.
55
    name : str, optional
56
        The variable name. The Blocks name (i.e.
57
        `x.tag.name`) is used.
58
    name_regex : str, optional
59
        A regular expression for the variable name. The Blocks name (i.e.
60
        `x.tag.name`) is used.
61
    theano_name : str, optional
62
        The variable name. The Theano name (i.e.
63
        `x.name`) is used.
64
    theano_name_regex : str, optional
65
        A regular expression for the variable name. The Theano name (i.e.
66
        `x.name`) is used.
67
    applications : list of :class:`.Application`, optional
68
        Matches a variable that was produced by any of the applications
69
        given.
70
71
    Notes
72
    -----
73
    Note that only auxiliary variables, parameters, inputs and outputs are
74
    tagged with the brick that created them. Other Theano variables that
75
    were created in the process of applying a brick will be filtered out.
76
77
    Note that technically speaking, bricks are able to have non-shared
78
    variables as parameters. For example, we can use the transpose of
79
    another weight matrix as the parameter of a particular brick. This
80
    means that in some unusual cases, filtering by the :const:`PARAMETER`
81
    role alone will not be enough to retrieve all trainable parameters in
82
    your model; you will need to filter out the shared variables from these
83
    (using e.g. :func:`is_shared_variable`).
84
85
    Examples
86
    --------
87
    >>> from blocks.bricks import MLP, Linear, Logistic, Identity
88
    >>> from blocks.roles import BIAS
89
    >>> mlp = MLP(activations=[Identity(), Logistic()], dims=[20, 10, 20])
90
    >>> from theano import tensor
91
    >>> x = tensor.matrix()
92
    >>> y_hat = mlp.apply(x)
93
    >>> from blocks.graph import ComputationGraph
94
    >>> cg = ComputationGraph(y_hat)
95
    >>> from blocks.filter import VariableFilter
96
    >>> var_filter = VariableFilter(roles=[BIAS],
97
    ...                             bricks=[mlp.linear_transformations[0]])
98
    >>> var_filter(cg.variables)
99
    [b]
100
101
    """
102
    def __init__(self, roles=None, bricks=None, each_role=False, name=None,
103
                 name_regex=None, theano_name=None, theano_name_regex=None,
104
                 applications=None):
105
        if bricks is not None and not all(
106
            isinstance(brick, Brick) or issubclass(brick, Brick)
107
                for brick in bricks):
108
            raise ValueError('`bricks` should be a list of Bricks')
109
        if applications is not None and not all(
110
            isinstance(application, BoundApplication)
111
                for application in applications):
112
            raise ValueError('`applications` should be a list of '
113
                             'BoundApplications')
114
        self.roles = roles
115
        self.bricks = bricks
116
        self.each_role = each_role
117
        self.name = name
118
        self.name_regex = name_regex
119
        self.theano_name = theano_name
120
        self.theano_name_regex = theano_name_regex
121
        self.applications = applications
122
123
    def __call__(self, variables):
124
        """Filter the given variables.
125
126
        Parameters
127
        ----------
128
        variables : list of :class:`~tensor.TensorVariable`
129
130
        """
131
        if self.roles:
132
            variables = [var for var in variables
133
                         if has_roles(var, self.roles, self.each_role)]
134
        if self.bricks is not None:
135
            filtered_variables = []
136
            for var in variables:
137
                var_brick = get_brick(var)
138
                if var_brick is None:
139
                    continue
140
                for brick in self.bricks:
141
                    if isclass(brick) and isinstance(var_brick, brick):
142
                        filtered_variables.append(var)
143
                        break
144
                    elif isinstance(brick, Brick) and var_brick is brick:
145
                        filtered_variables.append(var)
146
                        break
147
            variables = filtered_variables
148
        if self.name:
149
            variables = [var for var in variables
150
                         if hasattr(var.tag, 'name') and
151
                         self.name == var.tag.name]
152
        if self.name_regex:
153
            variables = [var for var in variables
154
                         if hasattr(var.tag, 'name') and
155
                         re.match(self.name_regex, var.tag.name)]
156
        if self.theano_name:
157
            variables = [var for var in variables
158
                         if (var.name is not None) and
159
                         self.theano_name == var.name]
160
        if self.theano_name_regex:
161
            variables = [var for var in variables
162
                         if (var.name is not None) and
163
                         re.match(self.theano_name_regex, var.name)]
164
        if self.applications:
165
            variables = [var for var in variables
166
                         if get_application_call(var) and
167
                         get_application_call(var).application in
168
                         self.applications]
169
        return variables
170