Test Failed
Pull Request — master (#809)
by Daniil
03:38
created

savu.plugins.reconstructions.base_recon   F

Complexity

Total Complexity 85

Size/Duplication

Total Lines 508
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 314
dl 0
loc 508
rs 2
c 0
b 0
f 0
wmc 85

41 Methods

Rating   Name   Duplication   Size   Complexity  
A BaseRecon.br_array_pad() 0 9 3
A BaseRecon.get_vol_shape() 0 2 1
B BaseRecon.base_process_frames_before() 0 38 6
A BaseRecon.__init__() 0 16 1
A BaseRecon.__get_outer_pad() 0 14 5
A BaseRecon.get_centre_shift() 0 3 1
A BaseRecon.base_process_frames_after() 0 8 3
A BaseRecon.base_pre_process() 0 25 2
A BaseRecon._get_detX_dim() 0 3 1
A BaseRecon.set_centre_of_rotation() 0 22 5
A BaseRecon.__get_pad_values() 0 6 1
A BaseRecon.pad_sino() 0 9 1
A BaseRecon.get_centre_offset() 0 7 1
A BaseRecon.__set_cor_from_meta_data() 0 7 2
A BaseRecon.set_function() 0 18 5
A BaseRecon.keep_sino() 0 3 1
A BaseRecon.populate_metadata_to_output() 0 8 1
A BaseRecon.crop_sino() 0 10 1
A BaseRecon.__polyfit_cor() 0 11 2
A BaseRecon.__make_lambda() 0 7 3
A BaseRecon.get_reconstruction_alg() 0 2 1
A BaseRecon._get_volume_dimensions() 0 2 1
A BaseRecon._get_shape() 0 14 3
B BaseRecon.setup() 0 46 6
A BaseRecon.get_angles() 0 7 1
A BaseRecon.get_frame_params() 0 4 1
A BaseRecon._set_volume_dimensions() 0 14 1
A BaseRecon.get_max_frames() 0 11 1
A BaseRecon.nOutput_datasets() 0 2 1
A BaseRecon.get_sino_centre_method() 0 9 5
A BaseRecon.nInput_datasets() 0 14 5
A BaseRecon.get_initial_data() 0 10 1
A BaseRecon.reconstruct_pre_process() 0 5 1
A BaseRecon.get_pad_amount() 0 2 1
A BaseRecon.get_fov_fraction() 0 8 1
A BaseRecon.get_cors() 0 8 1
A BaseRecon.get_slice_axis() 0 12 1
A BaseRecon.__set_pad_amount() 0 2 1
A BaseRecon.executive_summary() 0 8 3
A BaseRecon.set_mask() 0 2 1
A BaseRecon._get_axis_labels() 0 22 2

How to fix   Complexity   

Complexity

Complex classes like savu.plugins.reconstructions.base_recon often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: base_recon
17
   :platform: Unix
18
   :synopsis: A base class for all reconstruction methods
19
20
.. moduleauthor:: Mark Basham <[email protected]>
21
22
"""
23
24
import math
25
import copy
26
import numpy as np
27
np.seterr(divide='ignore', invalid='ignore')
28
29
import savu.core.utils as cu
30
from savu.plugins.plugin import Plugin
31
32
MAX_OUTER_PAD = 2.1
33
34
35
class BaseRecon(Plugin):
36
37
38
    def __init__(self, name='BaseRecon'):
39
        super(BaseRecon, self).__init__(name)
40
        self.nOut = 1
41
        self.nIn = 1
42
        self.scan_dim = None
43
        self.rep_dim = None
44
        self.br_vol_shape = None
45
        self.frame_angles = None
46
        self.frame_cors = None
47
        self.frame_init_data = None
48
        self.centre = None
49
        self.base_pad_amount = None
50
        self.padding_alg = False
51
        self.cor_shift = 0
52
        self.init_vol = False
53
        self.cor_as_dataset = False
54
55
    def base_pre_process(self):
56
        in_data, out_data = self.get_datasets()
57
        in_pData, out_pData = self.get_plugin_datasets()
58
        self.pad_dim = \
59
            in_pData[0].get_data_dimension_by_axis_label('x', contains=True)
60
        in_meta_data = self.get_in_meta_data()[0]
61
62
        self.exp.log(self.name + " End")
63
        self.br_vol_shape = out_pData[0].get_shape()
64
        self.set_centre_of_rotation(in_data[0], out_data[0], in_meta_data)
65
66
        self.main_dir = in_data[0].get_data_patterns()['SINOGRAM']['main_dir']
67
        self.angles = in_meta_data.get('rotation_angle')
68
        if len(self.angles.shape) != 1:
69
            self.scan_dim = in_data[0].get_data_dimension_by_axis_label('scan')
70
        self.slice_dirs = out_data[0].get_slice_dimensions()
71
72
        shape = in_pData[0].get_shape()
73
        factor = self.__get_outer_pad()
74
        self.sino_pad = int(math.ceil(factor * shape[self.pad_dim]))
75
76
        self.sino_func, self.cor_func = self.set_function(shape)
77
78
        self.range = self.parameters['force_zero']
79
        self.fix_sino = self.get_sino_centre_method()
80
81
    def __get_outer_pad(self):
82
        pad = self.parameters['outer_pad'] if 'outer_pad' in self.parameters \
83
            else False
84
        # length of diagonal of square is side*sqrt(2)
85
        factor = math.sqrt(2) - 1
86
        if isinstance(pad, bool):
87
            return factor if pad is True else 0
88
89
        factor = float(pad)
90
        if factor > MAX_OUTER_PAD:
91
            factor = MAX_OUTER_PAD
92
            msg = 'Maximum outer_pad value is 2.1, using this instead'
93
            cu.user_message(msg)
94
        return factor
95
96
    def get_vol_shape(self):
97
        return self.br_vol_shape
98
99
    def set_centre_of_rotation(self, inData, outData, mData):
100
        # if cor has been passed as a dataset then do nothing
101
        if isinstance(self.parameters['centre_of_rotation'], str):
102
            return
103
        if 'centre_of_rotation' in list(mData.get_dictionary().keys()):
104
            cor = self.__set_cor_from_meta_data(mData, inData)
105
        else:
106
            val = self.parameters['centre_of_rotation']
107
            if isinstance(val, dict):
108
                cor = self.__polyfit_cor(val, inData)
109
            else:
110
                sdirs = inData.get_slice_dimensions()
111
                cor = np.ones(np.prod([inData.get_shape()[i] for i in sdirs]))
112
                # if centre of rotation has not been set then fix it in the
113
                # centre
114
                val = val if val != 0 else \
115
                    (self.get_vol_shape()[self._get_detX_dim()]) / 2.0
116
                cor *= val
117
                # mData.set('centre_of_rotation', cor) see Github ticket
118
        self.cor = cor
119
        outData.meta_data.set("centre_of_rotation", copy.deepcopy(self.cor))
120
        self.centre = self.cor[0]
121
122
    def populate_metadata_to_output(self, inData, outData, mData):
123
        # writing  rotation angles into the metadata associated with the output (reconstruction)
124
        self.angles = mData.get('rotation_angle')
125
        outData.meta_data.set("rotation_angle", copy.deepcopy(self.angles))
126
127
        xDim = inData.get_data_dimension_by_axis_label('x', contains=True)
128
        det_length = inData.get_shape()[xDim]
129
        outData.meta_data.set("detector_x_length", copy.deepcopy(det_length))
130
131
    def __set_cor_from_meta_data(self, mData, inData):
132
        cor = mData.get('centre_of_rotation')
133
        sdirs = inData.get_slice_dimensions()
134
        total_frames = np.prod([inData.get_shape()[i] for i in sdirs])
135
        if total_frames > len(cor):
136
            cor = np.tile(cor, total_frames // len(cor))
137
        return cor
138
139
    def __polyfit_cor(self, cor_dict, inData):
140
        if 'detector_y' in list(inData.meta_data.get_dictionary().keys()):
141
            y = inData.meta_data.get('detector_y')
142
        else:
143
            yDim = inData.get_data_dimension_by_axis_label('detector_y')
144
            y = np.arange(inData.get_shape()[yDim])
145
146
        z = np.polyfit(list(map(int, list(cor_dict.keys()))), list(cor_dict.values()), 1)
147
        p = np.poly1d(z)
148
        cor = p(y)
149
        return cor
150
151
    def set_function(self, pad_shape):
152
        centre_pad = self.parameters['centre_pad'] if 'centre_pad' in \
153
            self.parameters else False
154
        if not centre_pad:
155
            def cor_func(cor):
156
                return cor
157
            if self.parameters['log']:
158
                sino_func = self.__make_lambda()
159
            else:
160
                sino_func = self.__make_lambda(log=False)
161
        else:
162
            def cor_func(cor):
163
                return cor + self.sino_pad
164
            if self.parameters['log']:
165
                sino_func = self.__make_lambda(pad=pad_shape)
166
            else:
167
                sino_func = self.__make_lambda(pad=pad_shape, log=False)
168
        return sino_func, cor_func
169
170
    def __make_lambda(self, log=True, pad=False):
171
        log_func = 'np.nan_to_num(sino)' if not log else self.parameters['log_func']
172
        if pad:
173
            pad_tuples, mode = self.__get_pad_values(pad)
174
            log_func = log_func.replace(
175
                    'sino', 'np.pad(sino, %s, "%s")' % (pad_tuples, mode))
176
        return eval("lambda sino: " + log_func)
177
178
    def __get_pad_values(self, pad_shape):
179
        mode = 'edge'
180
        pad_tuples = [(0, 0)] * (len(pad_shape) - 1)
181
        pad_tuples.insert(self.pad_dim, (self.sino_pad, self.sino_pad))
182
        pad_tuples = tuple(pad_tuples)
183
        return pad_tuples, mode
184
185
    def base_process_frames_before(self, data):
186
        """
187
        Reconstruct a single sinogram with the provided centre of rotation
188
        """
189
        sl = self.get_current_slice_list()[0]
190
        init = data[1] if self.init_vol else None
191
        angles = \
192
            self.angles[:, sl[self.scan_dim]] if self.scan_dim else self.angles
193
        angles = np.squeeze(angles)
194
195
        self.frame_angles = angles
196
197
        dim_sl = sl[self.main_dir]
198
199
        if self.cor_as_dataset:
200
            self.frame_cors = self.cor_func(data[len(data) - 1])
201
        else:
202
            frame_nos = \
203
                self.get_plugin_in_datasets()[0].get_current_frame_idx()
204
            a = self.cor[tuple([frame_nos])]
205
            self.frame_cors = self.cor_func(a)
206
207
        # for extra padded frames that make up the numbers
208
        if not self.frame_cors.shape:
209
            self.frame_cors = np.array([self.centre])
210
211
        len_data = len(np.arange(dim_sl.start, dim_sl.stop, dim_sl.step))
212
213
        missing = [self.centre] * (len(self.frame_cors) - len_data)
214
        self.frame_cors = np.append(self.frame_cors, missing)
215
216
        # fix to remove NaNs in the initialised image
217
        if init is not None:
218
            init[np.isnan(init)] == 0.0
219
        self.frame_init_data = init
220
221
        data[0] = self.fix_sino(self.sino_func(data[0]), self.frame_cors[0])
222
        return data
223
224
    def base_process_frames_after(self, data):
225
        lower_range, upper_range = self.range
226
        data = np.nan_to_num(data)
227
        if lower_range is not None:
228
            data[data < lower_range] = 0.0
229
        if upper_range is not None:
230
            data[data > upper_range] = 0.0
231
        return data
232
233
    def pad_sino(self, sino, cor):
234
        """  Pad the sinogram so the centre of rotation is at the centre. """
235
        detX = self._get_detX_dim()
236
        pad = self.get_centre_offset(sino, cor, detX)
237
        self.cor_shift = pad[0]
238
        pad_tuples = [(0, 0)] * (len(sino.shape) - 1)
239
        pad_tuples.insert(detX, tuple(pad))
240
        self.__set_pad_amount(max(pad))
241
        return np.pad(sino, tuple(pad_tuples), mode='edge')
242
243
    def _get_detX_dim(self):
244
        pData = self.get_plugin_in_datasets()[0]
245
        return pData.get_data_dimension_by_axis_label('x', contains=True)
246
247
    def get_centre_offset(self, sino, cor, detX):
248
        centre_pad = self.br_array_pad(cor, sino.shape[detX])
249
        sino_width = sino.shape[detX]
250
        new_width = sino_width + max(centre_pad)
251
        sino_pad = int(math.ceil(float(sino_width) / new_width * self.sino_pad) // 2)
252
        pad = np.array([sino_pad]*2) + centre_pad
253
        return pad
254
255
    def get_centre_shift(self, sino, cor):
256
        detX = self._get_detX_dim()
257
        return max(self.get_centre_offset(sino, self.centre, detX))
258
259
    def crop_sino(self, sino, cor):
260
        """  Crop the sinogram so the centre of rotation is at the centre. """
261
        detX = self._get_detX_dim()
262
        start, stop = self.br_array_pad(cor, sino.shape[detX])[::-1]
263
        self.cor_shift = -start
264
        sl = [slice(None)] * len(sino.shape)
265
        sl[detX] = slice(start, sino.shape[detX] - stop)
266
        sino = sino[tuple(sl)]
267
        self.set_mask(sino.shape)
268
        return sino
269
270
    def br_array_pad(self, ctr, nPixels):
271
        width = nPixels - 1.0
272
        alen = ctr
273
        blen = width - ctr
274
        mid = (width - 1.0) / 2.0
275
        shift = round(abs(blen - alen))
276
        p_low = 0 if (ctr > mid) else shift
277
        p_high = shift + 0 if (ctr > mid) else 0
278
        return np.array([int(p_low), int(p_high)])
279
280
    def keep_sino(self, sino, cor):
281
        """ No change to the sinogram """
282
        return sino
283
284
    def get_sino_centre_method(self):
285
        centre_pad = self.keep_sino
286
        if 'centre_pad' in list(self.parameters.keys()):
287
            cpad = self.parameters['centre_pad']
288
            if not (cpad is True or cpad is False):
289
                raise Exception('Unknown value for "centre_pad", please choose'
290
                                ' True or False.')
291
            centre_pad = self.pad_sino if cpad else self.crop_sino
292
        return centre_pad
293
294
    def __set_pad_amount(self, pad_amount):
295
        self.base_pad_amount = pad_amount
296
297
    def get_pad_amount(self):
298
        return self.base_pad_amount
299
300
    def get_fov_fraction(self, sino, cor):
301
        """ Get the fraction of the original FOV that can be reconstructed due\
302
        to offset centre """
303
        pData = self.get_plugin_in_datasets()[0]
304
        detX = pData.get_data_dimension_by_axis_label('x', contains=True)
305
        original_length = sino.shape[detX]
306
        shift = self.get_centre_shift(sino, cor)
307
        return (original_length - shift) / float(original_length)
308
309
    def get_reconstruction_alg(self):
310
        return None
311
312
    def get_angles(self):
313
        """ Get the angles associated with the current sinogram(s).
314
315
        :returns: Angles of the current frames.
316
        :rtype: np.ndarray
317
        """
318
        return self.frame_angles
319
320
    def get_cors(self):
321
        """
322
        Get the centre of rotations associated with the current sinogram(s).
323
324
        :returns: Centre of rotation values for the current frames.
325
        :rtype: np.ndarray
326
        """
327
        return self.frame_cors + self.cor_shift
328
329
    def set_mask(self, shape):
330
        pass
331
332
    def get_initial_data(self):
333
        """
334
        Get the initial data (if it is exists) associated with the current \
335
        sinogram(s).
336
337
        :returns: The section of the initialisation data associated with the \
338
            current frames.
339
        :rtype: np.ndarray or None
340
        """
341
        return self.frame_init_data
342
343
    def get_frame_params(self):
344
        params = [self.get_cors(), self.get_angles(), self.get_vol_shape(),
345
                  self.get_initial_data()]
346
        return params
347
348
    def setup(self):
349
        in_dataset, out_dataset = self.get_datasets()
350
        # reduce the data as per data_subset parameter
351
        self.preview_flag = \
352
            self.set_preview(in_dataset[0], self.parameters['preview'])
353
354
        self._set_volume_dimensions(in_dataset[0])
355
        axis_labels = self._get_axis_labels(in_dataset[0])
356
        shape = self._get_shape(in_dataset[0])
357
358
        # output dataset
359
        out_dataset[0].create_dataset(axis_labels=axis_labels, shape=shape)
360
        out_dataset[0].add_volume_patterns(*self._get_volume_dimensions())
361
362
        # set information relating to the plugin data
363
        in_pData, out_pData = self.get_plugin_datasets()
364
365
        self.init_vol = 1 if 'init_vol' in list(self.parameters.keys()) and\
366
            self.parameters['init_vol'] else 0
367
        self.cor_as_dataset = 1 if isinstance(
368
            self.parameters['centre_of_rotation'], str) else 0
369
370
        for i in range(len(in_dataset) - self.init_vol - self.cor_as_dataset):
371
            in_pData[i].plugin_data_setup('SINOGRAM', self.get_max_frames(),
372
                                          slice_axis=self.get_slice_axis())
373
            idx = 1
374
375
        # initial volume dataset
376
        if self.init_vol:
377
#            from savu.data.data_structures.data_types import Replicate
378
#            if self.rep_dim:
379
#                in_dataset[idx].data = Replicate(
380
#                    in_dataset[idx], in_dataset[0].get_shape(self.rep_dim))
381
            in_pData[1].plugin_data_setup('VOLUME_XZ', self.get_max_frames())
382
            idx += 1
0 ignored issues
show
introduced by
The variable idx does not seem to be defined in case the for loop on line 370 is not entered. Are you sure this can never be the case?
Loading history...
383
384
        # cor dataset
385
        if self.cor_as_dataset:
386
            self.cor_as_dataset = True
387
            in_pData[idx].plugin_data_setup('METADATA', self.get_max_frames())
388
389
        # set pattern_name and nframes to process for all datasets
390
        out_pData[0].plugin_data_setup('VOLUME_XZ', self.get_max_frames())
391
        # metadata output populator
392
        in_meta_data = self.get_in_meta_data()[0]
393
        self.populate_metadata_to_output(in_dataset[0], out_dataset[0], in_meta_data)
394
395
    def _get_axis_labels(self, in_dataset):
396
        """
397
        Get the new axis labels for the output dataset - this is now a volume.
398
399
        Parameters
400
        ----------
401
        in_dataset : :class:`savu.data.data_structures.data.Data`
402
            The input dataset to the plugin.
403
404
        Returns
405
        -------
406
        labels : dict
407
            The axis labels for the dataset that is output from the plugin.
408
409
        """
410
        labels = in_dataset.data_info.get('axis_labels')[0]
411
        volX, volY, volZ = self._get_volume_dimensions()
412
        labels = [str(volX) + '.voxel_x.voxels', str(volZ) + '.voxel_z.voxels']
413
        if volY:
414
            labels.append(str(volY) + '.voxel_y.voxels')
415
        labels = {in_dataset: labels}
416
        return labels
417
418
    def _set_volume_dimensions(self, data):
419
        """
420
        Map the input dimensions to the output volume dimensions
421
422
        Parameters
423
        ----------
424
        in_dataset : :class:`savu.data.data_structures.data.Data`
425
            The input dataset to the plugin.
426
        """
427
        data._finalise_patterns()
428
        self.volX = data.get_data_dimension_by_axis_label("rotation_angle")
429
        self.volZ = data.get_data_dimension_by_axis_label("x", contains=True)
430
        self.volY = data.get_data_dimension_by_axis_label(
431
            "y", contains=True, exists=True)
432
433
    def _get_volume_dimensions(self):
434
        return self.volX, self.volY, self.volZ
435
436
    def _get_shape(self, in_dataset):
437
        shape = list(in_dataset.get_shape())
438
        volX, volY, volZ = self._get_volume_dimensions()
439
440
        if self.parameters['vol_shape'] in ('auto', 'fixed'):
441
            shape[volX] = shape[volZ]
442
        else:
443
            shape[volX] = self.parameters['vol_shape']
444
            shape[volZ] = self.parameters['vol_shape']
445
446
        if 'resolution' in self.parameters.keys():
447
            shape[volX] = int(shape[volX] // self.parameters['resolution'])
448
            shape[volZ] = int(shape[volZ] // self.parameters['resolution'])
449
        return tuple(shape)
450
451
    def get_max_frames(self):
452
        """
453
        Number of data frames to pass to each instance of process_frames func
454
455
        Returns
456
        -------
457
        str or int
458
            "single", "multiple" or integer (only to be used if the number of
459
                                             frames MUST be fixed.)
460
        """
461
        return 'multiple'
462
463
    def get_slice_axis(self):
464
        """
465
        Fix the fastest changing slice dimension
466
467
        Returns
468
        -------
469
        str or None
470
            str should be the axis_label corresponding to the fastest changing
471
            dimension
472
473
        """
474
        return None
475
476
    def nInput_datasets(self):
477
        nIn = 1
478
        if 'init_vol' in self.parameters.keys() and \
479
                self.parameters['init_vol']:
480
            if len(self.parameters['init_vol'].split('.')) == 3:
481
                name, temp, self.rep_dim = self.parameters['init_vol']
482
                self.parameters['init_vol'] = name
483
            nIn += 1
484
            self.parameters['in_datasets'].append(self.parameters['init_vol'])
485
        if isinstance(self.parameters['centre_of_rotation'], str):
486
            self.parameters['in_datasets'].append(
487
                self.parameters['centre_of_rotation'])
488
            nIn += 1
489
        return nIn
490
491
    def nOutput_datasets(self):
492
        return self.nOut
493
494
    def reconstruct_pre_process(self):
495
        """
496
        Should be overridden to perform pre-processing in a child class
497
        """
498
        pass
499
500
    def executive_summary(self):
501
        summary = []
502
        if not self.preview_flag:
503
            summary.append(("WARNING: Ignoring preview parameters as a preview"
504
                            " has already been applied to the data."))
505
        if len(summary) > 0:
506
            return summary
507
        return ["Nothing to Report"]
508