OtsuThresh._crop()   F
last analyzed

Complexity

Conditions 23

Size

Total Lines 70
Code Lines 55

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 23
eloc 55
nop 4
dl 0
loc 70
rs 0
c 0
b 0
f 0

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like savu.plugins.segmentation.thresholding.otsu_thresh.OtsuThresh._crop() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# Copyright 2019 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: otsu_thresh
17
   :platform: Unix
18
   :synopsis: Segmentation by thresholding based on Otsu's method.
19
              Optionally calculate cropping values based on the segmented image
20
21
.. moduleauthor:: Jacob Williamson <[email protected]>
22
"""
23
24
25
from savu.plugins.plugin import Plugin
26
from savu.plugins.driver.cpu_plugin import CpuPlugin
27
from savu.plugins.utils import register_plugin
28
29
import itertools
30
import numpy as np
31
from skimage.filters import threshold_otsu
32
from PIL import Image
33
import h5py as h5
34
import os
35
36
@register_plugin
37
class OtsuThresh(Plugin, CpuPlugin):
38
39
    def __init__(self):
40
        super(OtsuThresh, self).__init__("OtsuThresh")
41
42
    def setup(self):
43
44
        in_dataset, out_dataset = self.get_datasets()
45
        in_pData, out_pData = self.get_plugin_datasets()
46
        in_pData[0].plugin_data_setup(self.parameters['pattern'], 'single')
47
48
        for i in range(len(out_dataset)):
49
            out_dataset[0].create_dataset(in_dataset[0], dtype=np.uint8)
50
            out_pData[0].plugin_data_setup(self.parameters['pattern'], 'single')
51
52
        self.cropping = self.parameters["cropping"]
53
        self.buffer = self.parameters["buffer"]
54
        self.directions = self.parameters["directions"]
55
56
        self.shape = in_dataset[0].data_info.get("shape")
57
        self.fully_right, self.fully_below = self.shape[2] - 1, self.shape[1] - 1
58
        self.volume_crop = {"left": 0, "above": 0, "right": self.fully_right, "below": self.fully_below}
59
        self.orig_edges = {"left": 0, "above": 0, "right": self.fully_right, "below": self.fully_below}
60
        self.gap_size = 20
61
62
    def pre_process(self):
63
64
        if "left" in self.directions:
65
            self.volume_crop["left"] = self.fully_right
66
        if "above" in self.directions:
67
            self.volume_crop["above"] = self.fully_below
68
        if "right" in self.directions:
69
            self.volume_crop["right"] = 0
70
        if "below" in self.directions:
71
            self.volume_crop["below"] = 0
72
73
    def process_frames(self, data):
74
        threshold = threshold_otsu(data[0])
75
        thresh_result = (data[0] > threshold) * 1
76
        if self.cropping:
77
            cropped_slice = self._crop(thresh_result, ["left", "above", "right", "below"])
78
            #if self.pcount % (self.shape[0]//10) == 0:
79
            #    self.__save_image(cropped_slice, f"cropped_slices/slice{self.pcount}-cropped")
80
        if self.exp.meta_data.get("pre_run"):
81
            return None
82
        else:
83
            return thresh_result
84
85
    def post_process(self):
86
        if self.cropping:
87
            preview = self._cropping_post_process()
88
            if self.exp.meta_data.get("pre_run"):
89
                self._write_preview_to_file(preview)
90
91
    def nInput_datasets(self):
92
        return 1
93
94
    def nOutput_datasets(self):
95
        if self.exp.meta_data.get("pre_run"):
96
            return 0
97
        else:
98
            return 1
99
100
    def _crop(self, binary_slice, directions, buffer=0):
101
        # For 3 dimensional volume and 2 dimensional slices
102
103
        total_crop = {"left": 0, "above": 0, "right": self.fully_right, "below": self.fully_below}
104
        total_counter = 0
105
        reset_counter = 0
106
        dir_cycle = itertools.cycle(directions)
107
        while reset_counter < 4:
108
            direction = next(dir_cycle)
109
            crops = {"left": 0, "above": 0, "right": binary_slice.shape[1], "below": binary_slice.shape[0]}
110
            if direction in ["above", "below"]:
111
                axis = 1
112
            elif direction in ["left", "right"]:
113
                axis = 0
114
115
            non_zeros_list = np.count_nonzero(binary_slice, axis=axis)
0 ignored issues
show
introduced by
The variable axis does not seem to be defined for all execution paths.
Loading history...
116
            previous = None
117
            fills = []
118
            gaps = []
119
            for i, non_zeros in enumerate(non_zeros_list):
120
                if non_zeros + buffer < binary_slice.shape[axis]:  # check if some zeros (indicating sample)
121
                    if previous != "fill":
122
                        fills.append([])
123
                    fills[-1].append(i)
124
                    previous = "fill"
125
                else:
126
                    if previous != "gap":
127
                        gaps.append([])
128
                    gaps[-1].append(i)
129
                    previous = "gap"
130
131
            #  find largest area of zeros (assume this is where sample is)
132
            largest_fill = []
133
            for fill in fills:
134
                if len(fill) > len(largest_fill):
135
                    largest_fill = fill
136
137
            #  find a reasonably sized gap closest to the sample
138
            if direction in ["left", "above"]:
139
                for gap in gaps:
140
                    if gap[-1] < largest_fill[0]:
141
                        if len(gap) > self.gap_size:
142
                            crops[direction] = gap[-1]
143
                            reset_counter = 0
144
                total_crop[direction] += crops[direction]
145
146
            elif direction in ["right", "below"]:
147
                for gap in gaps[::-1]:
148
                    if gap[0] > largest_fill[-1]:
149
                        if len(gap) > self.gap_size:
150
                            crops[direction] = gap[0] + 1
151
                            reset_counter = 0
152
                total_crop[direction] -= (binary_slice.shape[1 - axis] - crops[direction])
153
154
            #if self.pcount == 0:
155
            #    self.__save_image(binary_slice, f"{total_counter}-{reset_counter}-{direction}")
156
157
            reset_counter += 1
158
            total_counter += 1
159
160
            binary_slice = binary_slice[crops["above"]: crops["below"], crops["left"]: crops["right"]]
161
162
        for direction in self.directions:
163
            if direction in ["left", "above"]:
164
                if total_crop[direction] < self.volume_crop[direction]:
165
                    self.volume_crop[direction] = total_crop[direction]
166
            if direction in ["right", "below"]:
167
                if total_crop[direction] > self.volume_crop[direction]:
168
                    self.volume_crop[direction] = total_crop[direction]
169
        return binary_slice
170
171
    def _cropping_post_process(self):
172
        for direction in self.directions:
173
            if direction in ["left", "above"]:
174
                self.volume_crop[direction] = int(self.volume_crop[direction] - self.buffer)
175
                if self.volume_crop[direction] < self.orig_edges[direction]:
176
                    self.volume_crop[direction] = self.orig_edges[direction]
177
            if direction in ["right", "below"]:
178
                self.volume_crop[direction] = int(self.volume_crop[direction] + self.buffer)
179
                if self.volume_crop[direction] > self.orig_edges[direction]:
180
                    self.volume_crop[direction] = self.orig_edges[direction]
181
        preview = f":, {self.volume_crop['above']}:{self.volume_crop['below']}, {self.volume_crop['left']}:{self.volume_crop['right']}"
182
        self.exp.meta_data.set("pre_run_preview", preview)
183
184
        return preview
185
186
    def _write_preview_to_file(self, preview):
187
        if self.exp.meta_data.get("pre_run"):
188
            folder = self.exp.meta_data['out_path']
189
            fname = self.exp.meta_data.get('datafile_name') + '_pre_run.nxs'
190
            filename = os.path.join(folder, fname)
191
            comm = self.get_communicator()
192
            if comm.rank == 0:
193
                with h5.File(filename, "a") as h5file:
194
                    fsplit = self.exp.meta_data["data_path"].split("/")
195
                    fsplit[-1] = ""
196
                    stats_path = "/".join(fsplit)
197
                    preview_group = h5file.require_group(stats_path)
198
                    preview_group.create_dataset("preview", data=preview)
199
200
    def __save_image(self, binary_slice, name):
201
        #  just used for testing
202
        binary_slice = binary_slice.astype(np.uint8)*150
203
        im = Image.fromarray(binary_slice)
204
205
        im.save(f"/scratch/Savu/images/{name}.jpeg")