Centering360.find_overlap()   A
last analyzed

Complexity

Conditions 4

Size

Total Lines 73
Code Lines 33

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 33
nop 8
dl 0
loc 73
rs 9.0879
c 0
b 0
f 0

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""
15
.. module:: centering360
16
   :platform: Unix
17
   :synopsis: A plugin to find the center of rotation per frame
18
.. moduleauthor:: Nghia Vo, <[email protected]>
19
"""
20
from savu.plugins.driver.cpu_plugin import CpuPlugin
21
from savu.plugins.utils import register_plugin
22
from savu.plugins.filters.base_filter import BaseFilter
23
import savu.core.utils as cu
24
25
import logging
26
import numpy as np
27
import scipy.ndimage as ndi
28
from scipy import stats
29
30
@register_plugin
31
class Centering360(BaseFilter, CpuPlugin):
32
33
    def __init__(self):
34
        super(Centering360, self).__init__("Centering360")
35
36
    def find_overlap(self, mat1, mat2, win_width, side=None, denoise=True, norm=False,
37
                     use_overlap=False):
38
        """
39
        Find the overlap area and overlap side between two images (Ref. [1]) where
40
        the overlap side referring to the first image.
41
42
        Parameters
43
        ----------
44
        mat1 : array_like
45
            2D array. Projection image or sinogram image.
46
        mat2 :  array_like
47
            2D array. Projection image or sinogram image.
48
        win_width : int
49
            Width of the searching window.
50
        side : {None, 0, 1}, optional
51
            Only there options: None, 0, or 1. "None" corresponding to fully
52
            automated determination. "0" corresponding to the left side. "1"
53
            corresponding to the right side.
54
        denoise : bool, optional
55
            Apply the Gaussian filter if True.
56
        norm : bool, optional
57
            Apply the normalization if True.
58
        use_overlap : bool, optional
59
            Use the combination of images in the overlap area for calculating
60
            correlation coefficients if True.
61
62
        Returns
63
        -------
64
        overlap : float
65
            Width of the overlap area between two images.
66
        side : int
67
            Overlap side between two images.
68
        overlap_position : float
69
            Position of the window in the first image giving the best
70
            correlation metric.
71
72
        References
73
        ----------
74
        .. [1] https://doi.org/10.1364/OE.418448
75
        """
76
        (_, ncol1) = mat1.shape
77
        (_, ncol2) = mat2.shape
78
        win_width = np.int16(np.clip(win_width, 6, min(ncol1, ncol2) // 2))
79
        if side == 1:
80
            (list_metric, offset) = self.search_overlap(mat1, mat2, win_width, side,
81
                                                   denoise, norm, use_overlap)
82
            (_, overlap_position) = self.calculate_curvature(list_metric)
83
            overlap_position = overlap_position + offset
84
            overlap = ncol1 - overlap_position + win_width // 2
85
        elif side == 0:
86
            (list_metric, offset) = self.search_overlap(mat1, mat2, win_width, side,
87
                                                   denoise, norm, use_overlap)
88
            (_, overlap_position) = self.calculate_curvature(list_metric)
89
            overlap_position = overlap_position + offset
90
            overlap = overlap_position + win_width // 2
91
        else:
92
            (list_metric1, offset1) = self.search_overlap(mat1, mat2, win_width, 1, norm,
93
                                                     denoise, use_overlap)
94
            (list_metric2, offset2) = self.search_overlap(mat1, mat2, win_width, 0, norm,
95
                                                     denoise, use_overlap)
96
            (curvature1, overlap_position1) = self.calculate_curvature(list_metric1)
97
            overlap_position1 = overlap_position1 + offset1
98
            (curvature2, overlap_position2) = self.calculate_curvature(list_metric2)
99
            overlap_position2 = overlap_position2 + offset2
100
            if curvature1 > curvature2:
101
                side = 1
102
                overlap_position = overlap_position1
103
                overlap = ncol1 - overlap_position + win_width // 2
104
            else:
105
                side = 0
106
                overlap_position = overlap_position2
107
                overlap = overlap_position + win_width // 2
108
        return overlap, side, overlap_position
109
110
    def search_overlap(self, mat1, mat2, win_width, side, denoise=True, norm=False,
111
                       use_overlap=False):
112
        """
113
        Calculate the correlation metrics between a rectangular region, defined
114
        by the window width, on the utmost left/right side of image 2 and the
115
        same size region in image 1 where the region is slided across image 1.
116
117
        Parameters
118
        ----------
119
        mat1 : array_like
120
            2D array. Projection image or sinogram image.
121
        mat2 : array_like
122
            2D array. Projection image or sinogram image.
123
        win_width : int
124
            Width of the searching window.
125
        side : {0, 1}
126
            Only two options: 0 or 1. It is used to indicate the overlap side
127
            respects to image 1. "0" corresponds to the left side. "1" corresponds
128
            to the right side.
129
        denoise : bool, optional
130
            Apply the Gaussian filter if True.
131
        norm : bool, optional
132
            Apply the normalization if True.
133
        use_overlap : bool, optional
134
            Use the combination of images in the overlap area for calculating
135
            correlation coefficients if True.
136
137
        Returns
138
        -------
139
        list_metric : array_like
140
            1D array. List of the correlation metrics.
141
        offset : int
142
            Initial position of the searching window where the position
143
            corresponds to the center of the window.
144
        """
145
        if denoise is True:
146
            mat1 = ndi.gaussian_filter(mat1, (2, 2), mode='reflect')
147
            mat2 = ndi.gaussian_filter(mat2, (2, 2), mode='reflect')
148
        (nrow1, ncol1) = mat1.shape
149
        (nrow2, ncol2) = mat2.shape
150
        if nrow1 != nrow2:
151
            raise ValueError("Two images are not at the same height!!!")
152
        win_width = np.int16(np.clip(win_width, 6, min(ncol1, ncol2) // 2 - 1))
153
        offset = win_width // 2
154
        win_width = 2 * offset  # Make it even
155
        ramp_down = np.linspace(1.0, 0.0, win_width)
156
        ramp_up = 1.0 - ramp_down
157
        wei_down = np.tile(ramp_down, (nrow1, 1))
158
        wei_up = np.tile(ramp_up, (nrow1, 1))
159
        if side == 1:
160
            mat2_roi = mat2[:, 0:win_width]
161
            mat2_roi_wei = mat2_roi * wei_up
162
        else:
163
            mat2_roi = mat2[:, ncol2 - win_width:]
164
            mat2_roi_wei = mat2_roi * wei_down
165
        list_mean2 = np.mean(np.abs(mat2_roi), axis=1)
166
        list_pos = np.arange(offset, ncol1 - offset)
167
        num_metric = len(list_pos)
168
        list_metric = np.ones(num_metric, dtype=np.float32)
169
        for i, pos in enumerate(list_pos):
170
            mat1_roi = mat1[:, pos - offset:pos + offset]
171
            if use_overlap is True:
172
                if side == 1:
173
                    mat1_roi_wei = mat1_roi * wei_down
174
                else:
175
                    mat1_roi_wei = mat1_roi * wei_up
176
            if norm is True:
177
                list_mean1 = np.mean(np.abs(mat1_roi), axis=1)
178
                list_fact = list_mean2 / list_mean1
179
                mat_fact = np.transpose(np.tile(list_fact, (win_width, 1)))
180
                mat1_roi = mat1_roi * mat_fact
181
                if use_overlap is True:
182
                    mat1_roi_wei = mat1_roi_wei * mat_fact
0 ignored issues
show
introduced by
The variable mat1_roi_wei does not seem to be defined for all execution paths.
Loading history...
183
            if use_overlap is True:
184
                mat_comb = mat1_roi_wei + mat2_roi_wei
185
                list_metric[i] = (self.correlation_metric(mat1_roi, mat2_roi)
186
                                  + self.correlation_metric(mat1_roi, mat_comb)
187
                                  + self.correlation_metric(mat2_roi, mat_comb)) / 3.0
188
            else:
189
                list_metric[i] = self.correlation_metric(mat1_roi, mat2_roi)
190
        min_metric = np.min(list_metric)
191
        if min_metric != 0.0:
192
            list_metric = list_metric / min_metric
193
        return list_metric, offset
194
195
    def correlation_metric(self, mat1, mat2):
196
        """
197
        Calculate the correlation metric. Smaller metric corresponds to better
198
        correlation.
199
200
        Parameters
201
        ---------
202
        mat1 : array_like
203
        mat2 : array_like
204
205
        Returns
206
        -------
207
        float
208
            Correlation metric.
209
        """
210
        metric = np.abs(
211
            1.0 - stats.pearsonr(mat1.flatten('F'), mat2.flatten('F'))[0])
212
        return metric
213
214
    def calculate_curvature(self, list_metric):
215
        """
216
        Calculate the curvature of a fitted curve going through the minimum
217
        value of a metric list.
218
219
        Parameters
220
        ----------
221
        list_metric : array_like
222
            1D array. List of metrics.
223
224
        Returns
225
        -------
226
        curvature : float
227
            Quadratic coefficient of the parabola fitting.
228
        min_pos : float
229
            Position of the minimum value with sub-pixel accuracy.
230
        """
231
        radi = 2
232
        num_metric = len(list_metric)
233
        min_pos = np.clip(
234
            np.argmin(list_metric), radi, num_metric - radi - 1)
235
        list1 = list_metric[min_pos - radi:min_pos + radi + 1]
236
        (afact1, _, _) = np.polyfit(np.arange(0, 2 * radi + 1), list1, 2)
237
        list2 = list_metric[min_pos - 1:min_pos + 2]
238
        (afact2, bfact2, _) = np.polyfit(
239
            np.arange(min_pos - 1, min_pos + 2), list2, 2)
240
        curvature = np.abs(afact1)
241
        if afact2 != 0.0:
242
            num = - bfact2 / (2 * afact2)
243
            if (num >= min_pos - 1) and (num <= min_pos + 1):
244
                min_pos = num
245
        return curvature, np.float32(min_pos)
246
247 View Code Duplication
    def _downsample(self, image, dsp_fact0, dsp_fact1):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
248
        """Downsample an image by averaging.
249
250
        Parameters
251
        ----------
252
            image : 2D array.
253
            dsp_fact0 : downsampling factor along axis 0.
254
            dsp_fact1 : downsampling factor along axis 1.
255
256
        Returns
257
        ---------
258
            image_dsp : Downsampled image.
259
        """
260
        (height, width) = image.shape
261
        dsp_fact0 = np.clip(np.int16(dsp_fact0), 1, height // 2)
262
        dsp_fact1 = np.clip(np.int16(dsp_fact1), 1, width // 2)
263
        height_dsp = height // dsp_fact0
264
        width_dsp = width // dsp_fact1
265
        if dsp_fact0 == 1 and dsp_fact1 == 1:
266
            image_dsp = image
267
        else:
268
            image_dsp = image[0:dsp_fact0 * height_dsp, 0:dsp_fact1 * width_dsp]
269
            image_dsp = image_dsp.reshape(
270
                height_dsp, dsp_fact0, width_dsp, dsp_fact1).mean(-1).mean(1)
271
        return image_dsp
272
273
    def pre_process(self):
274
        self.win_width = np.int16(self.parameters['win_width'])
275
        self.side = self.parameters['side']
276
        self.denoise = self.parameters['denoise']
277
        self.norm = self.parameters['norm']
278
        self.use_overlap = self.parameters['use_overlap']
279
280
        self.broadcast_method = str(self.parameters['broadcast_method'])
281
        self.error_msg_1 = ""
282
        self.error_msg_2 = ""
283
        self.error_msg_3 = ""
284 View Code Duplication
        if not ((self.broadcast_method == 'mean')
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
285
                or (self.broadcast_method == 'median')
286
                or (self.broadcast_method == 'linear_fit')
287
                or (self.broadcast_method == 'nearest')):
288
            self.error_msg_3 = "!!!WARNING!!! Selected broadcasting method" \
289
                               " is out of the list. Use the default option:" \
290
                               " 'median'"
291
            logging.warning(self.error_msg_3)
292
            cu.user_message(self.error_msg_3)
293
            self.broadcast_method = 'median'
294
        in_pData = self.get_plugin_in_datasets()[0]
295
        data = self.get_in_datasets()[0]
296
        starts, stops, steps = data.get_preview().get_starts_stops_steps()[0:3]
297
        start_ind = starts[1]
298
        stop_ind = stops[1]
299
        step_ind = steps[1]
300
        name = data.get_name()
301
        pre_start = self.exp.meta_data.get(name + '_preview_starts')[1]
302
        pre_stop = self.exp.meta_data.get(name + '_preview_stops')[1]
303
        pre_step = self.exp.meta_data.get(name + '_preview_steps')[1]
304
        self.origin_prev = np.arange(pre_start, pre_stop, pre_step)
305
        self.plugin_prev = self.origin_prev[start_ind:stop_ind:step_ind]
306
        num_sino = len(self.plugin_prev)
307
        if num_sino > 20:
308
            warning_msg = "\n!!!WARNING!!! You selected to calculate the " \
309
                          "center-of-rotation using '{}' sinograms.\n" \
310
                          "This is computationally expensive. Considering to " \
311
                          "adjust the preview parameter to use\na smaller " \
312
                          "number of sinograms (< 20).\n".format(num_sino)
313
            logging.warning(warning_msg)
314
            cu.user_message(warning_msg)
315
316
    def process_frames(self, data):
317
        """
318
        Find the center-of-rotation (COR) in a 360-degree scan with offset COR use
319
        the method presented in Ref. [1].
320
321
        Parameters
322
        ----------
323
        data : array_like
324
            2D array. 360-degree sinogram.
325
326
        Returns
327
        -------
328
        cor : float
329
            Center-of-rotation.
330
331
        References
332
        ----------
333
        .. [1] https://doi.org/10.1364/OE.418448
334
        """
335
        sino = data[0]
336
        (nrow, ncol) = sino.shape
337
        nrow_180 = nrow // 2 + 1
338
        sino_top = sino[0:nrow_180, :]
339
        sino_bot = np.fliplr(sino[-nrow_180:, :])
340
        overlap, side, overlap_position =\
341
            self.find_overlap(sino_top, sino_bot, self.win_width, self.side,
342
                              self.denoise, self.norm, self.use_overlap)
343
        #overlap : Width of the overlap area between two halves
344
        #           of the sinogram.
345
        # side : Overlap side between two halves of the sinogram.
346
        # overlap_position : Position of the window in the first
347
        #           image giving the best correlation metric."""
348
        if side == 0:
349
            cor = overlap / 2.0 - 1.0
350
        else:
351
            cor = ncol - overlap / 2.0 - 1.0
352
        return [np.array([cor]), np.array([cor])]
353
354 View Code Duplication
    def post_process(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
355
        in_datasets, out_datasets = self.get_datasets()
356
        cor_prev = out_datasets[0].data[...]
357
        cor_broad = out_datasets[1].data[...]
358
        cor_broad[:] = np.median(np.squeeze(cor_prev))
359
        self.cor_for_executive_summary = np.median(cor_broad[:])
360
        if self.broadcast_method == 'mean':
361
            cor_broad[:] = np.mean(np.squeeze(cor_prev))
362
            self.cor_for_executive_summary = np.mean(cor_broad[:])
363
        if (self.broadcast_method == 'linear_fit') and (len(cor_prev) > 1):
364
            afact, bfact = np.polyfit(self.plugin_prev, cor_prev[:, 0], 1)
365
            list_cor = self.origin_prev * afact + bfact
366
            cor_broad[:, 0] = list_cor
367
            self.cor_for_executive_summary = cor_broad[:]
368
        if (self.broadcast_method == 'nearest') and (len(cor_prev) > 1):
369
            for i, pos in enumerate(self.origin_prev):
370
                minpos = np.argmin(np.abs(pos - self.plugin_prev))
371
                cor_broad[i, 0] = cor_prev[minpos, 0]
372
            self.cor_for_executive_summary = cor_broad[:]
373
        out_datasets[1].data[:] = cor_broad[:]
374
        self.populate_meta_data('cor_preview', np.squeeze(cor_prev))
375
        self.populate_meta_data('centre_of_rotation',
376
                                out_datasets[1].data[:].squeeze(axis=1))
377
378
    def populate_meta_data(self, key, value):
379
        datasets = self.parameters['datasets_to_populate']
380
        in_meta_data = self.get_in_meta_data()[0]
381
        in_meta_data.set(key, value)
382
        for name in datasets:
383
            self.exp.index['in_data'][name].meta_data.set(key, value)
384
385
    def setup(self):
386
        self.exp.log(self.name + " Start calculating center of rotation")
387
        # set up the output dataset that is created by the plugin
388
        in_dataset, out_dataset = self.get_datasets()
389
        in_pData, out_pData = self.get_plugin_datasets()
390
        in_pData[0].plugin_data_setup('SINOGRAM', self.get_max_frames())
391
        slice_dirs = list(in_dataset[0].get_slice_dimensions())
392
        self.orig_full_shape = in_dataset[0].get_shape()
393
394
        # reduce the data as per data_subset parameter
395
        self.set_preview(in_dataset[0], self.parameters['preview'])
396
        total_frames = \
397
            self._calc_total_frames(in_dataset[0].get_preview(), slice_dirs)
398
399
        # copy all required information from in_dataset[0]
400
        fullData = in_dataset[0]
401
        new_shape = (np.prod(np.array(fullData.get_shape())[slice_dirs]), 1)
402
        self.orig_shape = \
403
            (np.prod(np.array(self.orig_full_shape)[slice_dirs]), 1)
404
        out_dataset[0].create_dataset(shape=new_shape,
405
                                      axis_labels=['x.pixels', 'y.pixels'],
406
                                      remove=True,
407
                                      transport='hdf5')
408
        out_dataset[0].add_pattern("METADATA", core_dims=(1,), slice_dims=(0,))
409
410
        out_dataset[1].create_dataset(shape=self.orig_shape,
411
                                      axis_labels=['x.pixels', 'y.pixels'],
412
                                      remove=True,
413
                                      transport='hdf5')
414
        out_dataset[1].add_pattern("METADATA", core_dims=(1,), slice_dims=(0,))
415
        out_pData[0].plugin_data_setup('METADATA', self.get_max_frames())
416
        out_pData[1].plugin_data_setup('METADATA', self.get_max_frames())
417
        out_pData[1].meta_data.set('fix_total_frames', total_frames)
418
        self.exp.log(self.name + " End")
419
420
    def _calc_total_frames(self, preview, slice_dims):
421
        starts, stops, steps, _ = preview.get_starts_stops_steps()
422
        lengths = [len(np.arange(starts[i], stops[i], steps[i]))
423
                   for i in range(len(starts))]
424
        return np.prod([lengths[i] for i in slice_dims])
425
426
    def nOutput_datasets(self):
427
        return 2
428
429
    def get_max_frames(self):
430
        return 'single'
431
432
    def fix_transport(self):
433
        return 'hdf5'
434
435 View Code Duplication
    def executive_summary(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
436
        if ((self.error_msg_1 == "")
437
                and (self.error_msg_2 == "")):
438
            msg = "Centre of rotation is : %s" % (
439
                str(self.cor_for_executive_summary))
440
        else:
441
            msg = "\n" + self.error_msg_1 + "\n" + self.error_msg_2
442
            msg2 = "(Not well) estimated centre of rotation is : %s" % (str(
443
                self.cor_for_executive_summary))
444
            cu.user_message(msg2)
445
        return [msg]
446