NeuralTensorNet   A
last analyzed

Complexity

Total Complexity 2

Size/Duplication

Total Lines 7
Duplicated Lines 0 %

Importance

Changes 3
Bugs 0 Features 0
Metric Value
c 3
b 0
f 0
dl 0
loc 7
rs 10
wmc 2

2 Methods

Rating   Name   Duplication   Size   Complexity  
A __getattr__() 0 5 2
A wrapper() 0 3 1
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import theano.tensor as T
5
from theano import tensor as theano_tensor
6
7
from deepy.core.tensor_conversion import neural_computation
8
from deepy.core.neural_var import NeuralVariable
9
10
11
class NeuralTensorNet(object):
12
13
    def __getattr__(self, func_name):
14
        @neural_computation
15
        def wrapper(*args, **kwargs):
16
            return getattr(theano_tensor.nnet, func_name)(*args, **kwargs)
17
        return wrapper
18
19
deepy_nnet = NeuralTensorNet()
20
21
class NeuralTensor(object):
22
    """
23
    A class for exporting Theano tensor operations to neural variables.
24
25
    """
26
27
    def constant(self, value, dtype="float32", dim=None):
28
        return NeuralVariable(T.constant(value, dtype=dtype), dim=dim)
29
30
    def __getattr__(self, func_name):
31
        global deepy_nnet
32
        @neural_computation
33
        def wrapper(*args, **kwargs):
34
            return getattr(theano_tensor, func_name)(*args, **kwargs)
35
        if func_name == 'nnet':
36
            return deepy_nnet
37
        else:
38
            return wrapper
39
40
41
deepy_tensor = NeuralTensor()
42
tensor = deepy_tensor
43
44