Completed
Push — master ( 394368...090fba )
by Raphael
01:33
created

build_node_name()   B

Complexity

Conditions 7

Size

Total Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 7
c 1
b 0
f 0
dl 0
loc 12
rs 7.3333
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import copy
5
import logging as loggers
6
import os
7
import re
8
9
import numpy as np
10
import theano
11
import theano.tensor as T
12
from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
13
from theano.tensor.shared_randomstreams import RandomStreams as SharedRandomStreams
14
15
logging = loggers.getLogger(__name__)
16
logging.setLevel(loggers.INFO)
17
18
"""
19
This file is deprecated.
20
"""
21
22
FLOATX = theano.config.floatX
23
EPSILON = T.constant(1.0e-15, dtype=FLOATX)
24
BIG_EPSILON = T.constant(1.0e-7, dtype=FLOATX)
25
26
if 'DEEPY_SEED' in os.environ:
27
    global_seed = int(os.environ['DEEPY_SEED'])
28
    logging.info("set global random seed to %d" % global_seed)
29
else:
30
    global_seed = 3
31
global_rand = np.random.RandomState(seed=global_seed)
32
global_theano_rand = RandomStreams(seed=global_seed)
33
global_shared_rand = SharedRandomStreams(seed=global_seed)
34
35
36
def make_float_matrices(*names):
37
    ret = []
38
    for n in names:
39
        ret.append(T.matrix(n, dtype=FLOATX))
40
    return ret
41
42
43
def make_float_vectors(*names):
44
    ret = []
45
    for n in names:
46
        ret.append(T.vector(n, dtype=FLOATX))
47
    return ret
48
49
50
def back_grad(jacob, err_g):
51
    return T.dot(jacob, err_g)
52
    # return (jacob.T * err_g).T
53
54
def build_node_name(n):
55
    if "owner" not in dir(n) or "inputs" not in dir(n.owner):
56
        return str(n)
57
    else:
58
        op_name = str(n.owner.op)
59
        if "{" not in op_name:
60
            op_name = "Elemwise{%s}" % op_name
61
        if "," in op_name:
62
            op_name = re.sub(r"\{([^}]+),[^}]+\}", "{\\1}", op_name)
63
        if "_" in op_name:
64
            op_name = re.sub(r"\{[^}]+_([^_}]+)\}", "{\\1}", op_name)
65
        return "%s(%s)" % (op_name, ",".join([build_node_name(m) for m in n.owner.inputs]))
66
67
68
class VarMap():
69
    def __init__(self):
70
        self.varmap = {}
71
72
    def __get__(self, instance, owner):
73
        if instance not in self.varmap:
74
            return None
75
        else:
76
            return self.varmap[instance]
77
78
    def __set__(self, instance, value):
79
        self.varmap[instance] = value
80
81
    def __contains__(self, item):
82
        return item in self.varmap
83
84
    def update_if_not_existing(self, name, value):
85
        if name not in self.varmap:
86
            self.varmap[name] = value
87
88
    def get(self, name):
89
        return self.varmap[name]
90
91
    def set(self, name, value):
92
        self.varmap[name] = value