Total Complexity | 4 |
Total Lines | 14 |
Duplicated Lines | 0 % |
Changes | 0 |
1 | #!/usr/bin/env python |
||
6 | class DimShuffle(NeuralLayer): |
||
7 | """ |
||
8 | DimShuffle layer. |
||
9 | """ |
||
10 | |||
11 | def __init__(self, *pattern): |
||
12 | super(DimShuffle, self).__init__("dimshuffle") |
||
13 | if len(pattern) == 1 and type(pattern[0]) == list: |
||
14 | self.pattern = pattern[0] |
||
15 | else: |
||
16 | self.pattern = pattern |
||
17 | |||
18 | def compute_tensor(self, x): |
||
19 | return x.dimshuffle(*self.pattern) |