Test Failed
Pull Request — master (#708)
by Daniil
03:14
created

VoCenteringIterative.get_max_frames()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""
15
.. module:: vo_centering_iterative
16
   :platform: Unix
17
   :synopsis: A plugin to find the center of rotation per frame
18
.. moduleauthor:: Mark Basham <[email protected]>
19
"""
20
21
import math
22
import logging
23
import numpy as np
24
import scipy.ndimage as ndi
25
import scipy.ndimage.filters as filter
26
import pyfftw.interfaces.scipy_fftpack as fft
27
28
from scipy import signal
29
30
from savu.plugins.utils import register_plugin
31
from savu.data.plugin_list import CitationInformation
32
from savu.plugins.filters.base_filter import BaseFilter
33
from savu.plugins.driver.iterative_plugin import IterativePlugin
34
35
#    :u*param search_area: Search area in pixels from horizontal approximate \
36
#        centre of the image. Default: (-50, 50).
37
#    Deprecated !!!
38
39
@register_plugin
40
class VoCenteringIterative(BaseFilter, IterativePlugin):
41
    """
42
    A plugin to calculate the centre of rotation using the Vo Method
43
44
    :param ratio: The ratio between the size of object and FOV of \
45
        the camera. Default: 0.5.
46
    :param row_drop: Drop lines around vertical center of the \
47
        mask. Default: 20.
48
    :param search_radius: Use for fine searching. Default: 6.
49
    :param step: Step of fine searching. Default: 0.5.
50
    :param expand_by: The number of pixels to expand the search region by \
51
        on each iteration.  Default: 5
52
    :param boundary_distance: Accepted distance of minima from the boundary of\
53
        the listshift in the coarse search.  Default: 3.
54
    :u*param preview: A slice list of required frames (sinograms) to use in \
55
    the calulation of the centre of rotation (this will not reduce the data \
56
    size for subsequent plugins). Default: [].
57
    :param datasets_to_populate: A list of datasets which require this \
58
        information. Default: [].
59
    :param out_datasets: The default \
60
        names. Default: ['cor_raw','cor_fit', 'reliability'].
61
    :u*param start_pixel: The approximate centre. If value is None, take the \
62
        value from .nxs file else set to image centre. Default: None.
63
    """
64
65
    def __init__(self):
66
        super(VoCenteringIterative, self).__init__("VoCenteringIterative")
67
        self.search_area = (-20, 20)
68
        self.peak_height_min = 50000  # arbitrary
69
        self.min_dist = 3  # min distance deamed acceptible from boundary
70
        self.expand_by = 5  # expand the search region by this amount
71
        self.list_shift = None
72
        self.warning_level = 0
73
        self.final = False
74
        self.at_boundary = False
75
        self.list_metric = []
76
        self.expand_direction = None
77
78
    def _create_mask(self, Nrow, Ncol, obj_radius):
79
        du, dv = 1.0/Ncol, (Nrow-1.0)/(Nrow*2.0*math.pi)
80
        cen_row, cen_col = int(np.ceil(Nrow / 2.0)-1), int(np.ceil(Ncol / 2.0)-1)
81
        drop = self.parameters['row_drop']
82
        mask = np.zeros((Nrow, Ncol), dtype=np.float32)
83
        for i in range(Nrow):
84
            num1 = np.round(((i-cen_row)*dv/obj_radius)/du)
85
            p1, p2 = (np.clip(np.sort((-num1+cen_col, num1+cen_col)),
86
                              0, Ncol-1)).astype(int)
87
            mask[i, p1:p2+1] = np.ones(p2-p1+1, dtype=np.float32)
88
89
        if drop < cen_row:
90
            mask[cen_row-drop:cen_row+drop+1, :] = \
91
                np.zeros((2*drop + 1, Ncol), dtype=np.float32)
92
        mask[:, cen_col-1:cen_col+2] = np.zeros((Nrow, 3), dtype=np.float32)
93
        return mask
94
95
    def _get_start_shift(self, centre):
96
        if self.parameters['start_pixel'] is not None:
97
            shift = centre - int(self.parameters['start_pixel']/self.downlevel)
98
        else:
99
            in_mData = self.get_in_meta_data()[0]
100
            shift = centre - in_mData['centre'] if 'centre' in \
101
                list(in_mData.get_dictionary().keys()) else 0
102
        return int(shift)
103
104
    def _coarse_search(self, sino, list_shift):
105
        # search minsearch to maxsearch in 1 pixel steps
106
        list_metric = np.zeros(len(list_shift), dtype=np.float32)
107
        (Nrow, Ncol) = sino.shape
108
        # check angles to determine if a sinogram should be chopped off.
109
        # Copy the sinogram and flip left right, to make a full [0:2Pi] sino
110
        sino2 = np.fliplr(sino[1:])
111
        # This image is used for compensating the shift of sino2
112
        compensateimage = np.zeros((Nrow-1, Ncol), dtype=np.float32)
113
        # Start coarse search in which the shift step is 1
114
        compensateimage[:] = np.flipud(sino)[1:]
115
        mask = self._create_mask(2*Nrow-1, Ncol,
116
                                 0.5*self.parameters['ratio']*Ncol)
117
        count = 0
118
        for i in list_shift:
119
            sino2a = np.roll(sino2, i, axis=1)
120
            if i >= 0:
121
                sino2a[:, 0:i] = compensateimage[:, 0:i]
122
            else:
123
                sino2a[:, i:] = compensateimage[:, i:]
124
            list_metric[count] = np.sum(
125
                np.abs(fft.fftshift(fft.fft2(np.vstack((sino, sino2a)))))*mask)
126
            count += 1
127
        return list_metric
128
129
    def _fine_search(self, sino, raw_cor):
130
        (Nrow, Ncol) = sino.shape
131
        centerfliplr = (Ncol + 1.0) / 2.0 - 1.0
132
        # Use to shift the sino2 to the raw CoR
133
        shiftsino = np.int16(2*(raw_cor-centerfliplr))
134
        sino2 = np.roll(np.fliplr(sino[1:]), shiftsino, axis=1)
135
        lefttake = 0
136
        righttake = Ncol-1
137
        search_rad = self.parameters['search_radius']
138
139
        if raw_cor <= centerfliplr:
140
            lefttake = np.int16(np.ceil(search_rad+1))
141
            righttake = np.int16(np.floor(2*raw_cor-search_rad-1))
142
        else:
143
            lefttake = np.int16(np.ceil(raw_cor-(Ncol-1-raw_cor)+search_rad+1))
144
            righttake = np.int16(np.floor(Ncol-1-search_rad-1))
145
146
        Ncol1 = righttake-lefttake + 1
147
        mask = self._create_mask(2*Nrow-1, Ncol1,
148
                                 0.5*self.parameters['ratio']*Ncol)
149
        numshift = np.int16((2*search_rad)/self.parameters['step'])+1
150
        listshift = np.linspace(-search_rad, search_rad, num=numshift)
151
        listmetric = np.zeros(len(listshift), dtype=np.float32)
152
        num1 = 0
153
        factor1 = np.mean(sino[-1, lefttake:righttake])
154
        for i in listshift:
155
            sino2a = ndi.interpolation.shift(sino2, (0, i), prefilter=False)
156
            factor2 = np.mean(sino2a[0, lefttake:righttake])
157
            sino2a = sino2a*factor1/factor2
158
            sinojoin = np.vstack((sino, sino2a))
159
            listmetric[num1] = np.sum(np.abs(fft.fftshift(
160
                fft.fft2(sinojoin[:, lefttake:righttake + 1])))*mask)
161
            num1 = num1 + 1
162
        minpos = np.argmin(listmetric)
163
        rotcenter = raw_cor + listshift[minpos] / 2.0
164
        return rotcenter
165
166
    def _get_listshift(self):
167
        smin, smax = self.search_area if self.get_iteration() == 0 \
168
            else self._expand_search()
169
        list_shift = np.arange(smin, smax+2, 2) - self.start_shift
170
        logging.debug('list shift is %s', list_shift)
171
        return list_shift
172
173
    def _expand_search(self):
174
        if self.expand_direction == 'left':
175
            return self._expand_left()
176
        elif self.expand_direction == 'right':
177
            return self._expand_right()
178
        else:
179
            raise Exception('Unknown expand direction.')
180
181
    def _expand_left(self):
182
        smax = self.list_shift[0] - 2
183
        smin = smax - self.expand_by*2
184
185
        if smin <= -self.boundary:
186
            smin = -self.boundary
187
            self.at_boundary = True
188
        return smin, smax
189
190
    def _expand_right(self):
191
        smin = self.list_shift[-1] + 2
192
        smax = self.list_shift[-1] + self.expand_by*2
193
194
        if smax <= self.boundary:
195
            smax = self.boundary
196
            self.at_boundary = True
197
198
        return smin, smax
199
200
    def pre_process(self):
201
        pData = self.get_plugin_in_datasets()[0]
202
        label = pData.get_data_dimension_by_axis_label
203
        Ncol = pData.get_shape()[label('detector_x')]
204
        self.downlevel = 4 if Ncol > 1800 else 1
205
        self.downsample = slice(0, Ncol, self.downlevel)
206
        Ncol_downsample = len(np.arange(0, Ncol, self.downlevel))
207
        self.centre_fliplr = (Ncol_downsample - 1.0) / 2.0
208
        self.start_shift = self._get_start_shift(self.centre_fliplr)*2
209
        self.boundary = int(np.ceil(Ncol/4.0))
210
211
    def process_frames(self, data):
212
        if not self.final:
213
            logging.debug('performing coarse search for iteration %s',
214
                          self.get_iteration())
215
            sino = filter.gaussian_filter(data[0][:, self.downsample], (3, 1))
216
            list_shift = self._get_listshift()
217
            list_metric = self._coarse_search(sino, list_shift)
218
            self._update_lists(list(list_shift), list(list_metric))
219
220
            self.coarse_cor, dist, reliability_metrics = \
221
                self._analyse_result(self.list_metric, self.list_shift)
222
223
            return [np.array([self.coarse_cor]), np.array([dist]),
224
                    np.array([reliability_metrics]), np.array([self.list_metric])]
225
        else:
226
            logging.debug("performing fine search")
227
            sino = filter.median_filter(data[0], (2, 2))
228
            cor = self._fine_search(sino, self.coarse_cor)
229
            self.set_processing_complete()
230
            return [np.array([cor]), np.array([self.list_metric])]
231
232
    def _update_lists(self, shift, metric):
233
        if self.expand_direction == 'left':
234
            self.list_shift = shift + self.list_shift
235
            self.list_metric = metric + self.list_metric
236
        elif self.expand_direction == 'right':
237
            self.list_shift += shift
238
            self.list_metric += metric
239
        else:
240
            self.list_shift = shift
241
            self.list_metric = metric
242
243
    def _analyse_result(self, metric, shift):
244
        minpos = np.argmin(metric)
245
        dist = min(abs(len(shift) - minpos), -minpos)
246
247
        rot_centre = (self.centre_fliplr + shift[minpos] / 2.0)*self.downlevel
248
        peaks = self._find_peaks(metric)
249
250
        good_nPeaks = True
251
        if len(peaks) != 1:
252
            good_nPeaks = False
253
        good_peak_height = True if np.any(peaks) and \
254
            max(peaks) > self.peak_height_min else False
255
256
        metric = 0.0
257
        if (good_peak_height and good_nPeaks):
258
            metric = 1.0
259
        elif (good_peak_height or good_nPeaks):
260
            metric = 0.5
261
262
        return rot_centre, dist, metric
263
264
    def _find_peaks(self, metric):
265
        import peakutils
266
        grad2 = np.gradient(np.gradient(metric))
267
        grad2[grad2 < 0] = 0
268
        index = peakutils.indexes(grad2, thres=0.5, min_dist=3)
269
        return np.sort(grad2[index])
270
271
    def post_process(self):
272
        logging.debug("in the post process function")
273
        in_datasets, out_datasets = self.get_datasets()
274
275
        # =====================================================================
276
        # Analyse distance of centre values from boundary of search region
277
        dist_from_boundary = np.squeeze(out_datasets[1].data[...])
278
        near_boundary = np.where(abs(dist_from_boundary) < self.min_dist)[0]
279
        nEntries = len(dist_from_boundary)
280
281
        # Case1: Greater than half the results are near the boundary
282
        if (len(near_boundary)/float(nEntries)) > 0.5:
283
            # find which boundary
284
            signs = np.sign(dist_from_boundary[near_boundary])
285
            left, right = len(signs[signs < 0]), len(signs[signs > 0])
286
287
            logging.debug("res: results are near boundary")
288
            if not self.at_boundary:
289
                # if they are all at the same boundary expand the search region
290
                if not (left and right):
291
                    logging.debug("res: expanding")
292
                    self.expand_direction = 'left' if left else 'right'
293
                # if they are at different boundaries determine which values
294
                # are most reliable
295
                else:
296
                    logging.debug("res: choosing a boundary")
297
                    self.expand_direction = \
298
                        self._choose_boundary(near_boundary, signs)
299
                    # case that the results are close to different boundaries
300
                    # Analyse reliability and choose direction
301
            else:
302
                logging.debug("res: at the edge of the boundary")
303
                # Move on to the fine search
304
                self._set_final_process()
305
                self.warning_level = 1 # change this to be more descriptive ***
306
        else:
307
            logging.debug("result is not near the boundary")
308
            # Move on to the fine search
309
            self._set_final_process()
310
        # =====================================================================
311
312
    def _choose_boundary(self, idx, signs):
313
        good, maybe, bad = self._get_reliability_levels()
314
        sign = self._check_entries(good, signs[good])
315
        self.warning_level = 0
316
        if not sign:
317
            sign = self._check_entries(maybe, signs[maybe])
318
            self.warning_level = 1
319
        if not sign:
320
            sign = self._check_entries(bad, signs[bad])
321
            self.warning_level = 2
322
        return sign
323
324
    def _check_entries(self, idx, signs):
325
        if np.any(idx):
326
            left, right = signs[signs < 0], signs[signs > 0]
327
            if not (left and right):
328
                # use all the good ones
329
                return 'left' if left else 'right'
330
        return None
331
332
    def _get_reliability_levels(self, final=False):
333
        in_datasets, out_datasets = \
334
            self.get_datasets() if not final else self.get_original_datasets()
335
        reliability = np.squeeze(out_datasets[2].data[...])
336
        logging.debug('reliability is %s', reliability)
337
        good = np.where(reliability == 1.0)[0]
338
        maybe = np.where(reliability == 0.5)[0]
339
        bad = np.where(reliability == 0.0)[0]
340
        return good, maybe, bad
341
342
    def final_post_process(self):
343
344
        # choose which values to include
345
        good, maybe, bad = self._get_reliability_levels(final=True)
346
        # Do I need to change the warning levels here?
347
        entries = good if np.any(good) else maybe if np.any(maybe) else bad
348
        self.warning_level = 0 if np.any(good) else 1 if np.any(maybe) else 2
349
        logging.debug('sinograms used in final calculations are %s', entries)
350
351
        # do some curve fitting here
352
        # Get a handle on the original datasets
353
        in_dataset, out_dataset = self.get_original_datasets()
354
        cor_raw = np.squeeze(out_dataset[0].data[...])[entries]
355
        cor_fit = out_dataset[1].data[...]
356
        fit = np.zeros(cor_fit.shape)
357
        fit[:] = np.median(cor_raw)
358
        cor_fit = fit
359
        out_dataset[1].data[:] = cor_fit[:]
360
361
        self.populate_meta_data('cor_raw', cor_raw)
362
        self.populate_meta_data('centre_of_rotation',
363
                                out_dataset[1].data[:].squeeze(axis=1))
364
365
    def _set_final_process(self):
366
        self.final = True
367
        self.post_process = self.final_post_process
368
        in_dataset, out_dataset = self.get_datasets()
369
        self.set_iteration_datasets(
370
                self.get_iteration()+1, [in_dataset[0]], [out_dataset[0]])
371
372
    def populate_meta_data(self, key, value):
373
        datasets = self.parameters['datasets_to_populate']
374
        in_meta_data = self.get_in_meta_data()[0]
375
        in_meta_data.set(key, value)
376
        for name in datasets:
377
            self.exp.index['in_data'][name].meta_data.set(key, value)
378
379
    def setup(self):
380
        # set up the output dataset that is created by the plugin
381
        in_dataset, out_dataset = self.get_datasets()
382
383
        self.orig_full_shape = in_dataset[0].get_shape()
384
385
        # reduce the data as per data_subset parameter
386
        self.set_preview(in_dataset[0], self.parameters['preview'])
387
388
        in_pData, out_pData = self.get_plugin_datasets()
389
        in_pData[0].plugin_data_setup('SINOGRAM', self.get_max_frames())
390
        # copy all required information from in_dataset[0]
391
        fullData = in_dataset[0]
392
393
        slice_dirs = np.array(in_dataset[0].get_slice_dimensions())
394
        new_shape = (np.prod(np.array(fullData.get_shape())[slice_dirs]), 1)
395
        self.orig_shape = \
396
            (np.prod(np.array(self.orig_full_shape)[slice_dirs]), 1)
397
398
        self._create_metadata_dataset(out_dataset[0], new_shape)
399
        self._create_metadata_dataset(out_dataset[1], self.orig_shape)
400
        self._create_metadata_dataset(out_dataset[2], new_shape)
401
402
        # output metric
403
        new_shape = (np.prod(np.array(fullData.get_shape())[slice_dirs]), 21)
404
        self._create_metadata_dataset(out_dataset[3], new_shape)
405
406
        out_pData[0].plugin_data_setup('METADATA', self.get_max_frames())
407
        out_pData[1].plugin_data_setup('METADATA', self.get_max_frames())
408
        out_pData[2].plugin_data_setup('METADATA', self.get_max_frames())
409
        out_pData[3].plugin_data_setup('METADATA', self.get_max_frames())
410
411
    def _create_metadata_dataset(self, data, shape):
412
        data.create_dataset(shape=shape,
413
                            axis_labels=['x.pixels', 'y.pixels'],
414
                            remove=True,
415
                            transport='hdf5')
416
        data.add_pattern("METADATA", core_dims=(1,), slice_dims=(0,))
417
418
    def nOutput_datasets(self):
419
        return 4
420
421
    def get_max_frames(self):
422
        return 'single'
423
424
    def fix_transport(self):
425
        # This plugin requires communication between processes in the post
426
        # process, which it does via files
427
        return 'hdf5'
428
429
    def executive_summary(self):
430
        if self.warning_level == 0:
431
            msg = "Confidence in the centre value is high."
432
        elif self.warning_level == 1:
433
            msg = "Confidence in the centre value is average."
434
        else:
435
            msg = "Confidence in the centre value is low."
436
        return [msg]
437
438 View Code Duplication
    def get_citation_information(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
439
        cite_info = CitationInformation()
440
        cite_info.description = \
441
            ("The center of rotation for this reconstruction was calculated " +
442
             "automatically using the method described in this work")
443
        cite_info.bibtex = \
444
            ("@article{vo2014reliable,\n" +
445
             "title={Reliable method for calculating the center of rotation " +
446
             "in parallel-beam tomography},\n" +
447
             "author={Vo, Nghia T and Drakopoulos, Michael and Atwood, " +
448
             "Robert C and Reinhard, Christina},\n" +
449
             "journal={Optics Express},\n" +
450
             "volume={22},\n" +
451
             "number={16},\n" +
452
             "pages={19078--19086},\n" +
453
             "year={2014},\n" +
454
             "publisher={Optical Society of America}\n" +
455
             "}")
456
        cite_info.endnote = \
457
            ("%0 Journal Article\n" +
458
             "%T Reliable method for calculating the center of rotation in " +
459
             "parallel-beam tomography\n" +
460
             "%A Vo, Nghia T\n" +
461
             "%A Drakopoulos, Michael\n" +
462
             "%A Atwood, Robert C\n" +
463
             "%A Reinhard, Christina\n" +
464
             "%J Optics Express\n" +
465
             "%V 22\n" +
466
             "%N 16\n" +
467
             "%P 19078-19086\n" +
468
             "%@ 1094-4087\n" +
469
             "%D 2014\n" +
470
             "%I Optical Society of America")
471
        cite_info.doi = "https://doi.org/10.1364/OE.22.019078"
472
        return cite_info
473