| Total Complexity | 40 |
| Total Lines | 130 |
| Duplicated Lines | 0 % |
Complex classes like VariableFilter often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | from inspect import isclass |
||
| 40 | class VariableFilter(object): |
||
| 41 | """Filters Theano variables based on a range of criteria. |
||
| 42 | |||
| 43 | Parameters |
||
| 44 | ---------- |
||
| 45 | roles : list of :class:`.VariableRole` instances, optional |
||
| 46 | Matches any variable which has one of the roles given. |
||
| 47 | bricks : list of :class:`.Brick` classes or list of instances of |
||
| 48 | :class:`.Brick`, optional |
||
| 49 | Matches any variable that is instance of any of the given classes |
||
| 50 | or that is owned by any of the given brick instances. |
||
| 51 | each_role : bool, optional |
||
| 52 | If ``True``, the variable needs to have all given roles. If |
||
| 53 | ``False``, a variable matching any of the roles given will be |
||
| 54 | returned. ``False`` by default. |
||
| 55 | name : str, optional |
||
| 56 | The variable name. The Blocks name (i.e. |
||
| 57 | `x.tag.name`) is used. |
||
| 58 | name_regex : str, optional |
||
| 59 | A regular expression for the variable name. The Blocks name (i.e. |
||
| 60 | `x.tag.name`) is used. |
||
| 61 | theano_name : str, optional |
||
| 62 | The variable name. The Theano name (i.e. |
||
| 63 | `x.name`) is used. |
||
| 64 | theano_name_regex : str, optional |
||
| 65 | A regular expression for the variable name. The Theano name (i.e. |
||
| 66 | `x.name`) is used. |
||
| 67 | applications : list of :class:`.Application`, optional |
||
| 68 | Matches a variable that was produced by any of the applications |
||
| 69 | given. |
||
| 70 | |||
| 71 | Notes |
||
| 72 | ----- |
||
| 73 | Note that only auxiliary variables, parameters, inputs and outputs are |
||
| 74 | tagged with the brick that created them. Other Theano variables that |
||
| 75 | were created in the process of applying a brick will be filtered out. |
||
| 76 | |||
| 77 | Note that technically speaking, bricks are able to have non-shared |
||
| 78 | variables as parameters. For example, we can use the transpose of |
||
| 79 | another weight matrix as the parameter of a particular brick. This |
||
| 80 | means that in some unusual cases, filtering by the :const:`PARAMETER` |
||
| 81 | role alone will not be enough to retrieve all trainable parameters in |
||
| 82 | your model; you will need to filter out the shared variables from these |
||
| 83 | (using e.g. :func:`is_shared_variable`). |
||
| 84 | |||
| 85 | Examples |
||
| 86 | -------- |
||
| 87 | >>> from blocks.bricks import MLP, Linear, Logistic, Identity |
||
| 88 | >>> from blocks.roles import BIAS |
||
| 89 | >>> mlp = MLP(activations=[Identity(), Logistic()], dims=[20, 10, 20]) |
||
| 90 | >>> from theano import tensor |
||
| 91 | >>> x = tensor.matrix() |
||
| 92 | >>> y_hat = mlp.apply(x) |
||
| 93 | >>> from blocks.graph import ComputationGraph |
||
| 94 | >>> cg = ComputationGraph(y_hat) |
||
| 95 | >>> from blocks.filter import VariableFilter |
||
| 96 | >>> var_filter = VariableFilter(roles=[BIAS], |
||
| 97 | ... bricks=[mlp.linear_transformations[0]]) |
||
| 98 | >>> var_filter(cg.variables) |
||
| 99 | [b] |
||
| 100 | |||
| 101 | """ |
||
| 102 | def __init__(self, roles=None, bricks=None, each_role=False, name=None, |
||
| 103 | name_regex=None, theano_name=None, theano_name_regex=None, |
||
| 104 | applications=None): |
||
| 105 | if bricks is not None and not all( |
||
| 106 | isinstance(brick, Brick) or issubclass(brick, Brick) |
||
| 107 | for brick in bricks): |
||
| 108 | raise ValueError('`bricks` should be a list of Bricks') |
||
| 109 | if applications is not None and not all( |
||
| 110 | isinstance(application, BoundApplication) |
||
| 111 | for application in applications): |
||
| 112 | raise ValueError('`applications` should be a list of ' |
||
| 113 | 'BoundApplications') |
||
| 114 | self.roles = roles |
||
| 115 | self.bricks = bricks |
||
| 116 | self.each_role = each_role |
||
| 117 | self.name = name |
||
| 118 | self.name_regex = name_regex |
||
| 119 | self.theano_name = theano_name |
||
| 120 | self.theano_name_regex = theano_name_regex |
||
| 121 | self.applications = applications |
||
| 122 | |||
| 123 | def __call__(self, variables): |
||
| 124 | """Filter the given variables. |
||
| 125 | |||
| 126 | Parameters |
||
| 127 | ---------- |
||
| 128 | variables : list of :class:`~tensor.TensorVariable` |
||
| 129 | |||
| 130 | """ |
||
| 131 | if self.roles: |
||
| 132 | variables = [var for var in variables |
||
| 133 | if has_roles(var, self.roles, self.each_role)] |
||
| 134 | if self.bricks is not None: |
||
| 135 | filtered_variables = [] |
||
| 136 | for var in variables: |
||
| 137 | var_brick = get_brick(var) |
||
| 138 | if var_brick is None: |
||
| 139 | continue |
||
| 140 | for brick in self.bricks: |
||
| 141 | if isclass(brick) and isinstance(var_brick, brick): |
||
| 142 | filtered_variables.append(var) |
||
| 143 | break |
||
| 144 | elif isinstance(brick, Brick) and var_brick is brick: |
||
| 145 | filtered_variables.append(var) |
||
| 146 | break |
||
| 147 | variables = filtered_variables |
||
| 148 | if self.name: |
||
| 149 | variables = [var for var in variables |
||
| 150 | if hasattr(var.tag, 'name') and |
||
| 151 | self.name == var.tag.name] |
||
| 152 | if self.name_regex: |
||
| 153 | variables = [var for var in variables |
||
| 154 | if hasattr(var.tag, 'name') and |
||
| 155 | re.match(self.name_regex, var.tag.name)] |
||
| 156 | if self.theano_name: |
||
| 157 | variables = [var for var in variables |
||
| 158 | if (var.name is not None) and |
||
| 159 | self.theano_name == var.name] |
||
| 160 | if self.theano_name_regex: |
||
| 161 | variables = [var for var in variables |
||
| 162 | if (var.name is not None) and |
||
| 163 | re.match(self.theano_name_regex, var.name)] |
||
| 164 | if self.applications: |
||
| 165 | variables = [var for var in variables |
||
| 166 | if get_application_call(var) and |
||
| 167 | get_application_call(var).application in |
||
| 168 | self.applications] |
||
| 169 | return variables |
||
| 170 |