| Conditions | 30 |
| Total Lines | 105 |
| Lines | 0 |
| Ratio | 0 % |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like deepy.trainers.optimize_updates() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | #!/usr/bin/env python |
||
| 18 | def optimize_updates(params, gradients, config=None, shapes=None): |
||
| 19 | """ |
||
| 20 | General optimization function for Theano. |
||
| 21 | Parameters: |
||
| 22 | params - parameters |
||
| 23 | gradients - gradients |
||
| 24 | config - training config |
||
| 25 | Returns: |
||
| 26 | Theano updates |
||
| 27 | :type config: deepy.TrainerConfig or dict |
||
| 28 | """ |
||
| 29 | if config and isinstance(config, dict): |
||
| 30 | config = TrainerConfig(config) |
||
| 31 | |||
| 32 | # Clipping |
||
| 33 | if config: |
||
| 34 | clip_value = config.get("gradient_clipping", None) |
||
| 35 | |||
| 36 | if clip_value: |
||
| 37 | clip_constant = T.constant(clip_value, dtype=FLOATX) |
||
| 38 | |||
| 39 | if config.avoid_compute_embed_norm: |
||
| 40 | grad_norm = multiple_l2_norm([t[1] for t in zip(params, gradients) if not t[0].name.startswith("W_embed")]) |
||
| 41 | else: |
||
| 42 | grad_norm = multiple_l2_norm(gradients) |
||
| 43 | isnan = T.or_(T.isnan(grad_norm), T.isinf(grad_norm)) |
||
| 44 | multiplier = ifelse(grad_norm < clip_constant, |
||
| 45 | T.constant(1., dtype=FLOATX), clip_constant / (grad_norm + EPSILON)) |
||
| 46 | |||
| 47 | # Clip |
||
| 48 | clipped_gradients = [] |
||
| 49 | for param, g in zip(params, gradients): |
||
| 50 | g = multiplier * g |
||
| 51 | if config.avoid_nan: |
||
| 52 | g = T.switch(isnan, np.float32(0.1) * param, g) |
||
| 53 | if config.gradient_tolerance: |
||
| 54 | g = ifelse(grad_norm > config.gradient_tolerance, T.zeros_like(g) + EPSILON, g) |
||
| 55 | clipped_gradients.append(g) |
||
| 56 | |||
| 57 | gradients = clipped_gradients |
||
| 58 | # Regularization |
||
| 59 | if config and config.weight_l2: |
||
| 60 | regularized_gradients = [] |
||
| 61 | for param, grad in zip(params, gradients): |
||
| 62 | grad = grad + (2 * config.weight_l2 * param) |
||
| 63 | regularized_gradients.append(grad) |
||
| 64 | gradients = regularized_gradients |
||
| 65 | |||
| 66 | # Avoid nan but not computing the norm |
||
| 67 | # This is not recommended |
||
| 68 | if config and config.avoid_nan and not config.gradient_clipping: |
||
| 69 | logging.info("avoid NaN gradients") |
||
| 70 | new_gradients = [] |
||
| 71 | for grad in gradients: |
||
| 72 | new_grad = ifelse(T.isnan(grad).any(), T.zeros_like(grad) + EPSILON, grad) |
||
| 73 | new_gradients.append(new_grad) |
||
| 74 | gradients = new_gradients |
||
| 75 | |||
| 76 | |||
| 77 | # Find method |
||
| 78 | method = "SGD" |
||
| 79 | if config: |
||
| 80 | method = config.get("method", method).upper() |
||
| 81 | # Get Function |
||
| 82 | func = None |
||
| 83 | if method in ["SGD", "ADAGRAD", "ADADELTA", "FINETUNING_ADAGRAD"]: |
||
| 84 | from cores.ada_family import ada_family_core |
||
| 85 | func = ada_family_core |
||
| 86 | elif method == "ADAM": |
||
| 87 | from cores.adam import adam_core |
||
| 88 | func = adam_core |
||
| 89 | elif method == "RMSPROP": |
||
| 90 | from cores.rmsprop import rmsprop_core |
||
| 91 | func = rmsprop_core |
||
| 92 | elif method == "MOMENTUM": |
||
| 93 | from cores.momentum import momentum_core |
||
| 94 | func = momentum_core |
||
| 95 | |||
| 96 | if not func: |
||
| 97 | raise NotImplementedError("method '%s' is not supported" % method) |
||
| 98 | |||
| 99 | logging.info("optimize method=%s parameters=%s" % (method, str(params))) |
||
| 100 | |||
| 101 | free_parameters = [] |
||
| 102 | return_vals = wrap_core(func, config, params, gradients) |
||
| 103 | if type(return_vals) == list and type(return_vals[0]) == list: |
||
| 104 | updates, free_parameters = return_vals |
||
| 105 | else: |
||
| 106 | updates = return_vals |
||
| 107 | |||
| 108 | # No free param recording |
||
| 109 | if config and not config.record_free_params: |
||
| 110 | free_parameters = [] |
||
| 111 | |||
| 112 | # Weight bound |
||
| 113 | if config.weight_bound: |
||
| 114 | logging.info("apply weight bound of %.2f" % config.weight_bound) |
||
| 115 | new_updates = [] |
||
| 116 | for param, update_value in updates: |
||
| 117 | bounded_value = (update_value * (T.abs_(update_value) <= config.weight_bound) + |
||
| 118 | config.weight_bound * (update_value > config.weight_bound) + |
||
| 119 | -config.weight_bound * (update_value < -config.weight_bound)) |
||
| 120 | new_updates.append((param, bounded_value)) |
||
| 121 | updates = new_updates |
||
| 122 | return updates, free_parameters |
||
| 123 | |||
| 136 |