Completed
Push — master ( f73e69...91b7c0 )
by Raphael
01:35
created

deepy.trainers.optimize_updates()   F

Complexity

Conditions 30

Size

Total Lines 105

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 30
dl 0
loc 105
rs 2

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like deepy.trainers.optimize_updates() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
5
import logging as loggers
6
7
import numpy as np
8
import theano
9
import theano.tensor as T
10
from theano.ifelse import ifelse
11
12
from deepy.utils import FLOATX, dim_to_var, EPSILON
13
from deepy.trainers.util import wrap_core, multiple_l2_norm
14
from deepy.conf import TrainerConfig
15
16
logging = loggers.getLogger(__name__)
17
18
def optimize_updates(params, gradients, config=None, shapes=None):
19
    """
20
    General optimization function for Theano.
21
    Parameters:
22
        params - parameters
23
        gradients - gradients
24
        config - training config
25
    Returns:
26
        Theano updates
27
    :type config: deepy.TrainerConfig or dict
28
    """
29
    if config and isinstance(config, dict):
30
        config = TrainerConfig(config)
31
32
    # Clipping
33
    if config:
34
        clip_value = config.get("gradient_clipping", None)
35
36
        if clip_value:
37
            clip_constant = T.constant(clip_value, dtype=FLOATX)
38
39
            if config.avoid_compute_embed_norm:
40
                grad_norm = multiple_l2_norm([t[1] for t in zip(params, gradients) if not t[0].name.startswith("W_embed")])
41
            else:
42
                grad_norm = multiple_l2_norm(gradients)
43
            isnan = T.or_(T.isnan(grad_norm), T.isinf(grad_norm))
44
            multiplier = ifelse(grad_norm < clip_constant,
45
                                T.constant(1., dtype=FLOATX), clip_constant / (grad_norm + EPSILON))
46
47
            # Clip
48
            clipped_gradients = []
49
            for param, g in zip(params, gradients):
50
                g = multiplier * g
51
                if config.avoid_nan:
52
                    g = T.switch(isnan, np.float32(0.1) * param, g)
53
                if config.gradient_tolerance:
54
                    g = ifelse(grad_norm > config.gradient_tolerance, T.zeros_like(g) + EPSILON, g)
55
                clipped_gradients.append(g)
56
57
            gradients = clipped_gradients
58
    # Regularization
59
    if config and config.weight_l2:
60
        regularized_gradients = []
61
        for param, grad in zip(params, gradients):
62
            grad = grad + (2 * config.weight_l2 * param)
63
            regularized_gradients.append(grad)
64
        gradients = regularized_gradients
65
66
    # Avoid nan but not computing the norm
67
    # This is not recommended
68
    if config and config.avoid_nan and not config.gradient_clipping:
69
        logging.info("avoid NaN gradients")
70
        new_gradients = []
71
        for grad in gradients:
72
            new_grad = ifelse(T.isnan(grad).any(), T.zeros_like(grad) + EPSILON, grad)
73
            new_gradients.append(new_grad)
74
        gradients = new_gradients
75
76
77
    # Find method
78
    method = "SGD"
79
    if config:
80
        method = config.get("method", method).upper()
81
    # Get Function
82
    func = None
83
    if method in ["SGD", "ADAGRAD", "ADADELTA", "FINETUNING_ADAGRAD"]:
84
        from cores.ada_family import ada_family_core
85
        func = ada_family_core
86
    elif method == "ADAM":
87
        from cores.adam import adam_core
88
        func = adam_core
89
    elif method == "RMSPROP":
90
        from cores.rmsprop import rmsprop_core
91
        func = rmsprop_core
92
    elif method == "MOMENTUM":
93
        from cores.momentum import momentum_core
94
        func = momentum_core
95
96
    if not func:
97
        raise NotImplementedError("method '%s' is not supported" % method)
98
99
    logging.info("optimize method=%s parameters=%s" % (method, str(params)))
100
101
    free_parameters = []
102
    return_vals = wrap_core(func, config, params, gradients)
103
    if type(return_vals) == list and type(return_vals[0]) == list:
104
        updates, free_parameters = return_vals
105
    else:
106
        updates = return_vals
107
108
    # No free param recording
109
    if config and not config.record_free_params:
110
        free_parameters = []
111
112
    # Weight bound
113
    if config.weight_bound:
114
        logging.info("apply weight bound of %.2f" % config.weight_bound)
115
        new_updates = []
116
        for param, update_value in updates:
117
            bounded_value = (update_value * (T.abs_(update_value) <= config.weight_bound) +
118
                             config.weight_bound * (update_value > config.weight_bound) +
119
                             -config.weight_bound * (update_value < -config.weight_bound))
120
            new_updates.append((param, bounded_value))
121
        updates = new_updates
122
    return updates, free_parameters
123
124
def optimize_function(params, config=None):
125
    """
126
    Create a optimizing function receives gradients.
127
    Parameters:
128
        params - parameters
129
        config - training configuration
130
    Returns:
131
        updating function receives gradients
132
    """
133
    gs = [dim_to_var(p.ndim) for p in params]
134
    updates, _ = optimize_updates(params, gs, config)
135
    return theano.function(gs, [], updates=updates)
136