optimize_updates()   F
last analyzed

Complexity

Conditions 30

Size

Total Lines 105

Duplication

Lines 0
Ratio 0 %

Importance

Changes 5
Bugs 0 Features 0
Metric Value
cc 30
dl 0
loc 105
rs 2
c 5
b 0
f 0

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like optimize_updates() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
5
import logging as loggers
6
7
import numpy as np
8
import theano
9
import theano.tensor as T
10
from theano.ifelse import ifelse
11
12
from deepy.utils import dim_to_var
13
from deepy.core.env import FLOATX, EPSILON
14
from deepy.trainers.util import wrap_core, multiple_l2_norm
15
from deepy.conf import TrainerConfig
16
17
logging = loggers.getLogger(__name__)
18
19
def optimize_updates(params, gradients, config=None, shapes=None):
20
    """
21
    General optimization function for Theano.
22
    Parameters:
23
        params - parameters
24
        gradients - gradients
25
        config - training config
26
    Returns:
27
        Theano updates
28
    :type config: deepy.TrainerConfig or dict
29
    """
30
    if config and isinstance(config, dict):
31
        config = TrainerConfig(config)
32
33
    # Clipping
34
    if config:
35
        clip_value = config.get("gradient_clipping", None)
36
37
        if clip_value:
38
            clip_constant = T.constant(clip_value, dtype=FLOATX)
39
40
            if config.avoid_compute_embed_norm:
41
                grad_norm = multiple_l2_norm([t[1] for t in zip(params, gradients) if not t[0].name.startswith("W_embed")])
42
            else:
43
                grad_norm = multiple_l2_norm(gradients)
44
            isnan = T.or_(T.isnan(grad_norm), T.isinf(grad_norm))
45
            multiplier = ifelse(grad_norm < clip_constant,
46
                                T.constant(1., dtype=FLOATX), clip_constant / (grad_norm + EPSILON))
47
48
            # Clip
49
            clipped_gradients = []
50
            for param, g in zip(params, gradients):
51
                g = multiplier * g
52
                if config.avoid_nan:
53
                    g = T.switch(isnan, np.float32(0.1) * param, g)
54
                if config.gradient_tolerance:
55
                    g = ifelse(grad_norm > config.gradient_tolerance, T.zeros_like(g) + EPSILON, g)
56
                clipped_gradients.append(g)
57
58
            gradients = clipped_gradients
59
    # Regularization
60
    if config and config.weight_l2:
61
        regularized_gradients = []
62
        for param, grad in zip(params, gradients):
63
            grad = grad + (2 * config.weight_l2 * param)
64
            regularized_gradients.append(grad)
65
        gradients = regularized_gradients
66
67
    # Avoid nan but not computing the norm
68
    # This is not recommended
69
    if config and config.avoid_nan and not config.gradient_clipping:
70
        logging.info("avoid NaN gradients")
71
        new_gradients = []
72
        for grad in gradients:
73
            new_grad = ifelse(T.isnan(grad).any(), T.zeros_like(grad) + EPSILON, grad)
74
            new_gradients.append(new_grad)
75
        gradients = new_gradients
76
77
78
    # Find method
79
    method = "SGD"
80
    if config:
81
        method = config.get("method", method).upper()
82
    # Get Function
83
    func = None
84
    if method in ["SGD", "ADAGRAD", "ADADELTA", "FINETUNING_ADAGRAD"]:
85
        from cores.ada_family import ada_family_core
86
        func = ada_family_core
87
    elif method == "ADAM":
88
        from cores.adam import adam_core
89
        func = adam_core
90
    elif method == "RMSPROP":
91
        from cores.rmsprop import rmsprop_core
92
        func = rmsprop_core
93
    elif method == "MOMENTUM":
94
        from cores.momentum import momentum_core
95
        func = momentum_core
96
97
    if not func:
98
        raise NotImplementedError("method '%s' is not supported" % method)
99
100
    logging.info("optimize method=%s parameters=%s" % (method, str(params)))
101
102
    free_parameters = []
103
    return_vals = wrap_core(func, config, params, gradients)
104
    if type(return_vals) == list and type(return_vals[0]) == list:
105
        updates, free_parameters = return_vals
106
    else:
107
        updates = return_vals
108
109
    # No free param recording
110
    if config and not config.record_free_params:
111
        free_parameters = []
112
113
    # Weight bound
114
    if config.weight_bound:
115
        logging.info("apply weight bound of %.2f" % config.weight_bound)
116
        new_updates = []
117
        for param, update_value in updates:
118
            bounded_value = (update_value * (T.abs_(update_value) <= config.weight_bound) +
119
                             config.weight_bound * (update_value > config.weight_bound) +
120
                             -config.weight_bound * (update_value < -config.weight_bound))
121
            new_updates.append((param, bounded_value))
122
        updates = new_updates
123
    return updates, free_parameters
124
125
def optimize_function(params, config=None):
126
    """
127
    Create a optimizing function receives gradients.
128
    Parameters:
129
        params - parameters
130
        config - training configuration
131
    Returns:
132
        updating function receives gradients
133
    """
134
    gs = [dim_to_var(p.ndim) for p in params]
135
    updates, _ = optimize_updates(params, gs, config)
136
    return theano.function(gs, [], updates=updates)
137