| Conditions | 30 | 
| Total Lines | 105 | 
| Lines | 0 | 
| Ratio | 0 % | 
| Changes | 5 | ||
| Bugs | 0 | Features | 0 | 
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like optimize_updates() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | #!/usr/bin/env python  | 
            ||
| 18 | def optimize_updates(params, gradients, config=None, shapes=None):  | 
            ||
| 19 | """  | 
            ||
| 20 | General optimization function for Theano.  | 
            ||
| 21 | Parameters:  | 
            ||
| 22 | params - parameters  | 
            ||
| 23 | gradients - gradients  | 
            ||
| 24 | config - training config  | 
            ||
| 25 | Returns:  | 
            ||
| 26 | Theano updates  | 
            ||
| 27 | :type config: deepy.TrainerConfig or dict  | 
            ||
| 28 | """  | 
            ||
| 29 | if config and isinstance(config, dict):  | 
            ||
| 30 | config = TrainerConfig(config)  | 
            ||
| 31 | |||
| 32 | # Clipping  | 
            ||
| 33 | if config:  | 
            ||
| 34 |         clip_value = config.get("gradient_clipping", None) | 
            ||
| 35 | |||
| 36 | if clip_value:  | 
            ||
| 37 | clip_constant = T.constant(clip_value, dtype=FLOATX)  | 
            ||
| 38 | |||
| 39 | if config.avoid_compute_embed_norm:  | 
            ||
| 40 |                 grad_norm = multiple_l2_norm([t[1] for t in zip(params, gradients) if not t[0].name.startswith("W_embed")]) | 
            ||
| 41 | else:  | 
            ||
| 42 | grad_norm = multiple_l2_norm(gradients)  | 
            ||
| 43 | isnan = T.or_(T.isnan(grad_norm), T.isinf(grad_norm))  | 
            ||
| 44 | multiplier = ifelse(grad_norm < clip_constant,  | 
            ||
| 45 | T.constant(1., dtype=FLOATX), clip_constant / (grad_norm + EPSILON))  | 
            ||
| 46 | |||
| 47 | # Clip  | 
            ||
| 48 | clipped_gradients = []  | 
            ||
| 49 | for param, g in zip(params, gradients):  | 
            ||
| 50 | g = multiplier * g  | 
            ||
| 51 | if config.avoid_nan:  | 
            ||
| 52 | g = T.switch(isnan, np.float32(0.1) * param, g)  | 
            ||
| 53 | if config.gradient_tolerance:  | 
            ||
| 54 | g = ifelse(grad_norm > config.gradient_tolerance, T.zeros_like(g) + EPSILON, g)  | 
            ||
| 55 | clipped_gradients.append(g)  | 
            ||
| 56 | |||
| 57 | gradients = clipped_gradients  | 
            ||
| 58 | # Regularization  | 
            ||
| 59 | if config and config.weight_l2:  | 
            ||
| 60 | regularized_gradients = []  | 
            ||
| 61 | for param, grad in zip(params, gradients):  | 
            ||
| 62 | grad = grad + (2 * config.weight_l2 * param)  | 
            ||
| 63 | regularized_gradients.append(grad)  | 
            ||
| 64 | gradients = regularized_gradients  | 
            ||
| 65 | |||
| 66 | # Avoid nan but not computing the norm  | 
            ||
| 67 | # This is not recommended  | 
            ||
| 68 | if config and config.avoid_nan and not config.gradient_clipping:  | 
            ||
| 69 |         logging.info("avoid NaN gradients") | 
            ||
| 70 | new_gradients = []  | 
            ||
| 71 | for grad in gradients:  | 
            ||
| 72 | new_grad = ifelse(T.isnan(grad).any(), T.zeros_like(grad) + EPSILON, grad)  | 
            ||
| 73 | new_gradients.append(new_grad)  | 
            ||
| 74 | gradients = new_gradients  | 
            ||
| 75 | |||
| 76 | |||
| 77 | # Find method  | 
            ||
| 78 | method = "SGD"  | 
            ||
| 79 | if config:  | 
            ||
| 80 |         method = config.get("method", method).upper() | 
            ||
| 81 | # Get Function  | 
            ||
| 82 | func = None  | 
            ||
| 83 | if method in ["SGD", "ADAGRAD", "ADADELTA", "FINETUNING_ADAGRAD"]:  | 
            ||
| 84 | from cores.ada_family import ada_family_core  | 
            ||
| 85 | func = ada_family_core  | 
            ||
| 86 | elif method == "ADAM":  | 
            ||
| 87 | from cores.adam import adam_core  | 
            ||
| 88 | func = adam_core  | 
            ||
| 89 | elif method == "RMSPROP":  | 
            ||
| 90 | from cores.rmsprop import rmsprop_core  | 
            ||
| 91 | func = rmsprop_core  | 
            ||
| 92 | elif method == "MOMENTUM":  | 
            ||
| 93 | from cores.momentum import momentum_core  | 
            ||
| 94 | func = momentum_core  | 
            ||
| 95 | |||
| 96 | if not func:  | 
            ||
| 97 |         raise NotImplementedError("method '%s' is not supported" % method) | 
            ||
| 98 | |||
| 99 |     logging.info("optimize method=%s parameters=%s" % (method, str(params))) | 
            ||
| 100 | |||
| 101 | free_parameters = []  | 
            ||
| 102 | return_vals = wrap_core(func, config, params, gradients)  | 
            ||
| 103 | if type(return_vals) == list and type(return_vals[0]) == list:  | 
            ||
| 104 | updates, free_parameters = return_vals  | 
            ||
| 105 | else:  | 
            ||
| 106 | updates = return_vals  | 
            ||
| 107 | |||
| 108 | # No free param recording  | 
            ||
| 109 | if config and not config.record_free_params:  | 
            ||
| 110 | free_parameters = []  | 
            ||
| 111 | |||
| 112 | # Weight bound  | 
            ||
| 113 | if config.weight_bound:  | 
            ||
| 114 |         logging.info("apply weight bound of %.2f" % config.weight_bound) | 
            ||
| 115 | new_updates = []  | 
            ||
| 116 | for param, update_value in updates:  | 
            ||
| 117 | bounded_value = (update_value * (T.abs_(update_value) <= config.weight_bound) +  | 
            ||
| 118 | config.weight_bound * (update_value > config.weight_bound) +  | 
            ||
| 119 | -config.weight_bound * (update_value < -config.weight_bound))  | 
            ||
| 120 | new_updates.append((param, bounded_value))  | 
            ||
| 121 | updates = new_updates  | 
            ||
| 122 | return updates, free_parameters  | 
            ||
| 123 | |||
| 136 |