|
1
|
|
|
# Copyright 2019 Diamond Light Source Ltd. |
|
2
|
|
|
# |
|
3
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
4
|
|
|
# you may not use this file except in compliance with the License. |
|
5
|
|
|
# You may obtain a copy of the License at |
|
6
|
|
|
# |
|
7
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
8
|
|
|
# |
|
9
|
|
|
# Unless required by applicable law or agreed to in writing, software |
|
10
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
11
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
12
|
|
|
# See the License for the specific language governing permissions and |
|
13
|
|
|
# limitations under the License. |
|
14
|
|
|
|
|
15
|
|
|
""" |
|
16
|
|
|
.. module:: otsu_thresh |
|
17
|
|
|
:platform: Unix |
|
18
|
|
|
:synopsis: Segmentation by thresholding based on Otsu's method. |
|
19
|
|
|
Optionally calculate cropping values based on the segmented image |
|
20
|
|
|
|
|
21
|
|
|
.. moduleauthor:: Jacob Williamson <[email protected]> |
|
22
|
|
|
""" |
|
23
|
|
|
|
|
24
|
|
|
|
|
25
|
|
|
from savu.plugins.plugin import Plugin |
|
26
|
|
|
from savu.plugins.driver.cpu_plugin import CpuPlugin |
|
27
|
|
|
from savu.plugins.utils import register_plugin |
|
28
|
|
|
|
|
29
|
|
|
import itertools |
|
30
|
|
|
import numpy as np |
|
31
|
|
|
from skimage.filters import threshold_otsu |
|
32
|
|
|
from PIL import Image |
|
33
|
|
|
import h5py as h5 |
|
34
|
|
|
import os |
|
35
|
|
|
|
|
36
|
|
|
@register_plugin |
|
37
|
|
|
class OtsuThresh(Plugin, CpuPlugin): |
|
38
|
|
|
|
|
39
|
|
|
def __init__(self): |
|
40
|
|
|
super(OtsuThresh, self).__init__("OtsuThresh") |
|
41
|
|
|
|
|
42
|
|
|
def setup(self): |
|
43
|
|
|
|
|
44
|
|
|
in_dataset, out_dataset = self.get_datasets() |
|
45
|
|
|
in_pData, out_pData = self.get_plugin_datasets() |
|
46
|
|
|
in_pData[0].plugin_data_setup(self.parameters['pattern'], 'single') |
|
47
|
|
|
|
|
48
|
|
|
for i in range(len(out_dataset)): |
|
49
|
|
|
out_dataset[0].create_dataset(in_dataset[0], dtype=np.uint8) |
|
50
|
|
|
out_pData[0].plugin_data_setup(self.parameters['pattern'], 'single') |
|
51
|
|
|
|
|
52
|
|
|
self.cropping = self.parameters["cropping"] |
|
53
|
|
|
self.buffer = self.parameters["buffer"] |
|
54
|
|
|
self.directions = self.parameters["directions"] |
|
55
|
|
|
|
|
56
|
|
|
self.shape = in_dataset[0].data_info.get("shape") |
|
57
|
|
|
self.fully_right, self.fully_below = self.shape[2] - 1, self.shape[1] - 1 |
|
58
|
|
|
self.volume_crop = {"left": 0, "above": 0, "right": self.fully_right, "below": self.fully_below} |
|
59
|
|
|
self.orig_edges = {"left": 0, "above": 0, "right": self.fully_right, "below": self.fully_below} |
|
60
|
|
|
self.gap_size = 20 |
|
61
|
|
|
|
|
62
|
|
|
def pre_process(self): |
|
63
|
|
|
|
|
64
|
|
|
if "left" in self.directions: |
|
65
|
|
|
self.volume_crop["left"] = self.fully_right |
|
66
|
|
|
if "above" in self.directions: |
|
67
|
|
|
self.volume_crop["above"] = self.fully_below |
|
68
|
|
|
if "right" in self.directions: |
|
69
|
|
|
self.volume_crop["right"] = 0 |
|
70
|
|
|
if "below" in self.directions: |
|
71
|
|
|
self.volume_crop["below"] = 0 |
|
72
|
|
|
|
|
73
|
|
|
def process_frames(self, data): |
|
74
|
|
|
threshold = threshold_otsu(data[0]) |
|
75
|
|
|
thresh_result = (data[0] > threshold) * 1 |
|
76
|
|
|
if self.cropping: |
|
77
|
|
|
cropped_slice = self._crop(thresh_result, ["left", "above", "right", "below"]) |
|
78
|
|
|
#if self.pcount % (self.shape[0]//10) == 0: |
|
79
|
|
|
# self.__save_image(cropped_slice, f"cropped_slices/slice{self.pcount}-cropped") |
|
80
|
|
|
if self.exp.meta_data.get("pre_run"): |
|
81
|
|
|
return None |
|
82
|
|
|
else: |
|
83
|
|
|
return thresh_result |
|
84
|
|
|
|
|
85
|
|
|
def post_process(self): |
|
86
|
|
|
if self.cropping: |
|
87
|
|
|
preview = self._cropping_post_process() |
|
88
|
|
|
if self.exp.meta_data.get("pre_run"): |
|
89
|
|
|
self._write_preview_to_file(preview) |
|
90
|
|
|
|
|
91
|
|
|
def nInput_datasets(self): |
|
92
|
|
|
return 1 |
|
93
|
|
|
|
|
94
|
|
|
def nOutput_datasets(self): |
|
95
|
|
|
if self.exp.meta_data.get("pre_run"): |
|
96
|
|
|
return 0 |
|
97
|
|
|
else: |
|
98
|
|
|
return 1 |
|
99
|
|
|
|
|
100
|
|
|
def _crop(self, binary_slice, directions, buffer=0): |
|
101
|
|
|
# For 3 dimensional volume and 2 dimensional slices |
|
102
|
|
|
|
|
103
|
|
|
total_crop = {"left": 0, "above": 0, "right": self.fully_right, "below": self.fully_below} |
|
104
|
|
|
total_counter = 0 |
|
105
|
|
|
reset_counter = 0 |
|
106
|
|
|
dir_cycle = itertools.cycle(directions) |
|
107
|
|
|
while reset_counter < 4: |
|
108
|
|
|
direction = next(dir_cycle) |
|
109
|
|
|
crops = {"left": 0, "above": 0, "right": binary_slice.shape[1], "below": binary_slice.shape[0]} |
|
110
|
|
|
if direction in ["above", "below"]: |
|
111
|
|
|
axis = 1 |
|
112
|
|
|
elif direction in ["left", "right"]: |
|
113
|
|
|
axis = 0 |
|
114
|
|
|
|
|
115
|
|
|
non_zeros_list = np.count_nonzero(binary_slice, axis=axis) |
|
|
|
|
|
|
116
|
|
|
previous = None |
|
117
|
|
|
fills = [] |
|
118
|
|
|
gaps = [] |
|
119
|
|
|
for i, non_zeros in enumerate(non_zeros_list): |
|
120
|
|
|
if non_zeros + buffer < binary_slice.shape[axis]: # check if some zeros (indicating sample) |
|
121
|
|
|
if previous != "fill": |
|
122
|
|
|
fills.append([]) |
|
123
|
|
|
fills[-1].append(i) |
|
124
|
|
|
previous = "fill" |
|
125
|
|
|
else: |
|
126
|
|
|
if previous != "gap": |
|
127
|
|
|
gaps.append([]) |
|
128
|
|
|
gaps[-1].append(i) |
|
129
|
|
|
previous = "gap" |
|
130
|
|
|
|
|
131
|
|
|
# find largest area of zeros (assume this is where sample is) |
|
132
|
|
|
largest_fill = [] |
|
133
|
|
|
for fill in fills: |
|
134
|
|
|
if len(fill) > len(largest_fill): |
|
135
|
|
|
largest_fill = fill |
|
136
|
|
|
|
|
137
|
|
|
# find a reasonably sized gap closest to the sample |
|
138
|
|
|
if direction in ["left", "above"]: |
|
139
|
|
|
for gap in gaps: |
|
140
|
|
|
if gap[-1] < largest_fill[0]: |
|
141
|
|
|
if len(gap) > self.gap_size: |
|
142
|
|
|
crops[direction] = gap[-1] |
|
143
|
|
|
reset_counter = 0 |
|
144
|
|
|
total_crop[direction] += crops[direction] |
|
145
|
|
|
|
|
146
|
|
|
elif direction in ["right", "below"]: |
|
147
|
|
|
for gap in gaps[::-1]: |
|
148
|
|
|
if gap[0] > largest_fill[-1]: |
|
149
|
|
|
if len(gap) > self.gap_size: |
|
150
|
|
|
crops[direction] = gap[0] + 1 |
|
151
|
|
|
reset_counter = 0 |
|
152
|
|
|
total_crop[direction] -= (binary_slice.shape[1 - axis] - crops[direction]) |
|
153
|
|
|
|
|
154
|
|
|
#if self.pcount == 0: |
|
155
|
|
|
# self.__save_image(binary_slice, f"{total_counter}-{reset_counter}-{direction}") |
|
156
|
|
|
|
|
157
|
|
|
reset_counter += 1 |
|
158
|
|
|
total_counter += 1 |
|
159
|
|
|
|
|
160
|
|
|
binary_slice = binary_slice[crops["above"]: crops["below"], crops["left"]: crops["right"]] |
|
161
|
|
|
|
|
162
|
|
|
for direction in self.directions: |
|
163
|
|
|
if direction in ["left", "above"]: |
|
164
|
|
|
if total_crop[direction] < self.volume_crop[direction]: |
|
165
|
|
|
self.volume_crop[direction] = total_crop[direction] |
|
166
|
|
|
if direction in ["right", "below"]: |
|
167
|
|
|
if total_crop[direction] > self.volume_crop[direction]: |
|
168
|
|
|
self.volume_crop[direction] = total_crop[direction] |
|
169
|
|
|
return binary_slice |
|
170
|
|
|
|
|
171
|
|
|
def _cropping_post_process(self): |
|
172
|
|
|
for direction in self.directions: |
|
173
|
|
|
if direction in ["left", "above"]: |
|
174
|
|
|
self.volume_crop[direction] = int(self.volume_crop[direction] - self.buffer) |
|
175
|
|
|
if self.volume_crop[direction] < self.orig_edges[direction]: |
|
176
|
|
|
self.volume_crop[direction] = self.orig_edges[direction] |
|
177
|
|
|
if direction in ["right", "below"]: |
|
178
|
|
|
self.volume_crop[direction] = int(self.volume_crop[direction] + self.buffer) |
|
179
|
|
|
if self.volume_crop[direction] > self.orig_edges[direction]: |
|
180
|
|
|
self.volume_crop[direction] = self.orig_edges[direction] |
|
181
|
|
|
preview = f":, {self.volume_crop['above']}:{self.volume_crop['below']}, {self.volume_crop['left']}:{self.volume_crop['right']}" |
|
182
|
|
|
self.exp.meta_data.set("pre_run_preview", preview) |
|
183
|
|
|
|
|
184
|
|
|
return preview |
|
185
|
|
|
|
|
186
|
|
|
def _write_preview_to_file(self, preview): |
|
187
|
|
|
if self.exp.meta_data.get("pre_run"): |
|
188
|
|
|
folder = self.exp.meta_data['out_path'] |
|
189
|
|
|
fname = self.exp.meta_data.get('datafile_name') + '_pre_run.nxs' |
|
190
|
|
|
filename = os.path.join(folder, fname) |
|
191
|
|
|
comm = self.get_communicator() |
|
192
|
|
|
if comm.rank == 0: |
|
193
|
|
|
with h5.File(filename, "a") as h5file: |
|
194
|
|
|
fsplit = self.exp.meta_data["data_path"].split("/") |
|
195
|
|
|
fsplit[-1] = "" |
|
196
|
|
|
stats_path = "/".join(fsplit) |
|
197
|
|
|
preview_group = h5file.require_group(stats_path) |
|
198
|
|
|
preview_group.create_dataset("preview", data=preview) |
|
199
|
|
|
|
|
200
|
|
|
def __save_image(self, binary_slice, name): |
|
201
|
|
|
# just used for testing |
|
202
|
|
|
binary_slice = binary_slice.astype(np.uint8)*150 |
|
203
|
|
|
im = Image.fromarray(binary_slice) |
|
204
|
|
|
|
|
205
|
|
|
im.save(f"/scratch/Savu/images/{name}.jpeg") |