Total Complexity | 4 |
Total Lines | 18 |
Duplicated Lines | 0 % |
Changes | 3 | ||
Bugs | 0 | Features | 0 |
1 | #!/usr/bin/env python |
||
19 | class NeuralTensor(object): |
||
20 | """ |
||
21 | A class for exporting Theano tensor operations to neural variables. |
||
22 | |||
23 | """ |
||
24 | |||
25 | def constant(self, value, dtype="float32", dim=None): |
||
26 | return NeuralVariable(T.constant(value, dtype=dtype), dim=dim) |
||
27 | |||
28 | def __getattr__(self, func_name): |
||
29 | global deepy_nnet |
||
30 | @neural_computation |
||
31 | def wrapper(*args, **kwargs): |
||
32 | return getattr(theano_tensor, func_name)(*args, **kwargs) |
||
33 | if func_name == 'nnet': |
||
34 | return deepy_nnet |
||
35 | else: |
||
36 | return wrapper |
||
37 | |||
42 |