Completed
Push — master ( f73e69...91b7c0 )
by Raphael
01:35
created

KaimingHeInitializer.__init__()   A

Complexity

Conditions 1

Size

Total Lines 8

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 8
rs 9.4285
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import numpy as np
5
from deepy.utils import global_rand
6
7
def get_fans(shape):
8
    fan_in = shape[0] if len(shape) == 2 else np.prod(shape[1:])
9
    fan_out = shape[1] if len(shape) == 2 else shape[0]
10
    return fan_in, fan_out
11
12
class WeightInitializer(object):
13
    """
14
    Initializer for creating weights.
15
    """
16
17
    def __init__(self, seed=None):
18
        if not seed:
19
            self.rand = global_rand
20
        else:
21
            self.rand = np.random.RandomState(seed)
22
23
    def sample(self, shape):
24
        """
25
        Sample parameters with given shape.
26
        """
27
        raise NotImplementedError
28
29
class UniformInitializer(WeightInitializer):
30
    """
31
    Uniform weight sampler.
32
    """
33
34
    def __init__(self, scale=None, svd=False, seed=None):
35
        super(UniformInitializer, self).__init__(seed)
36
        self.scale = scale
37
        self.svd = svd
38
39
    def sample(self, shape):
40
        if not self.scale:
41
            scale = np.sqrt(6. / sum(get_fans(shape)))
42
        else:
43
            scale = self.scale
44
        weight = self.rand.uniform(-1, 1, size=shape) * scale
45
        if self.svd:
46
            norm = np.sqrt((weight**2).sum())
47
            ws = scale * weight / norm
48
            _, v, _ = np.linalg.svd(ws)
49
            ws = scale * ws / v[0]
50
        return weight
51
52
class GaussianInitializer(WeightInitializer):
53
    """
54
    Gaussian weight sampler.
55
    """
56
57
    def __init__(self, mean=0, deviation=0.01, seed=None):
58
        super(GaussianInitializer, self).__init__(seed)
59
        self.mean = mean
60
        self.deviation = deviation
61
62
    def sample(self, shape):
63
        weight = self.rand.normal(self.mean, self.deviation, size=shape)
64
        return weight
65
66
class IdentityInitializer(WeightInitializer):
67
    """
68
    Initialize weight as identity matrices.
69
    """
70
71
    def __init__(self, scale=1):
72
        super(IdentityInitializer, self).__init__()
73
        self.scale = 1
74
75
    def sample(self, shape):
76
        assert len(shape) == 2
77
        return np.eye(*shape) * self.scale
78
79
class XavierGlorotInitializer(WeightInitializer):
80
    """
81
    Xavier Glorot's weight initializer.
82
    See http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
83
    """
84
85
    def __init__(self, uniform=False, seed=None):
86
        """
87
        Parameters:
88
            uniform - uniform distribution, default Gaussian
89
            seed - random seed
90
        """
91
        super(XavierGlorotInitializer, self).__init__(seed)
92
        self.uniform = uniform
93
94
    def sample(self, shape):
95
        scale = np.sqrt(2. / sum(get_fans(shape)))
96
        if self.uniform:
97
            return self.rand.uniform(-1, 1, size=shape) * scale
98
        else:
99
            return self.rand.randn(*shape) * scale
100
101
class KaimingHeInitializer(WeightInitializer):
102
    """
103
    Kaiming He's initialization scheme, especially made for ReLU.
104
    See http://arxiv.org/abs/1502.01852.
105
    """
106
    def __init__(self, uniform=False, seed=None):
107
        """
108
        Parameters:
109
            uniform - uniform distribution, default Gaussian
110
            seed - random seed
111
        """
112
        super(KaimingHeInitializer, self).__init__(seed)
113
        self.uniform = uniform
114
115
    def sample(self, shape):
116
        fan_in, fan_out = get_fans(shape)
117
        scale = np.sqrt(2. / fan_in)
118
        if self.uniform:
119
            return self.rand.uniform(-1, 1, size=shape) * scale
120
        else:
121
            return self.rand.randn(*shape) * scale
122
123
class OrthogonalInitializer(WeightInitializer):
124
    """
125
    Orthogonal weight initializer.
126
    """
127
    def __init__(self, scale=1.1, seed=None):
128
        """
129
        Parameters:
130
            scale - scale
131
            seed - random seed
132
        """
133
        super(OrthogonalInitializer, self).__init__(seed)
134
        self.scale = scale
135
136
    def sample(self, shape):
137
        flat_shape = (shape[0], np.prod(shape[1:]))
138
        a = np.random.normal(0.0, 1.0, flat_shape)
139
        u, _, v = np.linalg.svd(a, full_matrices=False)
140
        q = u if u.shape == flat_shape else v
141
        q = q.reshape(shape)
142
        return self.scale * q[:shape[0], :shape[1]]
143