ada_family_core()   F
last analyzed

Complexity

Conditions 15

Size

Total Lines 49

Duplication

Lines 0
Ratio 0 %

Importance

Changes 3
Bugs 0 Features 0
Metric Value
cc 15
c 3
b 0
f 0
dl 0
loc 49
rs 2.5785

How to fix   Complexity   

Complexity

Complex classes like ada_family_core() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
from collections import OrderedDict
5
import inspect
6
import numpy as np
7
import theano
8
from theano import tensor as T
9
from deepy.trainers.optimize import logging
10
from deepy.core.env import FLOATX
11
12
13
def ada_family_core(params, gparams, learning_rate = 0.01, eps= 1e-6, rho=0.95, method="ADADELTA",
14
                        beta=0.0, gsum_regularization = 0.0001):
15
    """
16
    Optimize by SGD, AdaGrad, or AdaDelta.
17
    """
18
19
    _, _, _, args = inspect.getargvalues(inspect.currentframe())
20
    logging.info("ada_family_core: %s" % str(args.items()))
21
    free_parameters = []
22
23
    if method == "FINETUNING_ADAGRAD":
24
        method = "ADAGRAD"
25
        gsum_regularization = 0
26
27
    oneMinusBeta = 1 - beta
28
29
    gsums   = [theano.shared(np.zeros_like(param.get_value(borrow=True), dtype=FLOATX), name="gsum_%s" % param.name) if (method == 'ADADELTA' or method == 'ADAGRAD') else None for param in params]
30
    xsums   = [theano.shared(np.zeros_like(param.get_value(borrow=True), dtype=FLOATX), name="xsum_%s" % param.name) if method == 'ADADELTA' else None for param in params]
31
32
    # Fix for AdaGrad, init gsum to 1
33
    if method == 'ADAGRAD':
34
        for gsum in gsums:
35
            gsum.set_value(gsum.get_value() ** 0)
36
37
    updates = OrderedDict()
38
    # Updates
39
    for gparam, param, gsum, xsum in zip(gparams, params, gsums, xsums):
40
41
        if method == 'ADADELTA':
42
            updates[gsum] = rho * gsum + (1. - rho) * (gparam **2)
43
            dparam = -T.sqrt((xsum + eps) / (updates[gsum] + eps)) * gparam
44
            updates[xsum] =rho * xsum + (1. - rho) * (dparam **2)
45
            updates[param] = param * oneMinusBeta + dparam
46
        elif method == 'ADAGRAD':
47
            updates[gsum] = gsum + (gparam **2) - gsum_regularization * gsum
48
            updates[param] =  param * oneMinusBeta - learning_rate * (gparam / (T.sqrt(updates[gsum] + eps)))
49
50
        else:
51
            updates[param] = param * oneMinusBeta - gparam * learning_rate
52
    # Add free parameters
53
    if method == 'ADADELTA':
54
        free_parameters.extend(gsums + xsums)
55
    elif method == 'ADAGRAD':
56
        free_parameters.extend(gsums)
57
    # Check dtype
58
    for k in updates:
59
        if updates[k].dtype != FLOATX:
60
            updates[k] = updates[k].astype(FLOATX)
61
    return updates.items(), free_parameters
62