Test Failed
Pull Request — master (#700)
by Daniil
04:25
created

savu.plugins.reconstructions.base_recon   F

Complexity

Total Complexity 87

Size/Duplication

Total Lines 506
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 87
eloc 313
dl 0
loc 506
rs 2
c 0
b 0
f 0

41 Methods

Rating   Name   Duplication   Size   Complexity  
A BaseRecon.br_array_pad() 0 9 3
A BaseRecon.get_vol_shape() 0 2 1
B BaseRecon.base_process_frames_before() 0 37 6
A BaseRecon.get_reconstruction_alg() 0 2 1
A BaseRecon.__init__() 0 16 1
A BaseRecon._get_volume_dimensions() 0 2 1
A BaseRecon.__get_outer_pad() 0 14 5
A BaseRecon._get_shape() 0 14 3
A BaseRecon.get_centre_shift() 0 3 1
A BaseRecon.base_process_frames_after() 0 7 3
A BaseRecon.base_pre_process() 0 25 2
A BaseRecon._get_detX_dim() 0 3 1
B BaseRecon.setup() 0 46 6
A BaseRecon.get_angles() 0 7 1
A BaseRecon.get_frame_params() 0 4 1
A BaseRecon.set_centre_of_rotation() 0 22 5
A BaseRecon._set_volume_dimensions() 0 14 1
A BaseRecon.__get_pad_values() 0 6 1
A BaseRecon.get_max_frames() 0 11 1
A BaseRecon.nOutput_datasets() 0 2 1
B BaseRecon.get_sino_centre_method() 0 10 6
A BaseRecon.pad_sino() 0 9 1
A BaseRecon.get_centre_offset() 0 7 1
A BaseRecon.nInput_datasets() 0 14 5
A BaseRecon.get_initial_data() 0 10 1
A BaseRecon.reconstruct_pre_process() 0 5 1
A BaseRecon.get_pad_amount() 0 2 1
A BaseRecon.__set_cor_from_meta_data() 0 7 2
A BaseRecon.get_fov_fraction() 0 8 1
A BaseRecon.set_function() 0 16 5
A BaseRecon.get_cors() 0 8 1
A BaseRecon.get_slice_axis() 0 12 1
A BaseRecon.__set_pad_amount() 0 2 1
A BaseRecon.executive_summary() 0 8 3
A BaseRecon.keep_sino() 0 3 1
A BaseRecon.set_mask() 0 2 1
A BaseRecon.populate_metadata_to_output() 0 10 2
A BaseRecon.crop_sino() 0 10 1
A BaseRecon.__polyfit_cor() 0 11 2
A BaseRecon._get_axis_labels() 0 22 2
A BaseRecon.__make_lambda() 0 7 3

How to fix   Complexity   

Complexity

Complex classes like savu.plugins.reconstructions.base_recon often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: base_recon
17
   :platform: Unix
18
   :synopsis: A base class for all reconstruction methods
19
20
.. moduleauthor:: Mark Basham <[email protected]>
21
22
"""
23
import math
24
import copy
25
import numpy as np
26
np.seterr(divide='ignore', invalid='ignore')
27
28
import savu.core.utils as cu
29
from savu.plugins.plugin import Plugin
30
31
MAX_OUTER_PAD = 2.1
32
33
34
class BaseRecon(Plugin):
35
36
37
    def __init__(self, name='BaseRecon'):
38
        super(BaseRecon, self).__init__(name)
39
        self.nOut = 1
40
        self.nIn = 1
41
        self.scan_dim = None
42
        self.rep_dim = None
43
        self.br_vol_shape = None
44
        self.frame_angles = None
45
        self.frame_cors = None
46
        self.frame_init_data = None
47
        self.centre = None
48
        self.base_pad_amount = None
49
        self.padding_alg = False
50
        self.cor_shift = 0
51
        self.init_vol = False
52
        self.cor_as_dataset = False
53
54
    def base_pre_process(self):
55
        in_data, out_data = self.get_datasets()
56
        in_pData, out_pData = self.get_plugin_datasets()
57
        self.pad_dim = \
58
            in_pData[0].get_data_dimension_by_axis_label('x', contains=True)
59
        in_meta_data = self.get_in_meta_data()[0]
60
61
        self.exp.log(self.name + " End")
62
        self.br_vol_shape = out_pData[0].get_shape()
63
        self.set_centre_of_rotation(in_data[0], out_data[0], in_meta_data)
64
65
        self.main_dir = in_data[0].get_data_patterns()['SINOGRAM']['main_dir']
66
        self.angles = in_meta_data.get('rotation_angle')
67
        if len(self.angles.shape) != 1:
68
            self.scan_dim = in_data[0].get_data_dimension_by_axis_label('scan')
69
        self.slice_dirs = out_data[0].get_slice_dimensions()
70
71
        shape = in_pData[0].get_shape()
72
        factor = self.__get_outer_pad()
73
        self.sino_pad = int(math.ceil(factor * shape[self.pad_dim]))
74
75
        self.sino_func, self.cor_func = self.set_function(shape)
76
77
        self.range = self.parameters['force_zero']
78
        self.fix_sino = self.get_sino_centre_method()
79
80
    def __get_outer_pad(self):
81
        pad = self.parameters['outer_pad'] if 'outer_pad' in self.parameters \
82
            else False
83
        # length of diagonal of square is side*sqrt(2)
84
        factor = math.sqrt(2) - 1
85
        if isinstance(pad, bool):
86
            return factor if pad is True else 0
87
88
        factor = float(pad)
89
        if factor > MAX_OUTER_PAD:
90
            factor = MAX_OUTER_PAD
91
            msg = 'Maximum outer_pad value is 2.1, using this instead'
92
            cu.user_message(msg)
93
        return factor
94
95
    def get_vol_shape(self):
96
        return self.br_vol_shape
97
98
    def set_centre_of_rotation(self, inData, outData, mData):
99
        # if cor has been passed as a dataset then do nothing
100
        if isinstance(self.parameters['centre_of_rotation'], str):
101
            return
102
        if 'centre_of_rotation' in list(mData.get_dictionary().keys()):
103
            cor = self.__set_cor_from_meta_data(mData, inData)
104
        else:
105
            val = self.parameters['centre_of_rotation']
106
            if isinstance(val, dict):
107
                cor = self.__polyfit_cor(val, inData)
108
            else:
109
                sdirs = inData.get_slice_dimensions()
110
                cor = np.ones(np.prod([inData.get_shape()[i] for i in sdirs]))
111
                # if centre of rotation has not been set then fix it in the
112
                # centre
113
                val = val if val != 0 else \
114
                    (self.get_vol_shape()[self._get_detX_dim()]) / 2.0
115
                cor *= val
116
                # mData.set('centre_of_rotation', cor) see Github ticket
117
        self.cor = cor
118
        outData.meta_data.set("centre_of_rotation", copy.deepcopy(self.cor))
119
        self.centre = self.cor[0]
120
121
    def populate_metadata_to_output(self, inData, outData, mData):
122
        # writing  rotation angles into the metadata associated with the output (reconstruction)
123
        self.angles = mData.get('rotation_angle')
124
        outData.meta_data.set("rotation_angle", copy.deepcopy(self.angles))
125
        if 'detector_x' in list(mData.get_dictionary().keys()):
126
            detector_x_dim = mData.get('detector_x')
127
        else:
128
            xDim = inData.get_data_dimension_by_axis_label('detector_x')
129
            detector_x_dim = inData.get_shape()[xDim]
130
        outData.meta_data.set("detector_x_length", copy.deepcopy(detector_x_dim))
131
132
    def __set_cor_from_meta_data(self, mData, inData):
133
        cor = mData.get('centre_of_rotation')
134
        sdirs = inData.get_slice_dimensions()
135
        total_frames = np.prod([inData.get_shape()[i] for i in sdirs])
136
        if total_frames > len(cor):
137
            cor = np.tile(cor, total_frames // len(cor))
138
        return cor
139
140
    def __polyfit_cor(self, cor_dict, inData):
141
        if 'detector_y' in list(inData.meta_data.get_dictionary().keys()):
142
            y = inData.meta_data.get('detector_y')
143
        else:
144
            yDim = inData.get_data_dimension_by_axis_label('detector_y')
145
            y = np.arange(inData.get_shape()[yDim])
146
147
        z = np.polyfit(list(map(int, list(cor_dict.keys()))), list(cor_dict.values()), 1)
148
        p = np.poly1d(z)
149
        cor = p(y)
150
        return cor
151
152
    def set_function(self, pad_shape):
153
        centre_pad = self.parameters['centre_pad'] if 'centre_pad' in \
154
            self.parameters else False
155
        if not centre_pad:
156
            def cor_func(cor): return cor
157
            if self.parameters['log']:
158
                sino_func = self.__make_lambda()
159
            else:
160
                sino_func = self.__make_lambda(log=False)
161
        else:
162
            def cor_func(cor): return cor + self.sino_pad
163
            if self.parameters['log']:
164
                sino_func = self.__make_lambda(pad=pad_shape)
165
            else:
166
                sino_func = self.__make_lambda(pad=pad_shape, log=False)
167
        return sino_func, cor_func
168
169
    def __make_lambda(self, log=True, pad=False):
170
        log_func = 'np.nan_to_num(sino)' if not log else self.parameters['log_func']
171
        if pad:
172
            pad_tuples, mode = self.__get_pad_values(pad)
173
            log_func = log_func.replace(
174
                    'sino', 'np.pad(sino, %s, "%s")' % (pad_tuples, mode))
175
        return eval("lambda sino: " + log_func)
176
177
    def __get_pad_values(self, pad_shape):
178
        mode = 'edge'
179
        pad_tuples = [(0, 0)] * (len(pad_shape) - 1)
180
        pad_tuples.insert(self.pad_dim, (self.sino_pad, self.sino_pad))
181
        pad_tuples = tuple(pad_tuples)
182
        return pad_tuples, mode
183
184
    def base_process_frames_before(self, data):
185
        """
186
        Reconstruct a single sinogram with the provided centre of rotation
187
        """
188
        sl = self.get_current_slice_list()[0]
189
        init = data[1] if self.init_vol else None
190
        angles = \
191
            self.angles[:, sl[self.scan_dim]] if self.scan_dim else self.angles
192
193
        self.frame_angles = angles
194
195
        dim_sl = sl[self.main_dir]
196
197
        if self.cor_as_dataset:
198
            self.frame_cors = self.cor_func(data[len(data) - 1])
199
        else:
200
            frame_nos = \
201
                self.get_plugin_in_datasets()[0].get_current_frame_idx()
202
            a = self.cor[tuple([frame_nos])]
203
            self.frame_cors = self.cor_func(a)
204
205
        # for extra padded frames that make up the numbers
206
        if not self.frame_cors.shape:
207
            self.frame_cors = np.array([self.centre])
208
209
        len_data = len(np.arange(dim_sl.start, dim_sl.stop, dim_sl.step))
210
211
        missing = [self.centre] * (len(self.frame_cors) - len_data)
212
        self.frame_cors = np.append(self.frame_cors, missing)
213
214
        # fix to remove NaNs in the initialised image
215
        if init is not None:
216
            init[np.isnan(init)] == 0.0
217
        self.frame_init_data = init
218
219
        data[0] = self.fix_sino(self.sino_func(data[0]), self.frame_cors[0])
220
        return data
221
222
    def base_process_frames_after(self, data):
223
        lower_range, upper_range = self.range
224
        if lower_range is not None:
225
            data[data < lower_range] = 0
226
        if upper_range is not None:
227
            data[data > upper_range] = 0
228
        return data
229
230
    def pad_sino(self, sino, cor):
231
        """  Pad the sinogram so the centre of rotation is at the centre. """
232
        detX = self._get_detX_dim()
233
        pad = self.get_centre_offset(sino, cor, detX)
234
        self.cor_shift = pad[0]
235
        pad_tuples = [(0, 0)] * (len(sino.shape) - 1)
236
        pad_tuples.insert(detX, tuple(pad))
237
        self.__set_pad_amount(max(pad))
238
        return np.pad(sino, tuple(pad_tuples), mode='edge')
239
240
    def _get_detX_dim(self):
241
        pData = self.get_plugin_in_datasets()[0]
242
        return pData.get_data_dimension_by_axis_label('x', contains=True)
243
244
    def get_centre_offset(self, sino, cor, detX):
245
        centre_pad = self.br_array_pad(cor, sino.shape[detX])
246
        sino_width = sino.shape[detX]
247
        new_width = sino_width + max(centre_pad)
248
        sino_pad = int(math.ceil(float(sino_width) / new_width * self.sino_pad) // 2)
249
        pad = np.array([sino_pad]*2) + centre_pad
250
        return pad
251
252
    def get_centre_shift(self, sino, cor):
253
        detX = self._get_detX_dim()
254
        return max(self.get_centre_offset(sino, self.centre, detX))
255
256
    def crop_sino(self, sino, cor):
257
        """  Crop the sinogram so the centre of rotation is at the centre. """
258
        detX = self._get_detX_dim()
259
        start, stop = self.br_array_pad(cor, sino.shape[detX])[::-1]
260
        self.cor_shift = -start
261
        sl = [slice(None)] * len(sino.shape)
262
        sl[detX] = slice(start, sino.shape[detX] - stop)
263
        sino = sino[tuple(sl)]
264
        self.set_mask(sino.shape)
265
        return sino
266
267
    def br_array_pad(self, ctr, nPixels):
268
        width = nPixels - 1.0
269
        alen = ctr
270
        blen = width - ctr
271
        mid = (width - 1.0) / 2.0
272
        shift = round(abs(blen - alen))
273
        p_low = 0 if (ctr > mid) else shift
274
        p_high = shift + 0 if (ctr > mid) else 0
275
        return np.array([int(p_low), int(p_high)])
276
277
    def keep_sino(self, sino, cor):
278
        """ No change to the sinogram """
279
        return sino
280
281
    def get_sino_centre_method(self):
282
        centre_pad = self.keep_sino
283
        if 'centre_pad' in list(self.parameters.keys()) and \
284
                self.parameters['centre_pad'] is not False:
285
            cpad = self.parameters['centre_pad']
286
            if not (cpad is True or cpad is False):
287
                raise Exception('Unknown value for "centre_pad", please choose'
288
                                ' True or False.')
289
            centre_pad = self.pad_sino if cpad else self.crop_sino
290
        return centre_pad
291
292
    def __set_pad_amount(self, pad_amount):
293
        self.base_pad_amount = pad_amount
294
295
    def get_pad_amount(self):
296
        return self.base_pad_amount
297
298
    def get_fov_fraction(self, sino, cor):
299
        """ Get the fraction of the original FOV that can be reconstructed due\
300
        to offset centre """
301
        pData = self.get_plugin_in_datasets()[0]
302
        detX = pData.get_data_dimension_by_axis_label('x', contains=True)
303
        original_length = sino.shape[detX]
304
        shift = self.get_centre_shift(sino, cor)
305
        return (original_length - shift) / float(original_length)
306
307
    def get_reconstruction_alg(self):
308
        return None
309
310
    def get_angles(self):
311
        """ Get the angles associated with the current sinogram(s).
312
313
        :returns: Angles of the current frames.
314
        :rtype: np.ndarray
315
        """
316
        return self.frame_angles
317
318
    def get_cors(self):
319
        """
320
        Get the centre of rotations associated with the current sinogram(s).
321
322
        :returns: Centre of rotation values for the current frames.
323
        :rtype: np.ndarray
324
        """
325
        return self.frame_cors + self.cor_shift
326
327
    def set_mask(self, shape):
328
        pass
329
330
    def get_initial_data(self):
331
        """
332
        Get the initial data (if it is exists) associated with the current \
333
        sinogram(s).
334
335
        :returns: The section of the initialisation data associated with the \
336
            current frames.
337
        :rtype: np.ndarray or None
338
        """
339
        return self.frame_init_data
340
341
    def get_frame_params(self):
342
        params = [self.get_cors(), self.get_angles(), self.get_vol_shape(),
343
                  self.get_initial_data()]
344
        return params
345
346
    def setup(self):
347
        in_dataset, out_dataset = self.get_datasets()
348
        # reduce the data as per data_subset parameter
349
        self.preview_flag = \
350
            self.set_preview(in_dataset[0], self.parameters['preview'])
351
352
        self._set_volume_dimensions(in_dataset[0])
353
        axis_labels = self._get_axis_labels(in_dataset[0])
354
        shape = self._get_shape(in_dataset[0])
355
356
        # output dataset
357
        out_dataset[0].create_dataset(axis_labels=axis_labels, shape=shape)
358
        out_dataset[0].add_volume_patterns(*self._get_volume_dimensions())
359
360
        # set information relating to the plugin data
361
        in_pData, out_pData = self.get_plugin_datasets()
362
363
        self.init_vol = 1 if 'init_vol' in list(self.parameters.keys()) and\
364
            self.parameters['init_vol'] else 0
365
        self.cor_as_dataset = 1 if isinstance(
366
            self.parameters['centre_of_rotation'], str) else 0
367
368
        for i in range(len(in_dataset) - self.init_vol - self.cor_as_dataset):
369
            in_pData[i].plugin_data_setup('SINOGRAM', self.get_max_frames(),
370
                                          slice_axis=self.get_slice_axis())
371
            idx = 1
372
373
        # initial volume dataset
374
        if self.init_vol:
375
#            from savu.data.data_structures.data_types import Replicate
376
#            if self.rep_dim:
377
#                in_dataset[idx].data = Replicate(
378
#                    in_dataset[idx], in_dataset[0].get_shape(self.rep_dim))
379
            in_pData[1].plugin_data_setup('VOLUME_XZ', self.get_max_frames())
380
            idx += 1
0 ignored issues
show
introduced by
The variable idx does not seem to be defined in case the for loop on line 368 is not entered. Are you sure this can never be the case?
Loading history...
381
382
        # cor dataset
383
        if self.cor_as_dataset:
384
            self.cor_as_dataset = True
385
            in_pData[idx].plugin_data_setup('METADATA', self.get_max_frames())
386
387
        # set pattern_name and nframes to process for all datasets
388
        out_pData[0].plugin_data_setup('VOLUME_XZ', self.get_max_frames())
389
        # metadata output populator
390
        in_meta_data = self.get_in_meta_data()[0]
391
        self.populate_metadata_to_output(in_dataset[0], out_dataset[0], in_meta_data)
392
393
    def _get_axis_labels(self, in_dataset):
394
        """
395
        Get the new axis labels for the output dataset - this is now a volume.
396
397
        Parameters
398
        ----------
399
        in_dataset : :class:`savu.data.data_structures.data.Data`
400
            The input dataset to the plugin.
401
402
        Returns
403
        -------
404
        labels : dict
405
            The axis labels for the dataset that is output from the plugin.
406
407
        """
408
        labels = in_dataset.data_info.get('axis_labels')[0]
409
        volX, volY, volZ = self._get_volume_dimensions()
410
        labels = [str(volX) + '.voxel_x.voxels', str(volZ) + '.voxel_z.voxels']
411
        if volY:
412
            labels.append(str(volY) + '.voxel_y.voxels')
413
        labels = {in_dataset: labels}
414
        return labels
415
416
    def _set_volume_dimensions(self, data):
417
        """
418
        Map the input dimensions to the output volume dimensions
419
420
        Parameters
421
        ----------
422
        in_dataset : :class:`savu.data.data_structures.data.Data`
423
            The input dataset to the plugin.
424
        """
425
        data._finalise_patterns()
426
        self.volX = data.get_data_dimension_by_axis_label("rotation_angle")
427
        self.volZ = data.get_data_dimension_by_axis_label("x", contains=True)
428
        self.volY = data.get_data_dimension_by_axis_label(
429
            "y", contains=True, exists=True)
430
431
    def _get_volume_dimensions(self):
432
        return self.volX, self.volY, self.volZ
433
434
    def _get_shape(self, in_dataset):
435
        shape = list(in_dataset.get_shape())
436
        volX, volY, volZ = self._get_volume_dimensions()
437
438
        if self.parameters['vol_shape'] in ('auto', 'fixed'):
439
            shape[volX] = shape[volZ]
440
        else:
441
            shape[volX] = self.parameters['vol_shape']
442
            shape[volZ] = self.parameters['vol_shape']
443
444
        if 'resolution' in self.parameters.keys():
445
            shape[volX] = int(shape[volX] // self.parameters['resolution'])
446
            shape[volZ] = int(shape[volZ] // self.parameters['resolution'])
447
        return tuple(shape)
448
449
    def get_max_frames(self):
450
        """
451
        Number of data frames to pass to each instance of process_frames func
452
453
        Returns
454
        -------
455
        str or int
456
            "single", "multiple" or integer (only to be used if the number of
457
                                             frames MUST be fixed.)
458
        """
459
        return 'multiple'
460
461
    def get_slice_axis(self):
462
        """
463
        Fix the fastest changing slice dimension
464
465
        Returns
466
        -------
467
        str or None
468
            str should be the axis_label corresponding to the fastest changing
469
            dimension
470
471
        """
472
        return None
473
474
    def nInput_datasets(self):
475
        nIn = 1
476
        if 'init_vol' in self.parameters.keys() and \
477
                self.parameters['init_vol']:
478
            if len(self.parameters['init_vol'].split('.')) == 3:
479
                name, temp, self.rep_dim = self.parameters['init_vol']
480
                self.parameters['init_vol'] = name
481
            nIn += 1
482
            self.parameters['in_datasets'].append(self.parameters['init_vol'])
483
        if isinstance(self.parameters['centre_of_rotation'], str):
484
            self.parameters['in_datasets'].append(
485
                self.parameters['centre_of_rotation'])
486
            nIn += 1
487
        return nIn
488
489
    def nOutput_datasets(self):
490
        return self.nOut
491
492
    def reconstruct_pre_process(self):
493
        """
494
        Should be overridden to perform pre-processing in a child class
495
        """
496
        pass
497
498
    def executive_summary(self):
499
        summary = []
500
        if not self.preview_flag:
501
            summary.append(("WARNING: Ignoring preview parameters as a preview"
502
                            " has already been applied to the data."))
503
        if len(summary) > 0:
504
            return summary
505
        return ["Nothing to Report"]
506