| Total Complexity | 65 |
| Total Lines | 282 |
| Duplicated Lines | 0 % |
Complex classes like blocks.graph.ComputationGraph often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | """Annotated computation graph management.""" |
||
| 27 | class ComputationGraph(object): |
||
| 28 | r"""Encapsulates a managed Theano computation graph. |
||
| 29 | |||
| 30 | This implies that it not only contains the variables required to |
||
| 31 | compute the given outputs, but also all the auxiliary variables and |
||
| 32 | updates that were attached to these variables through the annotation |
||
| 33 | system. |
||
| 34 | |||
| 35 | All variables are presented in topologically sorted order according to |
||
| 36 | the apply nodes that they are an input to. |
||
| 37 | |||
| 38 | Parameters |
||
| 39 | ---------- |
||
| 40 | outputs : (list of) :class:`~tensor.TensorVariable` |
||
| 41 | The output(s) of the computation graph. |
||
| 42 | |||
| 43 | Attributes |
||
| 44 | ---------- |
||
| 45 | inputs : list of :class:`~tensor.TensorVariable` |
||
| 46 | The inputs of the computation graph. This does not include shared |
||
| 47 | variables and constants. |
||
| 48 | shared_variables : list of :class:`~tensor.TensorSharedVariable` |
||
| 49 | All the shared variables in the graph. |
||
| 50 | parameters : list of :class:`~tensor.TensorSharedVariable` |
||
| 51 | All the shared variables which have the :const:`.PARAMETER` role. |
||
| 52 | outputs : list of :class:`~tensor.TensorVariable` |
||
| 53 | The outputs of the computations graph (as passed to the |
||
| 54 | constructor). |
||
| 55 | auxiliary_variables : list of :class:`~tensor.TensorVariable` |
||
| 56 | All variables which have the :const:`.AUXILIARY` role. |
||
| 57 | intermediary_variables : list of :class:`~tensor.TensorVariable` |
||
| 58 | Any variable that is not part of :attr:`inputs` or :attr:`outputs`. |
||
| 59 | variables : list of :class:`~tensor.TensorVariable` |
||
| 60 | All variables (including auxiliary) in the managed graph. |
||
| 61 | scans : list of :class:`~theano.scan_module.scan_op.Scan` |
||
| 62 | All Scan ops used in this computation graph. |
||
| 63 | scan_variables : list of :class:`~tensor.TensorVariable` |
||
| 64 | All variables of the inner graphs of Scan ops. |
||
| 65 | updates : :class:`~tensor.TensorSharedVariable` updates |
||
| 66 | All the updates found attached to the annotations. |
||
| 67 | |||
| 68 | """ |
||
| 69 | def __init__(self, outputs): |
||
| 70 | if isinstance(outputs, Variable): |
||
| 71 | outputs = [outputs] |
||
| 72 | self.outputs = outputs |
||
| 73 | self._get_variables() |
||
| 74 | self._has_inputs = {} |
||
| 75 | |||
| 76 | def __iter__(self): |
||
| 77 | return iter(self.variables) |
||
| 78 | |||
| 79 | @property |
||
| 80 | def inputs(self): |
||
| 81 | """Inputs to the graph, excluding constants and shared variables.""" |
||
| 82 | return [var for var in self.variables if is_graph_input(var)] |
||
| 83 | |||
| 84 | @property |
||
| 85 | def intermediary_variables(self): |
||
| 86 | return [var for var in self.variables if |
||
| 87 | var not in self.inputs and |
||
| 88 | var not in self.outputs] |
||
| 89 | |||
| 90 | @property |
||
| 91 | def shared_variables(self): |
||
| 92 | return [var for var in self.variables if is_shared_variable(var)] |
||
| 93 | |||
| 94 | @property |
||
| 95 | def parameters(self): |
||
| 96 | return [var for var in self.shared_variables |
||
| 97 | if has_roles(var, [PARAMETER])] |
||
| 98 | |||
| 99 | @property |
||
| 100 | def auxiliary_variables(self): |
||
| 101 | return [var for var in self.variables if has_roles(var, [AUXILIARY])] |
||
| 102 | |||
| 103 | @property |
||
| 104 | def scan_variables(self): |
||
| 105 | """Variables of Scan ops.""" |
||
| 106 | return list(chain(*[g.variables for g in self._scan_graphs])) |
||
| 107 | |||
| 108 | def _get_variables(self): |
||
| 109 | """Collect variables, updates and auxiliary variables. |
||
| 110 | |||
| 111 | In addition collects all :class:`.Scan` ops and recurses in the |
||
| 112 | respective inner Theano graphs. |
||
| 113 | |||
| 114 | """ |
||
| 115 | updates = OrderedDict() |
||
| 116 | |||
| 117 | shared_outputs = [o for o in self.outputs if is_shared_variable(o)] |
||
| 118 | usual_outputs = [o for o in self.outputs if not is_shared_variable(o)] |
||
| 119 | variables = shared_outputs |
||
| 120 | |||
| 121 | if usual_outputs: |
||
| 122 | # Sort apply nodes topologically, get variables and remove |
||
| 123 | # duplicates |
||
| 124 | inputs = graph.inputs(self.outputs) |
||
| 125 | sorted_apply_nodes = graph.io_toposort(inputs, usual_outputs) |
||
| 126 | self.scans = list(unique([node.op for node in sorted_apply_nodes |
||
| 127 | if isinstance(node.op, Scan)])) |
||
| 128 | self._scan_graphs = [ComputationGraph(scan.outputs) |
||
| 129 | for scan in self.scans] |
||
| 130 | |||
| 131 | seen = set() |
||
| 132 | main_vars = ( |
||
| 133 | [var for var in list(chain( |
||
| 134 | *[apply_node.inputs for apply_node in sorted_apply_nodes])) |
||
| 135 | if not (var in seen or seen.add(var))] + |
||
| 136 | [var for var in self.outputs if var not in seen]) |
||
| 137 | |||
| 138 | # While preserving order add auxiliary variables, and collect |
||
| 139 | # updates |
||
| 140 | seen = set() |
||
| 141 | # Intermediate variables could be auxiliary |
||
| 142 | seen_avs = set(main_vars) |
||
| 143 | variables = [] |
||
| 144 | for var in main_vars: |
||
| 145 | variables.append(var) |
||
| 146 | for annotation in getattr(var.tag, 'annotations', []): |
||
| 147 | if annotation not in seen: |
||
| 148 | seen.add(annotation) |
||
| 149 | new_avs = [ |
||
| 150 | av for av in annotation.auxiliary_variables |
||
| 151 | if not (av in seen_avs or seen_avs.add(av))] |
||
| 152 | variables.extend(new_avs) |
||
| 153 | updates = dict_union(updates, annotation.updates) |
||
| 154 | |||
| 155 | self.variables = variables |
||
| 156 | self.updates = updates |
||
| 157 | |||
| 158 | def dict_of_inputs(self): |
||
| 159 | """Return a mapping from an input name to the input.""" |
||
| 160 | return {var.name: var for var in self.inputs} |
||
| 161 | |||
| 162 | def replace(self, replacements): |
||
| 163 | """Replace certain variables in the computation graph. |
||
| 164 | |||
| 165 | Parameters |
||
| 166 | ---------- |
||
| 167 | replacements : dict |
||
| 168 | The mapping from variables to be replaced to the corresponding |
||
| 169 | substitutes. |
||
| 170 | |||
| 171 | Examples |
||
| 172 | -------- |
||
| 173 | >>> import theano |
||
| 174 | >>> from theano import tensor, function |
||
| 175 | >>> x = tensor.scalar('x') |
||
| 176 | >>> y = x + 2 |
||
| 177 | >>> z = y + 3 |
||
| 178 | >>> a = z + 5 |
||
| 179 | |||
| 180 | Let's suppose we have dependent replacements like |
||
| 181 | |||
| 182 | >>> replacements = {y: x * 2, z: y * 3} |
||
| 183 | >>> cg = ComputationGraph([a]) |
||
| 184 | >>> theano.pprint(a) # doctest: +NORMALIZE_WHITESPACE |
||
| 185 | '(((x + TensorConstant{2}) + TensorConstant{3}) + |
||
| 186 | TensorConstant{5})' |
||
| 187 | >>> cg_new = cg.replace(replacements) |
||
| 188 | >>> theano.pprint( |
||
| 189 | ... cg_new.outputs[0]) # doctest: +NORMALIZE_WHITESPACE |
||
| 190 | '(((x * TensorConstant{2}) * TensorConstant{3}) + |
||
| 191 | TensorConstant{5})' |
||
| 192 | |||
| 193 | First two sums turned into multiplications |
||
| 194 | |||
| 195 | >>> float(function(cg_new.inputs, cg_new.outputs)(3.)[0]) |
||
| 196 | 23.0 |
||
| 197 | |||
| 198 | """ |
||
| 199 | # Due to theano specifics we have to make one replacement in time |
||
| 200 | replacements = OrderedDict(replacements) |
||
| 201 | |||
| 202 | outputs_cur = self.outputs |
||
| 203 | |||
| 204 | # `replacements` with previous replacements applied. We have to track |
||
| 205 | # variables in the new graph corresponding to original replacements. |
||
| 206 | replacement_keys_cur = [] |
||
| 207 | replacement_vals_cur = [] |
||
| 208 | # Sort `replacements` in topological order |
||
| 209 | # variables in self.variables are in topological order |
||
| 210 | remaining_replacements = replacements.copy() |
||
| 211 | for variable in self.variables: |
||
| 212 | if variable in replacements: |
||
| 213 | if has_roles(variable, [AUXILIARY]): |
||
| 214 | warnings.warn( |
||
| 215 | "replace method was asked to replace a variable ({}) " |
||
| 216 | "that is an auxiliary variable.".format(variable)) |
||
| 217 | replacement_keys_cur.append(variable) |
||
| 218 | # self.variables should not contain duplicates, |
||
| 219 | # otherwise pop() may fail. |
||
| 220 | replacement_vals_cur.append( |
||
| 221 | remaining_replacements.pop(variable)) |
||
| 222 | |||
| 223 | # if remaining_replacements is not empty |
||
| 224 | if remaining_replacements: |
||
| 225 | warnings.warn( |
||
| 226 | "replace method was asked to replace a variable(s) ({}) " |
||
| 227 | "that is not a part of the computational " |
||
| 228 | "graph.".format(str(remaining_replacements.keys()))) |
||
| 229 | |||
| 230 | # Replace step-by-step in topological order |
||
| 231 | while replacement_keys_cur: |
||
| 232 | replace_what = replacement_keys_cur[0] |
||
| 233 | replace_by = replacement_vals_cur[0] |
||
| 234 | # We also want to make changes in future replacements |
||
| 235 | outputs_new = theano.clone( |
||
| 236 | outputs_cur + replacement_keys_cur[1:] + |
||
| 237 | replacement_vals_cur[1:], |
||
| 238 | replace={replace_what: replace_by}) |
||
| 239 | # Reconstruct outputs, keys, and values |
||
| 240 | outputs_cur = outputs_new[:len(outputs_cur)] |
||
| 241 | replacement_keys_cur = outputs_new[len(outputs_cur): |
||
| 242 | len(outputs_cur) + |
||
| 243 | len(replacement_keys_cur) - 1] |
||
| 244 | replacement_vals_cur = outputs_new[len(outputs_cur) + |
||
| 245 | len(replacement_keys_cur):] |
||
| 246 | |||
| 247 | return ComputationGraph(outputs_cur) |
||
| 248 | |||
| 249 | def get_theano_function(self, additional_updates=None, **kwargs): |
||
| 250 | r"""Create Theano function from the graph contained. |
||
| 251 | |||
| 252 | Parameters |
||
| 253 | ---------- |
||
| 254 | \*\*kwargs : dict |
||
| 255 | Keyword arguments to theano.function. |
||
| 256 | Useful for specifying compilation modes or profiling. |
||
| 257 | |||
| 258 | """ |
||
| 259 | updates = self.updates |
||
| 260 | if additional_updates: |
||
| 261 | updates = dict_union(updates, OrderedDict(additional_updates)) |
||
| 262 | return theano.function(self.inputs, self.outputs, updates=updates, |
||
| 263 | **kwargs) |
||
| 264 | |||
| 265 | def get_snapshot(self, data): |
||
| 266 | """Evaluate all role-carrying Theano variables on given data. |
||
| 267 | |||
| 268 | Parameters |
||
| 269 | ---------- |
||
| 270 | data : dict of (data source, data) pairs |
||
| 271 | Data for input variables. The sources should match with the |
||
| 272 | names of the input variables. |
||
| 273 | |||
| 274 | Returns |
||
| 275 | ------- |
||
| 276 | Dictionary of (variable, variable value on given data) pairs. |
||
| 277 | |||
| 278 | """ |
||
| 279 | role_variables = [var for var in self.variables |
||
| 280 | if hasattr(var.tag, "roles") and |
||
| 281 | not is_shared_variable(var)] |
||
| 282 | value_holders = [shared_like(var) for var in role_variables] |
||
| 283 | function = self.get_theano_function(equizip(value_holders, |
||
| 284 | role_variables)) |
||
| 285 | function(*(data[input_.name] for input_ in self.inputs)) |
||
| 286 | return OrderedDict([(var, value_holder.get_value(borrow=True)) |
||
| 287 | for var, value_holder in equizip(role_variables, |
||
| 288 | value_holders)]) |
||
| 289 | |||
| 290 | def has_inputs(self, variable): |
||
| 291 | """Check if a variable depends on input variables. |
||
| 292 | |||
| 293 | Returns |
||
| 294 | ------- |
||
| 295 | bool |
||
| 296 | ``True`` if the given variable depends on input variables, |
||
| 297 | ``False`` otherwise. |
||
| 298 | |||
| 299 | """ |
||
| 300 | if variable not in self._has_inputs: |
||
| 301 | self._has_inputs[variable] = False |
||
| 302 | if is_graph_input(variable): |
||
| 303 | self._has_inputs[variable] = True |
||
| 304 | elif getattr(variable, 'owner', None): |
||
| 305 | for dependancy in variable.owner.inputs: |
||
| 306 | if self.has_inputs(dependancy): |
||
| 307 | self._has_inputs[variable] = True |
||
| 308 | return self._has_inputs[variable] |
||
| 309 | |||
| 541 |