Test Failed
Pull Request — master (#708)
by Daniil
03:14
created

savu.plugins.reconstructions.base_recon   F

Complexity

Total Complexity 90

Size/Duplication

Total Lines 545
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 312
dl 0
loc 545
rs 2
c 0
b 0
f 0
wmc 90

42 Methods

Rating   Name   Duplication   Size   Complexity  
A BaseRecon.br_array_pad() 0 9 3
A BaseRecon.get_vol_shape() 0 2 1
B BaseRecon.base_process_frames_before() 0 37 6
A BaseRecon.get_reconstruction_alg() 0 2 1
A BaseRecon.__init__() 0 16 1
A BaseRecon._get_volume_dimensions() 0 2 1
B BaseRecon.__get_outer_pad() 0 19 7
A BaseRecon.get_padding_algorithms() 0 3 1
A BaseRecon._get_shape() 0 14 3
A BaseRecon.get_centre_shift() 0 3 1
A BaseRecon.base_process_frames_after() 0 7 3
A BaseRecon.base_pre_process() 0 27 3
A BaseRecon.__set_padding_alg() 0 6 3
A BaseRecon._get_detX_dim() 0 3 1
B BaseRecon.setup() 0 43 6
A BaseRecon.get_angles() 0 7 1
A BaseRecon.get_frame_params() 0 4 1
A BaseRecon.set_centre_of_rotation() 0 21 5
A BaseRecon._set_volume_dimensions() 0 14 1
A BaseRecon.__get_pad_values() 0 6 1
A BaseRecon.get_max_frames() 0 11 1
A BaseRecon.nOutput_datasets() 0 2 1
A BaseRecon.get_sino_centre_method() 0 10 5
A BaseRecon.pad_sino() 0 9 1
A BaseRecon.get_centre_offset() 0 7 1
A BaseRecon.nInput_datasets() 0 14 5
A BaseRecon.get_initial_data() 0 10 1
A BaseRecon.reconstruct_pre_process() 0 5 1
A BaseRecon.get_pad_amount() 0 2 1
A BaseRecon.__set_cor_from_meta_data() 0 7 2
A BaseRecon.get_fov_fraction() 0 8 1
A BaseRecon.set_function() 0 14 4
A BaseRecon.get_cors() 0 8 1
A BaseRecon.get_slice_axis() 0 12 1
A BaseRecon.__set_pad_amount() 0 2 1
A BaseRecon.executive_summary() 0 8 3
A BaseRecon.keep_sino() 0 3 1
A BaseRecon.set_mask() 0 2 1
A BaseRecon.crop_sino() 0 10 1
A BaseRecon.__polyfit_cor() 0 11 2
A BaseRecon._get_axis_labels() 0 22 2
A BaseRecon.__make_lambda() 0 7 3

How to fix   Complexity   

Complexity

Complex classes like savu.plugins.reconstructions.base_recon often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: base_recon
17
   :platform: Unix
18
   :synopsis: A base class for all reconstruction methods
19
20
.. moduleauthor:: Mark Basham <[email protected]>
21
22
"""
23
import math
24
import numpy as np
25
np.seterr(divide='ignore', invalid='ignore')
26
27
import savu.core.utils as cu
28
from savu.plugins.plugin import Plugin
29
30
MAX_OUTER_PAD = 2.1
31
32
33
class BaseRecon(Plugin):
34
    """
35
    A base class for reconstruction plugins
36
37
    :u*param centre_of_rotation: Centre of rotation to use for the \
38
    reconstruction. Default: 0.0.
39
40
    :u*param init_vol: Dataset to use as volume initialiser \
41
    (doesn't currently work with preview). Default: None.
42
43
    :param centre_pad: Pad the sinogram to centre it in order to fill the \
44
    reconstructed volume ROI for asthetic purposes.\
45
    NB: Only available for selected algorithms and will be ignored otherwise. \
46
    WARNING: This will significantly increase the size of the data and the \
47
    time to compute the reconstruction). Default: False.
48
49
    :param outer_pad: Pad the sinogram width to fill the reconstructed volume \
50
    for asthetic purposes.\
51
    Choose from True (defaults to sqrt(2)), False or float <= 2.1. \
52
    NB: Only available for selected algorithms and will be ignored otherwise.\
53
    WARNING: This will increase the size of the data and the \
54
    time to compute the reconstruction). Default: False.
55
56
    :u*param log: Take the log of the data before reconstruction \
57
    (True or False). Default: True.
58
59
    :u*param preview: A slice list of required frames. Default: [].
60
61
    :param force_zero: Set any values in the reconstructed image outside of \
62
    this range to zero. Default: [None, None].
63
64
    :param ratio: Ratio between the diameter of a circle mask and the width of\
65
    a reconstructed image. If passed as a list or tuple, the second value is \
66
    assigned to the outer mask area, e.g [0.95, 0.0]. Default: 0.95.
67
68
    :param log_func: Override the default log \
69
        function. Default: 'np.nan_to_num(-np.log(sino))'.
70
71
    :param vol_shape: Override the size of the reconstruction volume with an \
72
    integer value. Default: 'fixed'.
73
    """
74
75
    def __init__(self, name='BaseRecon'):
76
        super(BaseRecon, self).__init__(name)
77
        self.nOut = 1
78
        self.nIn = 1
79
        self.scan_dim = None
80
        self.rep_dim = None
81
        self.br_vol_shape = None
82
        self.frame_angles = None
83
        self.frame_cors = None
84
        self.frame_init_data = None
85
        self.centre = None
86
        self.base_pad_amount = None
87
        self.padding_alg = False
88
        self.cor_shift = 0
89
        self.init_vol = False
90
        self.cor_as_dataset = False
91
92
    def base_pre_process(self):
93
        in_data, out_data = self.get_datasets()
94
        in_pData, out_pData = self.get_plugin_datasets()
95
        self.pad_dim = \
96
            in_pData[0].get_data_dimension_by_axis_label('x', contains=True)
97
        in_meta_data = self.get_in_meta_data()[0]
98
        self.__set_padding_alg()
99
100
        self.exp.log(self.name + " End")
101
        self.br_vol_shape = out_pData[0].get_shape()
102
        self.set_centre_of_rotation(in_data[0], in_meta_data, in_pData[0])
103
104
        self.main_dir = in_data[0].get_data_patterns()['SINOGRAM']['main_dir']
105
        self.angles = in_meta_data.get('rotation_angle')
106
        if len(self.angles.shape) != 1:
107
            self.scan_dim = in_data[0].get_data_dimension_by_axis_label('scan')
108
        self.slice_dirs = out_data[0].get_slice_dimensions()
109
110
        shape = in_pData[0].get_shape()
111
        factor = self.__get_outer_pad()
112
        self.sino_pad = int(math.ceil(factor * shape[self.pad_dim]))
113
114
        self.sino_func, self.cor_func = self.set_function(shape) if \
115
            self.padding_alg else self.set_function(False)
116
117
        self.range = self.parameters['force_zero']
118
        self.fix_sino = self.get_sino_centre_method()
119
120
    def __get_outer_pad(self):
121
        # length of diagonal of square is side*sqrt(2)
122
        factor = math.sqrt(2) - 1
123
        pad = self.parameters['outer_pad'] if 'outer_pad' in \
124
            list(self.parameters.keys()) else False
125
126
        if pad is not False and not self.padding_alg:
127
            msg = 'This reconstruction algorithm cannot be padded.'
128
            cu.user_message(msg)
129
            return 0
130
131
        if isinstance(pad, bool):
132
            return factor if pad is True else 0
133
        factor = float(pad)
134
        if factor > MAX_OUTER_PAD:
135
            factor = MAX_OUTER_PAD
136
            msg = 'Maximum outer_pad value is 2.1, using this instead'
137
            cu.user_message(msg)
138
        return float(pad)
139
140
    def __set_padding_alg(self):
141
        """ Determine if this is an algorithm that allows sinogram padding. """
142
        pad_algs = self.get_padding_algorithms()
143
        alg = self.parameters['algorithm'] if 'algorithm' in \
144
            list(self.parameters.keys()) else None
145
        self.padding_alg = True if alg in pad_algs else False
146
147
    def get_vol_shape(self):
148
        return self.br_vol_shape
149
150
    def set_centre_of_rotation(self, inData, mData, pData):
151
        # if cor has been passed as a dataset then do nothing
152
        if isinstance(self.parameters['centre_of_rotation'], str):
153
            return
154
        if 'centre_of_rotation' in list(mData.get_dictionary().keys()):
155
            cor = self.__set_cor_from_meta_data(mData, inData)
156
        else:
157
            val = self.parameters['centre_of_rotation']
158
            if isinstance(val, dict):
159
                cor = self.__polyfit_cor(val, inData)
160
            else:
161
                sdirs = inData.get_slice_dimensions()
162
                cor = np.ones(np.prod([inData.get_shape()[i] for i in sdirs]))
163
                # if centre of rotation has not been set then fix it in the
164
                # centre
165
                val = val if val != 0 else \
166
                    (self.get_vol_shape()[self._get_detX_dim()]) / 2.0
167
                cor *= val
168
                # mData.set('centre_of_rotation', cor) see Github ticket
169
        self.cor = cor
170
        self.centre = self.cor[0]
171
172
    def __set_cor_from_meta_data(self, mData, inData):
173
        cor = mData.get('centre_of_rotation')
174
        sdirs = inData.get_slice_dimensions()
175
        total_frames = np.prod([inData.get_shape()[i] for i in sdirs])
176
        if total_frames > len(cor):
177
            cor = np.tile(cor, int(total_frames / len(cor)))
178
        return cor
179
180
    def __polyfit_cor(self, cor_dict, inData):
181
        if 'detector_y' in list(inData.meta_data.get_dictionary().keys()):
182
            y = inData.meta_data.get('detector_y')
183
        else:
184
            yDim = inData.get_data_dimension_by_axis_label('detector_y')
185
            y = np.arange(inData.get_shape()[yDim])
186
187
        z = np.polyfit(list(map(int, list(cor_dict.keys()))), list(cor_dict.values()), 1)
188
        p = np.poly1d(z)
189
        cor = p(y)
190
        return cor
191
192
    def set_function(self, pad_shape):
193
        if not pad_shape:
194
            def cor_func(cor): return cor
195
            if self.parameters['log']:
196
                sino_func = self.__make_lambda()
197
            else:
198
                sino_func = self.__make_lambda(log=False)
199
        else:
200
            def cor_func(cor): return cor + self.sino_pad
201
            if self.parameters['log']:
202
                sino_func = self.__make_lambda(pad=pad_shape)
203
            else:
204
                sino_func = self.__make_lambda(pad=pad_shape, log=False)
205
        return sino_func, cor_func
206
207
    def __make_lambda(self, log=True, pad=False):
208
        log_func = 'np.nan_to_num(sino)' if not log else self.parameters['log_func']
209
        if pad:
210
            pad_tuples, mode = self.__get_pad_values(pad)
211
            log_func = log_func.replace(
212
                    'sino', 'np.pad(sino, %s, "%s")' % (pad_tuples, mode))
213
        return eval("lambda sino: " + log_func)
214
215
    def __get_pad_values(self, pad_shape):
216
        mode = 'edge'
217
        pad_tuples = [(0, 0)] * (len(pad_shape) - 1)
218
        pad_tuples.insert(self.pad_dim, (self.sino_pad, self.sino_pad))
219
        pad_tuples = tuple(pad_tuples)
220
        return pad_tuples, mode
221
222
    def base_process_frames_before(self, data):
223
        """
224
        Reconstruct a single sinogram with the provided centre of rotation
225
        """
226
        sl = self.get_current_slice_list()[0]
227
        init = data[1] if self.init_vol else None
228
        angles = \
229
            self.angles[:, sl[self.scan_dim]] if self.scan_dim else self.angles
230
231
        self.frame_angles = angles
232
233
        dim_sl = sl[self.main_dir]
234
235
        if self.cor_as_dataset:
236
            self.frame_cors = self.cor_func(data[len(data) - 1])
237
        else:
238
            frame_nos = \
239
                self.get_plugin_in_datasets()[0].get_current_frame_idx()
240
            a = self.cor[tuple([frame_nos])]
241
            self.frame_cors = self.cor_func(a)
242
243
        # for extra padded frames that make up the numbers
244
        if not self.frame_cors.shape:
245
            self.frame_cors = np.array([self.centre])
246
247
        len_data = len(np.arange(dim_sl.start, dim_sl.stop, dim_sl.step))
248
249
        missing = [self.centre] * (len(self.frame_cors) - len_data)
250
        self.frame_cors = np.append(self.frame_cors, missing)
251
252
        # fix to remove NaNs in the initialised image
253
        if init is not None:
254
            init[np.isnan(init)] == 0.0
255
        self.frame_init_data = init
256
257
        data[0] = self.fix_sino(self.sino_func(data[0]), self.frame_cors[0])
258
        return data
259
260
    def base_process_frames_after(self, data):
261
        lower_range, upper_range = self.range
262
        if lower_range is not None:
263
            data[data < lower_range] = 0
264
        if upper_range is not None:
265
            data[data > upper_range] = 0
266
        return data
267
268
    def get_padding_algorithms(self):
269
        """ A list of algorithms that allow the data to be padded. """
270
        return []
271
272
    def pad_sino(self, sino, cor):
273
        """  Pad the sinogram so the centre of rotation is at the centre. """
274
        detX = self._get_detX_dim()
275
        pad = self.get_centre_offset(sino, cor, detX)
276
        self.cor_shift = pad[0]
277
        pad_tuples = [(0, 0)] * (len(sino.shape) - 1)
278
        pad_tuples.insert(detX, tuple(pad))
279
        self.__set_pad_amount(max(pad))
280
        return np.pad(sino, tuple(pad_tuples), mode='edge')
281
282
    def _get_detX_dim(self):
283
        pData = self.get_plugin_in_datasets()[0]
284
        return pData.get_data_dimension_by_axis_label('x', contains=True)
285
286
    def get_centre_offset(self, sino, cor, detX):
287
        centre_pad = self.br_array_pad(cor, sino.shape[detX])
288
        sino_width = sino.shape[detX]
289
        new_width = sino_width + max(centre_pad)
290
        sino_pad = int(math.ceil(float(sino_width) / new_width * self.sino_pad) // 2)
291
        pad = np.array([sino_pad]*2) + centre_pad
292
        return pad
293
294
    def get_centre_shift(self, sino, cor):
295
        detX = self._get_detX_dim()
296
        return max(self.get_centre_offset(sino, self.centre, detX))
297
298
    def crop_sino(self, sino, cor):
299
        """  Crop the sinogram so the centre of rotation is at the centre. """
300
        detX = self._get_detX_dim()
301
        start, stop = self.br_array_pad(cor, sino.shape[detX])[::-1]
302
        self.cor_shift = -start
303
        sl = [slice(None)] * len(sino.shape)
304
        sl[detX] = slice(start, sino.shape[detX] - stop)
305
        sino = sino[tuple(sl)]
306
        self.set_mask(sino.shape)
307
        return sino
308
309
    def br_array_pad(self, ctr, nPixels):
310
        width = nPixels - 1.0
311
        alen = ctr
312
        blen = width - ctr
313
        mid = (width - 1.0) / 2.0
314
        shift = round(abs(blen - alen))
315
        p_low = 0 if (ctr > mid) else shift
316
        p_high = shift + 0 if (ctr > mid) else 0
317
        return np.array([int(p_low), int(p_high)])
318
319
    def keep_sino(self, sino, cor):
320
        """ No change to the sinogram """
321
        return sino
322
323
    def get_sino_centre_method(self):
324
        centre_pad = self.keep_sino
325
        if 'centre_pad' in list(self.parameters.keys()):
326
            cpad = self.parameters['centre_pad']
327
            if not (cpad is True or cpad is False):
328
                raise Exception('Unknown value for "centre_pad", please choose'
329
                                ' True or False.')
330
            centre_pad = self.pad_sino if cpad and self.padding_alg \
331
                else self.crop_sino
332
        return centre_pad
333
334
    def __set_pad_amount(self, pad_amount):
335
        self.base_pad_amount = pad_amount
336
337
    def get_pad_amount(self):
338
        return self.base_pad_amount
339
340
    def get_fov_fraction(self, sino, cor):
341
        """ Get the fraction of the original FOV that can be reconstructed due\
342
        to offset centre """
343
        pData = self.get_plugin_in_datasets()[0]
344
        detX = pData.get_data_dimension_by_axis_label('x', contains=True)
345
        original_length = sino.shape[detX]
346
        shift = self.get_centre_shift(sino, cor)
347
        return (original_length - shift) / float(original_length)
348
349
    def get_reconstruction_alg(self):
350
        return None
351
352
    def get_angles(self):
353
        """ Get the angles associated with the current sinogram(s).
354
355
        :returns: Angles of the current frames.
356
        :rtype: np.ndarray
357
        """
358
        return self.frame_angles
359
360
    def get_cors(self):
361
        """
362
        Get the centre of rotations associated with the current sinogram(s).
363
364
        :returns: Centre of rotation values for the current frames.
365
        :rtype: np.ndarray
366
        """
367
        return self.frame_cors + self.cor_shift
368
369
    def set_mask(self, shape):
370
        pass
371
372
    def get_initial_data(self):
373
        """
374
        Get the initial data (if it is exists) associated with the current \
375
        sinogram(s).
376
377
        :returns: The section of the initialisation data associated with the \
378
            current frames.
379
        :rtype: np.ndarray or None
380
        """
381
        return self.frame_init_data
382
383
    def get_frame_params(self):
384
        params = [self.get_cors(), self.get_angles(), self.get_vol_shape(),
385
                  self.get_initial_data()]
386
        return params
387
388
    def setup(self):
389
        in_dataset, out_dataset = self.get_datasets()
390
        # reduce the data as per data_subset parameter
391
        self.preview_flag = \
392
            self.set_preview(in_dataset[0], self.parameters['preview'])
393
394
        self._set_volume_dimensions(in_dataset[0])
395
        axis_labels = self._get_axis_labels(in_dataset[0])
396
        shape = self._get_shape(in_dataset[0])
397
398
        # output dataset
399
        out_dataset[0].create_dataset(axis_labels=axis_labels, shape=shape)
400
        out_dataset[0].add_volume_patterns(*self._get_volume_dimensions())
401
402
        # set information relating to the plugin data
403
        in_pData, out_pData = self.get_plugin_datasets()
404
405
        self.init_vol = 1 if 'init_vol' in list(self.parameters.keys()) and\
406
            self.parameters['init_vol'] else 0
407
        self.cor_as_dataset = 1 if isinstance(
408
            self.parameters['centre_of_rotation'], str) else 0
409
410
        for i in range(len(in_dataset) - self.init_vol - self.cor_as_dataset):
411
            in_pData[i].plugin_data_setup('SINOGRAM', self.get_max_frames(),
412
                                          slice_axis=self.get_slice_axis())
413
            idx = 1
414
415
        # initial volume dataset
416
        if self.init_vol:
417
#            from savu.data.data_structures.data_types import Replicate
418
#            if self.rep_dim:
419
#                in_dataset[idx].data = Replicate(
420
#                    in_dataset[idx], in_dataset[0].get_shape(self.rep_dim))
421
            in_pData[1].plugin_data_setup('VOLUME_XZ', self.get_max_frames())
422
            idx += 1
0 ignored issues
show
introduced by
The variable idx does not seem to be defined in case the for loop on line 410 is not entered. Are you sure this can never be the case?
Loading history...
423
424
        # cor dataset
425
        if self.cor_as_dataset:
426
            self.cor_as_dataset = True
427
            in_pData[idx].plugin_data_setup('METADATA', self.get_max_frames())
428
429
        # set pattern_name and nframes to process for all datasets
430
        out_pData[0].plugin_data_setup('VOLUME_XZ', self.get_max_frames())
431
432
    def _get_axis_labels(self, in_dataset):
433
        """
434
        Get the new axis labels for the output dataset - this is now a volume.
435
436
        Parameters
437
        ----------
438
        in_dataset : :class:`savu.data.data_structures.data.Data`
439
            The input dataset to the plugin.
440
441
        Returns
442
        -------
443
        labels : dict
444
            The axis labels for the dataset that is output from the plugin.
445
446
        """
447
        labels = in_dataset.data_info.get('axis_labels')[0]
448
        volX, volY, volZ = self._get_volume_dimensions()
449
        labels = [str(volX) + '.voxel_x.voxels', str(volZ) + '.voxel_z.voxels']
450
        if volY:
451
            labels.append(str(volY) + '.voxel_y.voxels')
452
        labels = {in_dataset: labels}
453
        return labels
454
455
    def _set_volume_dimensions(self, data):
456
        """
457
        Map the input dimensions to the output volume dimensions
458
459
        Parameters
460
        ----------
461
        in_dataset : :class:`savu.data.data_structures.data.Data`
462
            The input dataset to the plugin.
463
        """
464
        data._finalise_patterns()
465
        self.volX = data.get_data_dimension_by_axis_label("rotation_angle")
466
        self.volZ = data.get_data_dimension_by_axis_label("x", contains=True)
467
        self.volY = data.get_data_dimension_by_axis_label(
468
            "y", contains=True, exists=True)
469
470
    def _get_volume_dimensions(self):
471
        return self.volX, self.volY, self.volZ
472
473
    def _get_shape(self, in_dataset):
474
        shape = list(in_dataset.get_shape())
475
        volX, volY, volZ = self._get_volume_dimensions()
476
477
        if self.parameters['vol_shape'] in ('auto', 'fixed'):
478
            shape[volX] = shape[volZ]
479
        else:
480
            shape[volX] = self.parameters['vol_shape']
481
            shape[volZ] = self.parameters['vol_shape']
482
483
        if 'resolution' in self.parameters.keys():
484
            shape[volX] /= self.parameters['resolution']
485
            shape[volZ] /= self.parameters['resolution']
486
        return tuple(shape)
487
488
    def get_max_frames(self):
489
        """
490
        Number of data frames to pass to each instance of process_frames func
491
492
        Returns
493
        -------
494
        str or int
495
            "single", "multiple" or integer (only to be used if the number of
496
                                             frames MUST be fixed.)
497
        """
498
        return 'multiple'
499
500
    def get_slice_axis(self):
501
        """
502
        Fix the fastest changing slice dimension
503
504
        Returns
505
        -------
506
        str or None
507
            str should be the axis_label corresponding to the fastest changing
508
            dimension
509
510
        """
511
        return None
512
513
    def nInput_datasets(self):
514
        nIn = 1
515
        if 'init_vol' in self.parameters.keys() and \
516
                self.parameters['init_vol']:
517
            if len(self.parameters['init_vol'].split('.')) == 3:
518
                name, temp, self.rep_dim = self.parameters['init_vol']
519
                self.parameters['init_vol'] = name
520
            nIn += 1
521
            self.parameters['in_datasets'].append(self.parameters['init_vol'])
522
        if isinstance(self.parameters['centre_of_rotation'], str):
523
            self.parameters['in_datasets'].append(
524
                self.parameters['centre_of_rotation'])
525
            nIn += 1
526
        return nIn
527
528
    def nOutput_datasets(self):
529
        return self.nOut
530
531
    def reconstruct_pre_process(self):
532
        """
533
        Should be overridden to perform pre-processing in a child class
534
        """
535
        pass
536
537
    def executive_summary(self):
538
        summary = []
539
        if not self.preview_flag:
540
            summary.append(("WARNING: Ignoring preview parameters as a preview"
541
                            " has already been applied to the data."))
542
        if len(summary) > 0:
543
            return summary
544
        return ["Nothing to Report"]
545