get_activation()   F
last analyzed

Complexity

Conditions 18

Size

Total Lines 40

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 18
c 1
b 0
f 0
dl 0
loc 40
rs 2.7087

1 Method

Rating   Name   Duplication   Size   Complexity  
A compose() 0 4 2

How to fix   Complexity   

Complexity

Complex classes like get_activation() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
5
import functools
6
7
import theano.tensor as T
8
from theano.tensor.shared_randomstreams import RandomStreams
9
10
from deepy.core.env import FLOATX
11
12
theano_rng = RandomStreams(seed=3)
13
14
15
def add_noise(x, sigma, rho):
16
    if sigma > 0 and rho > 0:
17
        noise = theano_rng.normal(size=x.shape, std=sigma, dtype=FLOATX)
18
        mask = theano_rng.binomial(size=x.shape, n=1, p=1-rho, dtype=FLOATX)
19
        return mask * (x + noise)
20
    if sigma > 0:
21
        return x + theano_rng.normal(size=x.shape, std=sigma, dtype=FLOATX)
22
    if rho > 0:
23
        mask = theano_rng.binomial(size=x.shape, n=1, p=1-rho, dtype=FLOATX)
24
        return mask * x
25
    return x
26
27
def softmax(x):
28
    # T.nnet.softmax doesn't work with the HF trainer.
29
    z = T.exp(x.T - x.T.max(axis=0))
30
    return (z / z.sum(axis=0)).T
31
32
def get_activation(act=None):
33
        def compose(a, b):
34
            c = lambda z: b(a(z))
35
            c.__theanets_name__ = '%s(%s)' % (b.__theanets_name__, a.__theanets_name__)
36
            return c
37
        if '+' in act:
38
            return functools.reduce(
39
                compose, (get_activation(a) for a in act.split('+')))
40
        options = {
41
            'tanh': T.tanh,
42
            'linear': lambda z: z,
43
            'logistic': T.nnet.sigmoid,
44
            'sigmoid': T.nnet.sigmoid,
45
            'hard_sigmoid': T.nnet.hard_sigmoid,
46
            'softplus': T.nnet.softplus,
47
            'softmax': softmax,
48
            'theano_softmax': T.nnet.softmax,
49
50
            # shorthands
51
            'relu': lambda z: T.nnet.relu(z),
52
            'leaky_relu': lambda z: T.nnet.relu(z, 0.01),
53
            'trel': lambda z: z * (z > 0) * (z < 1),
54
            'trec': lambda z: z * (z > 1),
55
            'tlin': lambda z: z * (abs(z) > 1),
56
57
            # modifiers
58
            'rect:max': lambda z: T.minimum(1, z),
59
            'rect:min': lambda z: T.maximum(0, z),
60
61
            # normalization
62
            'norm:dc': lambda z: (z.T - z.mean(axis=1)).T,
63
            'norm:max': lambda z: (z.T / T.maximum(1e-10, abs(z).max(axis=1))).T,
64
            'norm:std': lambda z: (z.T / T.maximum(1e-10, T.std(z, axis=1))).T,
65
            }
66
        for k, v in options.items():
67
            v.__theanets_name__ = k
68
        try:
69
            return options[act]
70
        except KeyError:
71
            raise KeyError('unknown activation %r' % act)
72