| Conditions | 30 |
| Total Lines | 105 |
| Lines | 0 |
| Ratio | 0 % |
| Changes | 5 | ||
| Bugs | 0 | Features | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like optimize_updates() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | #!/usr/bin/env python |
||
| 19 | def optimize_updates(params, gradients, config=None, shapes=None): |
||
| 20 | """ |
||
| 21 | General optimization function for Theano. |
||
| 22 | Parameters: |
||
| 23 | params - parameters |
||
| 24 | gradients - gradients |
||
| 25 | config - training config |
||
| 26 | Returns: |
||
| 27 | Theano updates |
||
| 28 | :type config: deepy.TrainerConfig or dict |
||
| 29 | """ |
||
| 30 | if config and isinstance(config, dict): |
||
| 31 | config = TrainerConfig(config) |
||
| 32 | |||
| 33 | # Clipping |
||
| 34 | if config: |
||
| 35 | clip_value = config.get("gradient_clipping", None) |
||
| 36 | |||
| 37 | if clip_value: |
||
| 38 | clip_constant = T.constant(clip_value, dtype=FLOATX) |
||
| 39 | |||
| 40 | if config.avoid_compute_embed_norm: |
||
| 41 | grad_norm = multiple_l2_norm([t[1] for t in zip(params, gradients) if not t[0].name.startswith("W_embed")]) |
||
| 42 | else: |
||
| 43 | grad_norm = multiple_l2_norm(gradients) |
||
| 44 | isnan = T.or_(T.isnan(grad_norm), T.isinf(grad_norm)) |
||
| 45 | multiplier = ifelse(grad_norm < clip_constant, |
||
| 46 | T.constant(1., dtype=FLOATX), clip_constant / (grad_norm + EPSILON)) |
||
| 47 | |||
| 48 | # Clip |
||
| 49 | clipped_gradients = [] |
||
| 50 | for param, g in zip(params, gradients): |
||
| 51 | g = multiplier * g |
||
| 52 | if config.avoid_nan: |
||
| 53 | g = T.switch(isnan, np.float32(0.1) * param, g) |
||
| 54 | if config.gradient_tolerance: |
||
| 55 | g = ifelse(grad_norm > config.gradient_tolerance, T.zeros_like(g) + EPSILON, g) |
||
| 56 | clipped_gradients.append(g) |
||
| 57 | |||
| 58 | gradients = clipped_gradients |
||
| 59 | # Regularization |
||
| 60 | if config and config.weight_l2: |
||
| 61 | regularized_gradients = [] |
||
| 62 | for param, grad in zip(params, gradients): |
||
| 63 | grad = grad + (2 * config.weight_l2 * param) |
||
| 64 | regularized_gradients.append(grad) |
||
| 65 | gradients = regularized_gradients |
||
| 66 | |||
| 67 | # Avoid nan but not computing the norm |
||
| 68 | # This is not recommended |
||
| 69 | if config and config.avoid_nan and not config.gradient_clipping: |
||
| 70 | logging.info("avoid NaN gradients") |
||
| 71 | new_gradients = [] |
||
| 72 | for grad in gradients: |
||
| 73 | new_grad = ifelse(T.isnan(grad).any(), T.zeros_like(grad) + EPSILON, grad) |
||
| 74 | new_gradients.append(new_grad) |
||
| 75 | gradients = new_gradients |
||
| 76 | |||
| 77 | |||
| 78 | # Find method |
||
| 79 | method = "SGD" |
||
| 80 | if config: |
||
| 81 | method = config.get("method", method).upper() |
||
| 82 | # Get Function |
||
| 83 | func = None |
||
| 84 | if method in ["SGD", "ADAGRAD", "ADADELTA", "FINETUNING_ADAGRAD"]: |
||
| 85 | from cores.ada_family import ada_family_core |
||
| 86 | func = ada_family_core |
||
| 87 | elif method == "ADAM": |
||
| 88 | from cores.adam import adam_core |
||
| 89 | func = adam_core |
||
| 90 | elif method == "RMSPROP": |
||
| 91 | from cores.rmsprop import rmsprop_core |
||
| 92 | func = rmsprop_core |
||
| 93 | elif method == "MOMENTUM": |
||
| 94 | from cores.momentum import momentum_core |
||
| 95 | func = momentum_core |
||
| 96 | |||
| 97 | if not func: |
||
| 98 | raise NotImplementedError("method '%s' is not supported" % method) |
||
| 99 | |||
| 100 | logging.info("optimize method=%s parameters=%s" % (method, str(params))) |
||
| 101 | |||
| 102 | free_parameters = [] |
||
| 103 | return_vals = wrap_core(func, config, params, gradients) |
||
| 104 | if type(return_vals) == list and type(return_vals[0]) == list: |
||
| 105 | updates, free_parameters = return_vals |
||
| 106 | else: |
||
| 107 | updates = return_vals |
||
| 108 | |||
| 109 | # No free param recording |
||
| 110 | if config and not config.record_free_params: |
||
| 111 | free_parameters = [] |
||
| 112 | |||
| 113 | # Weight bound |
||
| 114 | if config.weight_bound: |
||
| 115 | logging.info("apply weight bound of %.2f" % config.weight_bound) |
||
| 116 | new_updates = [] |
||
| 117 | for param, update_value in updates: |
||
| 118 | bounded_value = (update_value * (T.abs_(update_value) <= config.weight_bound) + |
||
| 119 | config.weight_bound * (update_value > config.weight_bound) + |
||
| 120 | -config.weight_bound * (update_value < -config.weight_bound)) |
||
| 121 | new_updates.append((param, bounded_value)) |
||
| 122 | updates = new_updates |
||
| 123 | return updates, free_parameters |
||
| 124 | |||
| 137 |