Test Failed
Pull Request — master (#916)
by Daniil
05:04
created

MinAndMaxDeprecated.setup()   A

Complexity

Conditions 3

Size

Total Lines 26
Code Lines 21

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 21
nop 1
dl 0
loc 26
rs 9.376
c 0
b 0
f 0
1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""
15
.. module:: min_and_max
16
   :platform: Unix
17
   :synopsis: A plugin to calculate the min and max of each frame
18
.. moduleauthor:: Nicola Wadeson <[email protected]>
19
"""
20
import logging
21
import numpy as np
22
23
from scipy.ndimage import gaussian_filter
24
from savu.plugins.plugin import Plugin
25
from savu.plugins.utils import register_plugin
26
from savu.plugins.driver.cpu_plugin import CpuPlugin
27
import savu.core.utils as cu
28
29
30
@register_plugin
31
class MinAndMaxDeprecated(Plugin, CpuPlugin):
32
33
    def __init__(self):
34
        super(MinAndMaxDeprecated, self).__init__("MinAndMax")
35
36
    def circle_mask(self, width, ratio):
37
        # Create a circle mask.
38
        mask = np.zeros((width, width), dtype=np.float32)
39
        center = width // 2
40
        radius = ratio * center
41
        y, x = np.ogrid[-center:width - center, -center:width - center]
42
        mask_check = x * x + y * y <= radius * radius
43
        mask[mask_check] = 1.0
44
        return mask
45
46
    def pre_process(self):
47
        in_pData = self.get_plugin_in_datasets()[0]
48
        in_meta_data = self.get_in_meta_data()[0]
49
        data = self.get_in_datasets()[0]
50
        data_shape = data.get_shape()
51
        width = data_shape[0]
52
        self.use_mask = self.parameters['masking']
53
        self.data_pattern = self.parameters['pattern']
54
        self.mask = np.ones((width, width), dtype=np.float32)
55
        if self.use_mask is True:
56
            ratio = self.parameters['ratio']
57
            if ratio is None:
58
                try:
59
                    cor = np.min(in_meta_data.get('centre_of_rotation'))
60
                    ratio = (min(cor, abs(width - cor))) / (width * 0.5)
61
                except KeyError:
62
                    ratio = 1.0
63
            self.mask = self.circle_mask(width, ratio)
64
        self.method = self.parameters['method']
65
        if not (self.method == 'percentile' or self.method == 'extrema'):
66
            msg = "\n***********************************************\n" \
67
                  "!!! ERROR !!! -> Wrong method. Please use only one of " \
68
                  "the provided options \n" \
69
                  "***********************************************\n"
70
            logging.warning(msg)
71
            cu.user_message(msg)
72
            raise ValueError(msg)
73
        self.p_min, self.p_max = np.sort(np.clip(np.asarray(
74
            self.parameters['p_range'], dtype=np.float32), 0.0, 100.0))
75
76
    def process_frames(self, data):
77
        use_filter = self.parameters['smoothing']
78
        frame = np.nan_to_num(data[0])
79
        if use_filter is True:
80
            frame = gaussian_filter(frame, (3, 3))
81
        if (self.use_mask is True) and (self.data_pattern == 'VOLUME_XZ') \
82
                and (self.mask.shape == frame.shape):
83
            frame = frame * self.mask
84
        if self.method == 'percentile':
85
            list_out = [np.array(
86
                [np.percentile(frame, self.p_min)], dtype=np.float32),
87
                np.array([np.percentile(frame, self.p_max)], dtype=np.float32)]
88
        else:
89
            list_out = [np.array([np.min(frame)], dtype=np.float32),
90
                        np.array([np.max(frame)], dtype=np.float32)]
91
        return list_out
92
93
    def post_process(self):
94
        in_datasets, out_datasets = self.get_datasets()
95
        the_min = np.squeeze(out_datasets[0].data[...])
96
        the_max = np.squeeze(out_datasets[1].data[...])
97
        pattern = self._get_pattern()
98
        in_datasets[0].meta_data.set(['stats', 'min', pattern], the_min)
99
        in_datasets[0].meta_data.set(['stats', 'max', pattern], the_max)
100
101
    def setup(self):
102
        in_dataset, out_datasets = self.get_datasets()
103
        in_pData, out_pData = self.get_plugin_datasets()
104
        try:
105
            in_pData[0].plugin_data_setup(self._get_pattern(), 'single')
106
        except:
107
            msg = "\n***************************************************" \
108
                  "**********\nCan't find the data pattern: {}.\nThe pattern " \
109
                  "parameter of this plugin must be relevant to its \n" \
110
                  "previous plugin\n****************************************" \
111
                  "*********************\n".format(self._get_pattern())
112
            logging.warning(msg)
113
            cu.user_message(msg)
114
            raise ValueError(msg)
115
116
        slice_dirs = list(in_dataset[0].get_slice_dimensions())
117
        orig_shape = in_dataset[0].get_shape()
118
        new_shape = (np.prod(np.array(orig_shape)[slice_dirs]), 1)
119
120
        labels = ['x.pixels', 'y.pixels']
121
        for i in range(len(out_datasets)):
122
            out_datasets[i].create_dataset(shape=new_shape, axis_labels=labels,
123
                                           remove=True, transport='hdf5')
124
            out_datasets[i].add_pattern(
125
                "METADATA", core_dims=(1,), slice_dims=(0,))
126
            out_pData[i].plugin_data_setup('METADATA', 'single')
127
128
    def _get_pattern(self):
129
        return self.parameters['pattern']
130
131
    def nOutput_datasets(self):
132
        return 2
133