| Total Complexity | 6 |
| Total Lines | 24 |
| Duplicated Lines | 0 % |
| 1 | from . import NeuralLayer |
||
| 4 | class Combine(NeuralLayer): |
||
| 5 | """ |
||
| 6 | Combine two variables. |
||
| 7 | """ |
||
| 8 | |||
| 9 | def __init__(self, func, dim=0): |
||
| 10 | """ |
||
| 11 | :type layer1: NeuralLayer |
||
| 12 | :type layer2: NeuralLayer |
||
| 13 | """ |
||
| 14 | super(Combine, self).__init__("combine") |
||
| 15 | self.func = func |
||
| 16 | if dim > 0: |
||
| 17 | self.output_dim = dim |
||
| 18 | |||
| 19 | def prepare(self): |
||
| 20 | if self.output_dim == 0: |
||
| 21 | self.output_dim = self.input_dim |
||
| 22 | |||
| 23 | def output(self, *tensors): |
||
| 24 | return self.func(*tensors) |
||
| 25 | |||
| 26 | def test_output(self, *tensors): |
||
| 27 | return self.func(*tensors) |
||
| 28 |