Total Complexity | 6 |
Total Lines | 24 |
Duplicated Lines | 0 % |
1 | from . import NeuralLayer |
||
4 | class Combine(NeuralLayer): |
||
5 | """ |
||
6 | Combine two variables. |
||
7 | """ |
||
8 | |||
9 | def __init__(self, func, dim=0): |
||
10 | """ |
||
11 | :type layer1: NeuralLayer |
||
12 | :type layer2: NeuralLayer |
||
13 | """ |
||
14 | super(Combine, self).__init__("combine") |
||
15 | self.func = func |
||
16 | if dim > 0: |
||
17 | self.output_dim = dim |
||
18 | |||
19 | def prepare(self): |
||
20 | if self.output_dim == 0: |
||
21 | self.output_dim = self.input_dim |
||
22 | |||
23 | def compute_tensor(self, *tensors): |
||
24 | return self.func(*tensors) |
||
25 | |||
26 | def compute_test_tesnor(self, *tensors): |
||
27 | return self.func(*tensors) |
||
28 |