Test Failed
Pull Request — master (#806)
by Nicola
03:35
created

savu.plugins.reshape.sum_dimension   A

Complexity

Total Complexity 6

Size/Duplication

Total Lines 75
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 38
dl 0
loc 75
rs 10
c 0
b 0
f 0
wmc 6

6 Methods

Rating   Name   Duplication   Size   Complexity  
A SumDimension.nInput_datasets() 0 2 1
A SumDimension.__init__() 0 2 1
A SumDimension.process_frames() 0 2 1
A SumDimension.pre_process() 0 4 1
A SumDimension.nOutput_datasets() 0 2 1
A SumDimension.setup() 0 22 1
1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: sum_dimension
17
   :platform: Unix
18
   :synopsis: Sum a chosen dimension of the data.
19
20
.. moduleauthor:: Nicola Wadeson <[email protected]>
21
22
"""
23
24
import copy
25
import logging
26
import numpy as np
27
28
from savu.plugins.plugin import Plugin
29
from savu.plugins.utils import register_plugin
30
from savu.plugins.driver.cpu_plugin import CpuPlugin
31
32
33
@register_plugin
34
class SumDimension(Plugin, CpuPlugin):
35
36
    def __init__(self):
37
        super(SumDimension, self).__init__('SumDimension')
38
39
    def pre_process(self):
40
        in_pData = self.get_plugin_in_datasets()[0]
41
        self.sum_dim = in_pData.get_data_dimension_by_axis_label(
42
                self.parameters['axis_label'])
43
44
    def process_frames(self, data):
45
        return np.sum(data[0], axis=self.sum_dim)
46
47
    def setup(self):
48
        in_dataset, out_dataset = self.get_datasets()
49
50
        rm_label = self.parameters['axis_label']
51
        rm_dim = in_dataset[0].get_data_dimension_by_axis_label(rm_label)
52
        patterns = ['*.' + str(rm_dim)]    
53
54
        axis_labels = copy.copy(in_dataset[0].get_axis_labels())
55
        del axis_labels[rm_dim]
56
57
        shape = list(in_dataset[0].get_shape())
58
        del shape[rm_dim]
59
60
        out_dataset[0].create_dataset(
61
                patterns={in_dataset[0]: patterns},
62
                axis_labels=axis_labels,
63
                shape=tuple(shape))
64
65
        pattern = self.parameters['pattern']
66
        in_pData, out_pData = self.get_plugin_datasets()
67
        in_pData[0].plugin_data_setup(pattern, 'multiple')
68
        out_pData[0].plugin_data_setup(pattern, 'multiple')
69
70
    def nInput_datasets(self):
71
        return 1
72
73
    def nOutput_datasets(self):
74
        return 1
75
76