Conditions | 30 |
Total Lines | 105 |
Lines | 0 |
Ratio | 0 % |
Changes | 5 | ||
Bugs | 0 | Features | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like optimize_updates() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | #!/usr/bin/env python |
||
19 | def optimize_updates(params, gradients, config=None, shapes=None): |
||
20 | """ |
||
21 | General optimization function for Theano. |
||
22 | Parameters: |
||
23 | params - parameters |
||
24 | gradients - gradients |
||
25 | config - training config |
||
26 | Returns: |
||
27 | Theano updates |
||
28 | :type config: deepy.TrainerConfig or dict |
||
29 | """ |
||
30 | if config and isinstance(config, dict): |
||
31 | config = TrainerConfig(config) |
||
32 | |||
33 | # Clipping |
||
34 | if config: |
||
35 | clip_value = config.get("gradient_clipping", None) |
||
36 | |||
37 | if clip_value: |
||
38 | clip_constant = T.constant(clip_value, dtype=FLOATX) |
||
39 | |||
40 | if config.avoid_compute_embed_norm: |
||
41 | grad_norm = multiple_l2_norm([t[1] for t in zip(params, gradients) if not t[0].name.startswith("W_embed")]) |
||
42 | else: |
||
43 | grad_norm = multiple_l2_norm(gradients) |
||
44 | isnan = T.or_(T.isnan(grad_norm), T.isinf(grad_norm)) |
||
45 | multiplier = ifelse(grad_norm < clip_constant, |
||
46 | T.constant(1., dtype=FLOATX), clip_constant / (grad_norm + EPSILON)) |
||
47 | |||
48 | # Clip |
||
49 | clipped_gradients = [] |
||
50 | for param, g in zip(params, gradients): |
||
51 | g = multiplier * g |
||
52 | if config.avoid_nan: |
||
53 | g = T.switch(isnan, np.float32(0.1) * param, g) |
||
54 | if config.gradient_tolerance: |
||
55 | g = ifelse(grad_norm > config.gradient_tolerance, T.zeros_like(g) + EPSILON, g) |
||
56 | clipped_gradients.append(g) |
||
57 | |||
58 | gradients = clipped_gradients |
||
59 | # Regularization |
||
60 | if config and config.weight_l2: |
||
61 | regularized_gradients = [] |
||
62 | for param, grad in zip(params, gradients): |
||
63 | grad = grad + (2 * config.weight_l2 * param) |
||
64 | regularized_gradients.append(grad) |
||
65 | gradients = regularized_gradients |
||
66 | |||
67 | # Avoid nan but not computing the norm |
||
68 | # This is not recommended |
||
69 | if config and config.avoid_nan and not config.gradient_clipping: |
||
70 | logging.info("avoid NaN gradients") |
||
71 | new_gradients = [] |
||
72 | for grad in gradients: |
||
73 | new_grad = ifelse(T.isnan(grad).any(), T.zeros_like(grad) + EPSILON, grad) |
||
74 | new_gradients.append(new_grad) |
||
75 | gradients = new_gradients |
||
76 | |||
77 | |||
78 | # Find method |
||
79 | method = "SGD" |
||
80 | if config: |
||
81 | method = config.get("method", method).upper() |
||
82 | # Get Function |
||
83 | func = None |
||
84 | if method in ["SGD", "ADAGRAD", "ADADELTA", "FINETUNING_ADAGRAD"]: |
||
85 | from cores.ada_family import ada_family_core |
||
86 | func = ada_family_core |
||
87 | elif method == "ADAM": |
||
88 | from cores.adam import adam_core |
||
89 | func = adam_core |
||
90 | elif method == "RMSPROP": |
||
91 | from cores.rmsprop import rmsprop_core |
||
92 | func = rmsprop_core |
||
93 | elif method == "MOMENTUM": |
||
94 | from cores.momentum import momentum_core |
||
95 | func = momentum_core |
||
96 | |||
97 | if not func: |
||
98 | raise NotImplementedError("method '%s' is not supported" % method) |
||
99 | |||
100 | logging.info("optimize method=%s parameters=%s" % (method, str(params))) |
||
101 | |||
102 | free_parameters = [] |
||
103 | return_vals = wrap_core(func, config, params, gradients) |
||
104 | if type(return_vals) == list and type(return_vals[0]) == list: |
||
105 | updates, free_parameters = return_vals |
||
106 | else: |
||
107 | updates = return_vals |
||
108 | |||
109 | # No free param recording |
||
110 | if config and not config.record_free_params: |
||
111 | free_parameters = [] |
||
112 | |||
113 | # Weight bound |
||
114 | if config.weight_bound: |
||
115 | logging.info("apply weight bound of %.2f" % config.weight_bound) |
||
116 | new_updates = [] |
||
117 | for param, update_value in updates: |
||
118 | bounded_value = (update_value * (T.abs_(update_value) <= config.weight_bound) + |
||
119 | config.weight_bound * (update_value > config.weight_bound) + |
||
120 | -config.weight_bound * (update_value < -config.weight_bound)) |
||
121 | new_updates.append((param, bounded_value)) |
||
122 | updates = new_updates |
||
123 | return updates, free_parameters |
||
124 | |||
137 |