| Total Complexity | 11 |
| Total Lines | 42 |
| Duplicated Lines | 0 % |
| 1 | #!/usr/bin/env python |
||
| 8 | class NeuralVariable(NeuralLayer): |
||
| 9 | """ |
||
| 10 | Create a constant layer with tensors. |
||
| 11 | """ |
||
| 12 | |||
| 13 | def __init__(self, tensor, test_tensor=None, dim=0): |
||
| 14 | """ |
||
| 15 | Create a tensor layer. |
||
| 16 | """ |
||
| 17 | super(NeuralVariable, self).__init__("const") |
||
| 18 | self.output_dim = dim |
||
| 19 | self.tensor = tensor |
||
| 20 | self.test_tensor = tensor if not test_tensor else test_tensor |
||
| 21 | self.initialize(0) |
||
| 22 | |||
| 23 | def __getitem__(self, index): |
||
| 24 | @neural_computation |
||
| 25 | def getitem_wrapper(t, index): |
||
| 26 | return t[index] |
||
| 27 | return getitem_wrapper(self, index) |
||
| 28 | |||
| 29 | def apply(self, func, dim=None): |
||
| 30 | """ |
||
| 31 | Apply a function to tensors. |
||
| 32 | """ |
||
| 33 | output_dim = dim if dim else self.output_dim |
||
| 34 | return NeuralVariable(func(self.tensor), func(self.test_tensor), output_dim) |
||
| 35 | |||
| 36 | def compute_tensor(self, x): |
||
| 37 | return self.tensor |
||
| 38 | |||
| 39 | def compute_test_tesnor(self, x): |
||
| 40 | return self.test_tensor |
||
| 41 | |||
| 42 | def set_test_value(self, value): |
||
| 43 | self.tensor.tag.test_value = value |
||
| 44 | |||
| 45 | def dim(self): |
||
| 46 | return self.output_dim |
||
| 47 | |||
| 48 | def shape(self, dim_index): |
||
| 49 | return self.tensor.shape[dim_index] |
||
| 50 |