DimShuffle.__init__()   A
last analyzed

Complexity

Conditions 3

Size

Total Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
dl 0
loc 6
rs 9.4285
c 0
b 0
f 0
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
from . import NeuralLayer
5
6
class DimShuffle(NeuralLayer):
7
    """
8
    DimShuffle layer.
9
    """
10
11
    def __init__(self, *pattern):
12
        super(DimShuffle, self).__init__("dimshuffle")
13
        if len(pattern) == 1 and type(pattern[0]) == list:
14
            self.pattern = pattern[0]
15
        else:
16
            self.pattern = pattern
17
18
    def compute_tensor(self, x):
19
        return x.dimshuffle(*self.pattern)