Centering360.process_frames()   A
last analyzed

Complexity

Conditions 2

Size

Total Lines 37
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 13
nop 2
dl 0
loc 37
rs 9.75
c 0
b 0
f 0
1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""
15
.. module:: centering360
16
   :platform: Unix
17
   :synopsis: A plugin to find the center of rotation per frame
18
.. moduleauthor:: Nghia Vo, <[email protected]>
19
"""
20
from savu.plugins.driver.cpu_plugin import CpuPlugin
21
from savu.plugins.utils import register_plugin
22
from savu.plugins.filters.base_filter import BaseFilter
23
import savu.core.utils as cu
24
25
import logging
26
import numpy as np
27
import scipy.ndimage as ndi
28
from scipy import stats
29
30
@register_plugin
31
class Centering360(BaseFilter, CpuPlugin):
32
33
    def __init__(self):
34
        super(Centering360, self).__init__("Centering360")
35
36
    def find_overlap(self, mat1, mat2, win_width, side=None, denoise=True, norm=False,
37
                     use_overlap=False):
38
        """
39
        Find the overlap area and overlap side between two images (Ref. [1]) where
40
        the overlap side referring to the first image.
41
42
        Parameters
43
        ----------
44
        mat1 : array_like
45
            2D array. Projection image or sinogram image.
46
        mat2 :  array_like
47
            2D array. Projection image or sinogram image.
48
        win_width : int
49
            Width of the searching window.
50
        side : {None, 0, 1}, optional
51
            Only there options: None, 0, or 1. "None" corresponding to fully
52
            automated determination. "0" corresponding to the left side. "1"
53
            corresponding to the right side.
54
        denoise : bool, optional
55
            Apply the Gaussian filter if True.
56
        norm : bool, optional
57
            Apply the normalization if True.
58
        use_overlap : bool, optional
59
            Use the combination of images in the overlap area for calculating
60
            correlation coefficients if True.
61
62
        Returns
63
        -------
64
        overlap : float
65
            Width of the overlap area between two images.
66
        side : int
67
            Overlap side between two images.
68
        overlap_position : float
69
            Position of the window in the first image giving the best
70
            correlation metric.
71
72
        References
73
        ----------
74
        .. [1] https://doi.org/10.1364/OE.418448
75
        """
76
        (_, ncol1) = mat1.shape
77
        (_, ncol2) = mat2.shape
78
        win_width = np.int16(np.clip(win_width, 6, min(ncol1, ncol2) // 2))
79
        if side == 1:
80
            (list_metric, offset) = self.search_overlap(mat1, mat2, win_width, side,
81
                                                   denoise, norm, use_overlap)
82
            (_, overlap_position) = self.calculate_curvature(list_metric)
83
            overlap_position = overlap_position + offset
84
            overlap = ncol1 - overlap_position + win_width // 2
85
        elif side == 0:
86
            (list_metric, offset) = self.search_overlap(mat1, mat2, win_width, side,
87
                                                   denoise, norm, use_overlap)
88
            (_, overlap_position) = self.calculate_curvature(list_metric)
89
            overlap_position = overlap_position + offset
90
            overlap = overlap_position + win_width // 2
91
        else:
92
            (list_metric1, offset1) = self.search_overlap(mat1, mat2, win_width, 1, norm,
93
                                                     denoise, use_overlap)
94
            (list_metric2, offset2) = self.search_overlap(mat1, mat2, win_width, 0, norm,
95
                                                     denoise, use_overlap)
96
            (curvature1, overlap_position1) = self.calculate_curvature(list_metric1)
97
            overlap_position1 = overlap_position1 + offset1
98
            (curvature2, overlap_position2) = self.calculate_curvature(list_metric2)
99
            overlap_position2 = overlap_position2 + offset2
100
            if curvature1 > curvature2:
101
                side = 1
102
                overlap_position = overlap_position1
103
                overlap = ncol1 - overlap_position + win_width // 2
104
            else:
105
                side = 0
106
                overlap_position = overlap_position2
107
                overlap = overlap_position + win_width // 2
108
        return overlap, side, overlap_position
109
110
    def search_overlap(self, mat1, mat2, win_width, side, denoise=True, norm=False,
111
                       use_overlap=False):
112
        """
113
        Calculate the correlation metrics between a rectangular region, defined
114
        by the window width, on the utmost left/right side of image 2 and the
115
        same size region in image 1 where the region is slided across image 1.
116
117
        Parameters
118
        ----------
119
        mat1 : array_like
120
            2D array. Projection image or sinogram image.
121
        mat2 : array_like
122
            2D array. Projection image or sinogram image.
123
        win_width : int
124
            Width of the searching window.
125
        side : {0, 1}
126
            Only two options: 0 or 1. It is used to indicate the overlap side
127
            respects to image 1. "0" corresponds to the left side. "1" corresponds
128
            to the right side.
129
        denoise : bool, optional
130
            Apply the Gaussian filter if True.
131
        norm : bool, optional
132
            Apply the normalization if True.
133
        use_overlap : bool, optional
134
            Use the combination of images in the overlap area for calculating
135
            correlation coefficients if True.
136
137
        Returns
138
        -------
139
        list_metric : array_like
140
            1D array. List of the correlation metrics.
141
        offset : int
142
            Initial position of the searching window where the position
143
            corresponds to the center of the window.
144
        """
145
        if denoise is True:
146
            mat1 = ndi.gaussian_filter(mat1, (2, 2), mode='reflect')
147
            mat2 = ndi.gaussian_filter(mat2, (2, 2), mode='reflect')
148
        (nrow1, ncol1) = mat1.shape
149
        (nrow2, ncol2) = mat2.shape
150
        if nrow1 != nrow2:
151
            raise ValueError("Two images are not at the same height!!!")
152
        win_width = np.int16(np.clip(win_width, 6, min(ncol1, ncol2) // 2 - 1))
153
        offset = win_width // 2
154
        win_width = 2 * offset  # Make it even
155
        ramp_down = np.linspace(1.0, 0.0, win_width)
156
        ramp_up = 1.0 - ramp_down
157
        wei_down = np.tile(ramp_down, (nrow1, 1))
158
        wei_up = np.tile(ramp_up, (nrow1, 1))
159
        if side == 1:
160
            mat2_roi = mat2[:, 0:win_width]
161
            mat2_roi_wei = mat2_roi * wei_up
162
        else:
163
            mat2_roi = mat2[:, ncol2 - win_width:]
164
            mat2_roi_wei = mat2_roi * wei_down
165
        list_mean2 = np.mean(np.abs(mat2_roi), axis=1)
166
        list_pos = np.arange(offset, ncol1 - offset)
167
        num_metric = len(list_pos)
168
        list_metric = np.ones(num_metric, dtype=np.float32)
169
        for i, pos in enumerate(list_pos):
170
            mat1_roi = mat1[:, pos - offset:pos + offset]
171
            if use_overlap is True:
172
                if side == 1:
173
                    mat1_roi_wei = mat1_roi * wei_down
174
                else:
175
                    mat1_roi_wei = mat1_roi * wei_up
176
            if norm is True:
177
                list_mean1 = np.mean(np.abs(mat1_roi), axis=1)
178
                list_fact = list_mean2 / list_mean1
179
                mat_fact = np.transpose(np.tile(list_fact, (win_width, 1)))
180
                mat1_roi = mat1_roi * mat_fact
181
                if use_overlap is True:
182
                    mat1_roi_wei = mat1_roi_wei * mat_fact
0 ignored issues
show
introduced by
The variable mat1_roi_wei does not seem to be defined for all execution paths.
Loading history...
183
            if use_overlap is True:
184
                mat_comb = mat1_roi_wei + mat2_roi_wei
185
                list_metric[i] = (self.correlation_metric(mat1_roi, mat2_roi)
186
                                  + self.correlation_metric(mat1_roi, mat_comb)
187
                                  + self.correlation_metric(mat2_roi, mat_comb)) / 3.0
188
            else:
189
                list_metric[i] = self.correlation_metric(mat1_roi, mat2_roi)
190
        min_metric = np.min(list_metric)
191
        if min_metric != 0.0:
192
            list_metric = list_metric / min_metric
193
        return list_metric, offset
194
195
    def correlation_metric(self, mat1, mat2):
196
        """
197
        Calculate the correlation metric. Smaller metric corresponds to better
198
        correlation.
199
200
        Parameters
201
        ---------
202
        mat1 : array_like
203
        mat2 : array_like
204
205
        Returns
206
        -------
207
        float
208
            Correlation metric.
209
        """
210
        metric = np.abs(
211
            1.0 - stats.pearsonr(mat1.flatten('F'), mat2.flatten('F'))[0])
212
        return metric
213
214
    def calculate_curvature(self, list_metric):
215
        """
216
        Calculate the curvature of a fitted curve going through the minimum
217
        value of a metric list.
218
219
        Parameters
220
        ----------
221
        list_metric : array_like
222
            1D array. List of metrics.
223
224
        Returns
225
        -------
226
        curvature : float
227
            Quadratic coefficient of the parabola fitting.
228
        min_pos : float
229
            Position of the minimum value with sub-pixel accuracy.
230
        """
231
        radi = 2
232
        num_metric = len(list_metric)
233
        min_pos = np.clip(
234
            np.argmin(list_metric), radi, num_metric - radi - 1)
235
        list1 = list_metric[min_pos - radi:min_pos + radi + 1]
236
        (afact1, _, _) = np.polyfit(np.arange(0, 2 * radi + 1), list1, 2)
237
        list2 = list_metric[min_pos - 1:min_pos + 2]
238
        (afact2, bfact2, _) = np.polyfit(
239
            np.arange(min_pos - 1, min_pos + 2), list2, 2)
240
        curvature = np.abs(afact1)
241
        if afact2 != 0.0:
242
            num = - bfact2 / (2 * afact2)
243
            if (num >= min_pos - 1) and (num <= min_pos + 1):
244
                min_pos = num
245
        return curvature, np.float32(min_pos)
246
247 View Code Duplication
    def _downsample(self, image, dsp_fact0, dsp_fact1):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
248
        """Downsample an image by averaging.
249
250
        Parameters
251
        ----------
252
            image : 2D array.
253
            dsp_fact0 : downsampling factor along axis 0.
254
            dsp_fact1 : downsampling factor along axis 1.
255
256
        Returns
257
        ---------
258
            image_dsp : Downsampled image.
259
        """
260
        (height, width) = image.shape
261
        dsp_fact0 = np.clip(np.int16(dsp_fact0), 1, height // 2)
262
        dsp_fact1 = np.clip(np.int16(dsp_fact1), 1, width // 2)
263
        height_dsp = height // dsp_fact0
264
        width_dsp = width // dsp_fact1
265
        if dsp_fact0 == 1 and dsp_fact1 == 1:
266
            image_dsp = image
267
        else:
268
            image_dsp = image[0:dsp_fact0 * height_dsp, 0:dsp_fact1 * width_dsp]
269
            image_dsp = image_dsp.reshape(
270
                height_dsp, dsp_fact0, width_dsp, dsp_fact1).mean(-1).mean(1)
271
        return image_dsp
272
273
    def pre_process(self):
274
        self.win_width = np.int16(self.parameters['win_width'])
275
        self.side = self.parameters['side']
276
        self.denoise = self.parameters['denoise']
277
        self.norm = self.parameters['norm']
278
        self.use_overlap = self.parameters['use_overlap']
279
280
        self.broadcast_method = str(self.parameters['broadcast_method'])
281
        self.error_msg_1 = ""
282
        self.error_msg_2 = ""
283
        self.error_msg_3 = ""
284 View Code Duplication
        if not ((self.broadcast_method == 'mean')
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
285
                or (self.broadcast_method == 'median')
286
                or (self.broadcast_method == 'linear_fit')
287
                or (self.broadcast_method == 'nearest')):
288
            self.error_msg_3 = "!!!WARNING!!! Selected broadcasting method" \
289
                               " is out of the list. Use the default option:" \
290
                               " 'median'"
291
            logging.warning(self.error_msg_3)
292
            cu.user_message(self.error_msg_3)
293
            self.broadcast_method = 'median'
294
        in_pData = self.get_plugin_in_datasets()[0]
295
        data = self.get_in_datasets()[0]
296
        starts, stops, steps = data.get_preview().get_starts_stops_steps()[0:3]
297
        start_ind = starts[1]
298
        stop_ind = stops[1]
299
        step_ind = steps[1]
300
        name = data.get_name()
301
        pre_start = self.exp.meta_data.get(name + '_preview_starts')[1]
302
        pre_stop = self.exp.meta_data.get(name + '_preview_stops')[1]
303
        pre_step = self.exp.meta_data.get(name + '_preview_steps')[1]
304
        self.origin_prev = np.arange(pre_start, pre_stop, pre_step)
305
        self.plugin_prev = self.origin_prev[start_ind:stop_ind:step_ind]
306
        num_sino = len(self.plugin_prev)
307
        if num_sino > 20:
308
            warning_msg = "\n!!!WARNING!!! You selected to calculate the " \
309
                          "center-of-rotation using '{}' sinograms.\n" \
310
                          "This is computationally expensive. Considering to " \
311
                          "adjust the preview parameter to use\na smaller " \
312
                          "number of sinograms (< 20).\n".format(num_sino)
313
            logging.warning(warning_msg)
314
            cu.user_message(warning_msg)
315
316
    def process_frames(self, data):
317
        """
318
        Find the center-of-rotation (COR) in a 360-degree scan with offset COR use
319
        the method presented in Ref. [1].
320
321
        Parameters
322
        ----------
323
        data : array_like
324
            2D array. 360-degree sinogram.
325
326
        Returns
327
        -------
328
        cor : float
329
            Center-of-rotation.
330
331
        References
332
        ----------
333
        .. [1] https://doi.org/10.1364/OE.418448
334
        """
335
        sino = data[0]
336
        (nrow, ncol) = sino.shape
337
        nrow_180 = nrow // 2 + 1
338
        sino_top = sino[0:nrow_180, :]
339
        sino_bot = np.fliplr(sino[-nrow_180:, :])
340
        overlap, side, overlap_position =\
341
            self.find_overlap(sino_top, sino_bot, self.win_width, self.side,
342
                              self.denoise, self.norm, self.use_overlap)
343
        #overlap : Width of the overlap area between two halves
344
        #           of the sinogram.
345
        # side : Overlap side between two halves of the sinogram.
346
        # overlap_position : Position of the window in the first
347
        #           image giving the best correlation metric."""
348
        if side == 0:
349
            cor = overlap / 2.0 - 1.0
350
        else:
351
            cor = ncol - overlap / 2.0 - 1.0
352
        return [np.array([cor]), np.array([cor])]
353
354 View Code Duplication
    def post_process(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
355
        in_datasets, out_datasets = self.get_datasets()
356
        cor_prev = out_datasets[0].data[...]
357
        cor_broad = out_datasets[1].data[...]
358
        cor_broad[:] = np.median(np.squeeze(cor_prev))
359
        self.cor_for_executive_summary = np.median(cor_broad[:])
360
        if self.broadcast_method == 'mean':
361
            cor_broad[:] = np.mean(np.squeeze(cor_prev))
362
            self.cor_for_executive_summary = np.mean(cor_broad[:])
363
        if (self.broadcast_method == 'linear_fit') and (len(cor_prev) > 1):
364
            afact, bfact = np.polyfit(self.plugin_prev, cor_prev[:, 0], 1)
365
            list_cor = self.origin_prev * afact + bfact
366
            cor_broad[:, 0] = list_cor
367
            self.cor_for_executive_summary = cor_broad[:]
368
        if (self.broadcast_method == 'nearest') and (len(cor_prev) > 1):
369
            for i, pos in enumerate(self.origin_prev):
370
                minpos = np.argmin(np.abs(pos - self.plugin_prev))
371
                cor_broad[i, 0] = cor_prev[minpos, 0]
372
            self.cor_for_executive_summary = cor_broad[:]
373
        out_datasets[1].data[:] = cor_broad[:]
374
        self.populate_meta_data('cor_preview', np.squeeze(cor_prev))
375
        self.populate_meta_data('centre_of_rotation',
376
                                out_datasets[1].data[:].squeeze(axis=1))
377
378
    def populate_meta_data(self, key, value):
379
        datasets = self.parameters['datasets_to_populate']
380
        in_meta_data = self.get_in_meta_data()[0]
381
        in_meta_data.set(key, value)
382
        for name in datasets:
383
            self.exp.index['in_data'][name].meta_data.set(key, value)
384
385
    def setup(self):
386
        self.exp.log(self.name + " Start calculating center of rotation")
387
        # set up the output dataset that is created by the plugin
388
        in_dataset, out_dataset = self.get_datasets()
389
        in_pData, out_pData = self.get_plugin_datasets()
390
        in_pData[0].plugin_data_setup('SINOGRAM', self.get_max_frames())
391
        slice_dirs = list(in_dataset[0].get_slice_dimensions())
392
        self.orig_full_shape = in_dataset[0].get_shape()
393
394
        # reduce the data as per data_subset parameter
395
        self.set_preview(in_dataset[0], self.parameters['preview'])
396
        total_frames = \
397
            self._calc_total_frames(in_dataset[0].get_preview(), slice_dirs)
398
399
        # copy all required information from in_dataset[0]
400
        fullData = in_dataset[0]
401
        new_shape = (np.prod(np.array(fullData.get_shape())[slice_dirs]), 1)
402
        self.orig_shape = \
403
            (np.prod(np.array(self.orig_full_shape)[slice_dirs]), 1)
404
        out_dataset[0].create_dataset(shape=new_shape,
405
                                      axis_labels=['x.pixels', 'y.pixels'],
406
                                      remove=True,
407
                                      transport='hdf5')
408
        out_dataset[0].add_pattern("METADATA", core_dims=(1,), slice_dims=(0,))
409
410
        out_dataset[1].create_dataset(shape=self.orig_shape,
411
                                      axis_labels=['x.pixels', 'y.pixels'],
412
                                      remove=True,
413
                                      transport='hdf5')
414
        out_dataset[1].add_pattern("METADATA", core_dims=(1,), slice_dims=(0,))
415
        out_pData[0].plugin_data_setup('METADATA', self.get_max_frames())
416
        out_pData[1].plugin_data_setup('METADATA', self.get_max_frames())
417
        out_pData[1].meta_data.set('fix_total_frames', total_frames)
418
        self.exp.log(self.name + " End")
419
420
    def _calc_total_frames(self, preview, slice_dims):
421
        starts, stops, steps, _ = preview.get_starts_stops_steps()
422
        lengths = [len(np.arange(starts[i], stops[i], steps[i]))
423
                   for i in range(len(starts))]
424
        return np.prod([lengths[i] for i in slice_dims])
425
426
    def nOutput_datasets(self):
427
        return 2
428
429
    def get_max_frames(self):
430
        return 'single'
431
432
    def fix_transport(self):
433
        return 'hdf5'
434
435 View Code Duplication
    def executive_summary(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
436
        if ((self.error_msg_1 == "")
437
                and (self.error_msg_2 == "")):
438
            msg = "Centre of rotation is : %s" % (
439
                str(self.cor_for_executive_summary))
440
        else:
441
            msg = "\n" + self.error_msg_1 + "\n" + self.error_msg_2
442
            msg2 = "(Not well) estimated centre of rotation is : %s" % (str(
443
                self.cor_for_executive_summary))
444
            cu.user_message(msg2)
445
        return [msg]
446