1
|
|
|
# Copyright 2019 Diamond Light Source Ltd. |
2
|
|
|
# |
3
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
4
|
|
|
# you may not use this file except in compliance with the License. |
5
|
|
|
# You may obtain a copy of the License at |
6
|
|
|
# |
7
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
8
|
|
|
# |
9
|
|
|
# Unless required by applicable law or agreed to in writing, software |
10
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS, |
11
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12
|
|
|
# See the License for the specific language governing permissions and |
13
|
|
|
# limitations under the License. |
14
|
|
|
|
15
|
|
|
""" |
16
|
|
|
.. module:: otsu_thresh |
17
|
|
|
:platform: Unix |
18
|
|
|
:synopsis: Segmentation by thresholding based on Otsu's method. |
19
|
|
|
Optionally calculate cropping values based on the segmented image |
20
|
|
|
|
21
|
|
|
.. moduleauthor:: Jacob Williamson <[email protected]> |
22
|
|
|
""" |
23
|
|
|
|
24
|
|
|
|
25
|
|
|
from savu.plugins.plugin import Plugin |
26
|
|
|
from savu.plugins.driver.cpu_plugin import CpuPlugin |
27
|
|
|
from savu.plugins.utils import register_plugin |
28
|
|
|
|
29
|
|
|
import itertools |
30
|
|
|
import numpy as np |
31
|
|
|
from skimage.filters import threshold_otsu |
32
|
|
|
from PIL import Image |
33
|
|
|
import h5py as h5 |
34
|
|
|
import os |
35
|
|
|
|
36
|
|
|
@register_plugin |
37
|
|
|
class OtsuThresh(Plugin, CpuPlugin): |
38
|
|
|
|
39
|
|
|
def __init__(self): |
40
|
|
|
super(OtsuThresh, self).__init__("OtsuThresh") |
41
|
|
|
|
42
|
|
|
def setup(self): |
43
|
|
|
|
44
|
|
|
in_dataset, out_dataset = self.get_datasets() |
45
|
|
|
in_pData, out_pData = self.get_plugin_datasets() |
46
|
|
|
in_pData[0].plugin_data_setup(self.parameters['pattern'], 'single') |
47
|
|
|
|
48
|
|
|
for i in range(len(out_dataset)): |
49
|
|
|
out_dataset[0].create_dataset(in_dataset[0], dtype=np.uint8) |
50
|
|
|
out_pData[0].plugin_data_setup(self.parameters['pattern'], 'single') |
51
|
|
|
|
52
|
|
|
self.cropping = self.parameters["cropping"] |
53
|
|
|
self.buffer = self.parameters["buffer"] |
54
|
|
|
self.directions = self.parameters["directions"] |
55
|
|
|
|
56
|
|
|
self.shape = in_dataset[0].data_info.get("shape") |
57
|
|
|
self.fully_right, self.fully_below = self.shape[2] - 1, self.shape[1] - 1 |
58
|
|
|
self.volume_crop = {"left": 0, "above": 0, "right": self.fully_right, "below": self.fully_below} |
59
|
|
|
self.orig_edges = {"left": 0, "above": 0, "right": self.fully_right, "below": self.fully_below} |
60
|
|
|
self.gap_size = 20 |
61
|
|
|
|
62
|
|
|
def pre_process(self): |
63
|
|
|
|
64
|
|
|
if "left" in self.directions: |
65
|
|
|
self.volume_crop["left"] = self.fully_right |
66
|
|
|
if "above" in self.directions: |
67
|
|
|
self.volume_crop["above"] = self.fully_below |
68
|
|
|
if "right" in self.directions: |
69
|
|
|
self.volume_crop["right"] = 0 |
70
|
|
|
if "below" in self.directions: |
71
|
|
|
self.volume_crop["below"] = 0 |
72
|
|
|
|
73
|
|
|
def process_frames(self, data): |
74
|
|
|
threshold = threshold_otsu(data[0]) |
75
|
|
|
thresh_result = (data[0] > threshold) * 1 |
76
|
|
|
if self.cropping: |
77
|
|
|
cropped_slice = self._crop(thresh_result, ["left", "above", "right", "below"]) |
78
|
|
|
#if self.pcount % (self.shape[0]//10) == 0: |
79
|
|
|
# self.__save_image(cropped_slice, f"cropped_slices/slice{self.pcount}-cropped") |
80
|
|
|
if self.exp.meta_data.get("pre_run"): |
81
|
|
|
return None |
82
|
|
|
else: |
83
|
|
|
return thresh_result |
84
|
|
|
|
85
|
|
|
def post_process(self): |
86
|
|
|
if self.cropping: |
87
|
|
|
preview = self._cropping_post_process() |
88
|
|
|
if self.exp.meta_data.get("pre_run"): |
89
|
|
|
self._write_preview_to_file(preview) |
90
|
|
|
|
91
|
|
|
def nInput_datasets(self): |
92
|
|
|
return 1 |
93
|
|
|
|
94
|
|
|
def nOutput_datasets(self): |
95
|
|
|
if self.exp.meta_data.get("pre_run"): |
96
|
|
|
return 0 |
97
|
|
|
else: |
98
|
|
|
return 1 |
99
|
|
|
|
100
|
|
|
def _crop(self, binary_slice, directions, buffer=0): |
101
|
|
|
# For 3 dimensional volume and 2 dimensional slices |
102
|
|
|
|
103
|
|
|
total_crop = {"left": 0, "above": 0, "right": self.fully_right, "below": self.fully_below} |
104
|
|
|
total_counter = 0 |
105
|
|
|
reset_counter = 0 |
106
|
|
|
dir_cycle = itertools.cycle(directions) |
107
|
|
|
while reset_counter < 4: |
108
|
|
|
direction = next(dir_cycle) |
109
|
|
|
crops = {"left": 0, "above": 0, "right": binary_slice.shape[1], "below": binary_slice.shape[0]} |
110
|
|
|
if direction in ["above", "below"]: |
111
|
|
|
axis = 1 |
112
|
|
|
elif direction in ["left", "right"]: |
113
|
|
|
axis = 0 |
114
|
|
|
|
115
|
|
|
non_zeros_list = np.count_nonzero(binary_slice, axis=axis) |
|
|
|
|
116
|
|
|
previous = None |
117
|
|
|
fills = [] |
118
|
|
|
gaps = [] |
119
|
|
|
for i, non_zeros in enumerate(non_zeros_list): |
120
|
|
|
if non_zeros + buffer < binary_slice.shape[axis]: # check if some zeros (indicating sample) |
121
|
|
|
if previous != "fill": |
122
|
|
|
fills.append([]) |
123
|
|
|
fills[-1].append(i) |
124
|
|
|
previous = "fill" |
125
|
|
|
else: |
126
|
|
|
if previous != "gap": |
127
|
|
|
gaps.append([]) |
128
|
|
|
gaps[-1].append(i) |
129
|
|
|
previous = "gap" |
130
|
|
|
|
131
|
|
|
# find largest area of zeros (assume this is where sample is) |
132
|
|
|
largest_fill = [] |
133
|
|
|
for fill in fills: |
134
|
|
|
if len(fill) > len(largest_fill): |
135
|
|
|
largest_fill = fill |
136
|
|
|
|
137
|
|
|
# find a reasonably sized gap closest to the sample |
138
|
|
|
if direction in ["left", "above"]: |
139
|
|
|
for gap in gaps: |
140
|
|
|
if gap[-1] < largest_fill[0]: |
141
|
|
|
if len(gap) > self.gap_size: |
142
|
|
|
crops[direction] = gap[-1] |
143
|
|
|
reset_counter = 0 |
144
|
|
|
total_crop[direction] += crops[direction] |
145
|
|
|
|
146
|
|
|
elif direction in ["right", "below"]: |
147
|
|
|
for gap in gaps[::-1]: |
148
|
|
|
if gap[0] > largest_fill[-1]: |
149
|
|
|
if len(gap) > self.gap_size: |
150
|
|
|
crops[direction] = gap[0] + 1 |
151
|
|
|
reset_counter = 0 |
152
|
|
|
total_crop[direction] -= (binary_slice.shape[1 - axis] - crops[direction]) |
153
|
|
|
|
154
|
|
|
#if self.pcount == 0: |
155
|
|
|
# self.__save_image(binary_slice, f"{total_counter}-{reset_counter}-{direction}") |
156
|
|
|
|
157
|
|
|
reset_counter += 1 |
158
|
|
|
total_counter += 1 |
159
|
|
|
|
160
|
|
|
binary_slice = binary_slice[crops["above"]: crops["below"], crops["left"]: crops["right"]] |
161
|
|
|
|
162
|
|
|
for direction in self.directions: |
163
|
|
|
if direction in ["left", "above"]: |
164
|
|
|
if total_crop[direction] < self.volume_crop[direction]: |
165
|
|
|
self.volume_crop[direction] = total_crop[direction] |
166
|
|
|
if direction in ["right", "below"]: |
167
|
|
|
if total_crop[direction] > self.volume_crop[direction]: |
168
|
|
|
self.volume_crop[direction] = total_crop[direction] |
169
|
|
|
return binary_slice |
170
|
|
|
|
171
|
|
|
def _cropping_post_process(self): |
172
|
|
|
for direction in self.directions: |
173
|
|
|
if direction in ["left", "above"]: |
174
|
|
|
self.volume_crop[direction] = int(self.volume_crop[direction] - self.buffer) |
175
|
|
|
if self.volume_crop[direction] < self.orig_edges[direction]: |
176
|
|
|
self.volume_crop[direction] = self.orig_edges[direction] |
177
|
|
|
if direction in ["right", "below"]: |
178
|
|
|
self.volume_crop[direction] = int(self.volume_crop[direction] + self.buffer) |
179
|
|
|
if self.volume_crop[direction] > self.orig_edges[direction]: |
180
|
|
|
self.volume_crop[direction] = self.orig_edges[direction] |
181
|
|
|
preview = f":, {self.volume_crop['above']}:{self.volume_crop['below']}, {self.volume_crop['left']}:{self.volume_crop['right']}" |
182
|
|
|
self.exp.meta_data.set("pre_run_preview", preview) |
183
|
|
|
|
184
|
|
|
return preview |
185
|
|
|
|
186
|
|
|
def _write_preview_to_file(self, preview): |
187
|
|
|
if self.exp.meta_data.get("pre_run"): |
188
|
|
|
folder = self.exp.meta_data['out_path'] |
189
|
|
|
fname = self.exp.meta_data.get('datafile_name') + '_pre_run.nxs' |
190
|
|
|
filename = os.path.join(folder, fname) |
191
|
|
|
comm = self.get_communicator() |
192
|
|
|
if comm.rank == 0: |
193
|
|
|
with h5.File(filename, "a") as h5file: |
194
|
|
|
fsplit = self.exp.meta_data["data_path"].split("/") |
195
|
|
|
fsplit[-1] = "" |
196
|
|
|
stats_path = "/".join(fsplit) |
197
|
|
|
preview_group = h5file.require_group(stats_path) |
198
|
|
|
preview_group.create_dataset("preview", data=preview) |
199
|
|
|
|
200
|
|
|
def __save_image(self, binary_slice, name): |
201
|
|
|
# just used for testing |
202
|
|
|
binary_slice = binary_slice.astype(np.uint8)*150 |
203
|
|
|
im = Image.fromarray(binary_slice) |
204
|
|
|
|
205
|
|
|
im.save(f"/scratch/Savu/images/{name}.jpeg") |