Test Failed
Pull Request — master (#700)
by Daniil
03:41 queued 12s
created

BaseRecon._get_volume_dimensions()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: base_recon
17
   :platform: Unix
18
   :synopsis: A base class for all reconstruction methods
19
20
.. moduleauthor:: Mark Basham <[email protected]>
21
22
"""
23
import math
24
import copy
25
import numpy as np
26
np.seterr(divide='ignore', invalid='ignore')
27
28
import savu.core.utils as cu
29
from savu.plugins.plugin import Plugin
30
31
MAX_OUTER_PAD = 2.1
32
33
34
class BaseRecon(Plugin):
35
    """
36
    A base class for reconstruction plugins
37
38
    :u*param centre_of_rotation: Centre of rotation to use for the \
39
    reconstruction. Default: 0.0.
40
41
    :u*param init_vol: Dataset to use as volume initialiser \
42
    (doesn't currently work with preview). Default: None.
43
44
    :param centre_pad: Pad the sinogram to centre it in order to fill the \
45
    reconstructed volume ROI for asthetic purposes.\
46
    NB: Only available for selected algorithms and will be ignored otherwise. \
47
    WARNING: This will significantly increase the size of the data and the \
48
    time to compute the reconstruction). Default: False.
49
50
    :param outer_pad: Pad the sinogram width to fill the reconstructed volume \
51
    for asthetic purposes.\
52
    Choose from True (defaults to sqrt(2)), False or float <= 2.1. \
53
    NB: Only available for selected algorithms and will be ignored otherwise.\
54
    WARNING: This will increase the size of the data and the \
55
    time to compute the reconstruction). Default: False.
56
57
    :u*param log: Take the log of the data before reconstruction \
58
    (True or False). Default: True.
59
60
    :u*param preview: A slice list of required frames. Default: [].
61
62
    :param force_zero: Set any values in the reconstructed image outside of \
63
    this range to zero. Default: [None, None].
64
65
    :param ratio: Ratio between the diameter of a circle mask and the width of\
66
    a reconstructed image. If passed as a list or tuple, the second value is \
67
    assigned to the outer mask area, e.g [0.95, 0.0]. Default: 0.95.
68
69
    :param log_func: Override the default log \
70
        function. Default: 'np.nan_to_num(-np.log(sino))'.
71
72
    :param vol_shape: Override the size of the reconstruction volume with an \
73
    integer value. Default: 'fixed'.
74
    """
75
76
    def __init__(self, name='BaseRecon'):
77
        super(BaseRecon, self).__init__(name)
78
        self.nOut = 1
79
        self.nIn = 1
80
        self.scan_dim = None
81
        self.rep_dim = None
82
        self.br_vol_shape = None
83
        self.frame_angles = None
84
        self.frame_cors = None
85
        self.frame_init_data = None
86
        self.centre = None
87
        self.base_pad_amount = None
88
        self.padding_alg = False
89
        self.cor_shift = 0
90
        self.init_vol = False
91
        self.cor_as_dataset = False
92
93
    def base_pre_process(self):
94
        in_data, out_data = self.get_datasets()
95
        in_pData, out_pData = self.get_plugin_datasets()
96
        self.pad_dim = \
97
            in_pData[0].get_data_dimension_by_axis_label('x', contains=True)
98
        in_meta_data = self.get_in_meta_data()[0]
99
        self.__set_padding_alg()
100
101
        self.exp.log(self.name + " End")
102
        self.br_vol_shape = out_pData[0].get_shape()
103
        self.set_centre_of_rotation(in_data[0], out_data[0], in_meta_data)
104
105
        self.main_dir = in_data[0].get_data_patterns()['SINOGRAM']['main_dir']
106
        self.angles = in_meta_data.get('rotation_angle')
107
        if len(self.angles.shape) != 1:
108
            self.scan_dim = in_data[0].get_data_dimension_by_axis_label('scan')
109
        self.slice_dirs = out_data[0].get_slice_dimensions()
110
111
        shape = in_pData[0].get_shape()
112
        factor = self.__get_outer_pad()
113
        self.sino_pad = int(math.ceil(factor * shape[self.pad_dim]))
114
115
        self.sino_func, self.cor_func = self.set_function(shape) if \
116
            self.padding_alg else self.set_function(False)
117
118
        self.range = self.parameters['force_zero']
119
        self.fix_sino = self.get_sino_centre_method()
120
121
    def __get_outer_pad(self):
122
        # length of diagonal of square is side*sqrt(2)
123
        factor = math.sqrt(2) - 1
124
        pad = self.parameters['outer_pad'] if 'outer_pad' in \
125
            list(self.parameters.keys()) else False
126
127
        if pad is not False and not self.padding_alg:
128
            msg = 'This reconstruction algorithm cannot be padded.'
129
            cu.user_message(msg)
130
            return 0
131
132
        if isinstance(pad, bool):
133
            return factor if pad is True else 0
134
        factor = float(pad)
135
        if factor > MAX_OUTER_PAD:
136
            factor = MAX_OUTER_PAD
137
            msg = 'Maximum outer_pad value is 2.1, using this instead'
138
            cu.user_message(msg)
139
        return float(pad)
140
141
    def __set_padding_alg(self):
142
        """ Determine if this is an algorithm that allows sinogram padding. """
143
        pad_algs = self.get_padding_algorithms()
144
        alg = self.parameters['algorithm'] if 'algorithm' in \
145
            list(self.parameters.keys()) else None
146
        self.padding_alg = True if alg in pad_algs else False
147
148
    def get_vol_shape(self):
149
        return self.br_vol_shape
150
151
    def set_centre_of_rotation(self, inData, outData, mData):
152
        # if cor has been passed as a dataset then do nothing
153
        if isinstance(self.parameters['centre_of_rotation'], str):
154
            return
155
        if 'centre_of_rotation' in list(mData.get_dictionary().keys()):
156
            cor = self.__set_cor_from_meta_data(mData, inData)
157
        else:
158
            val = self.parameters['centre_of_rotation']
159
            if isinstance(val, dict):
160
                cor = self.__polyfit_cor(val, inData)
161
            else:
162
                sdirs = inData.get_slice_dimensions()
163
                cor = np.ones(np.prod([inData.get_shape()[i] for i in sdirs]))
164
                # if centre of rotation has not been set then fix it in the
165
                # centre
166
                val = val if val != 0 else \
167
                    (self.get_vol_shape()[self._get_detX_dim()]) / 2.0
168
                cor *= val
169
                # mData.set('centre_of_rotation', cor) see Github ticket
170
        self.cor = cor
171
        outData.meta_data.set("centre_of_rotation", copy.deepcopy(self.cor))
172
        self.centre = self.cor[0]
173
174
    def populate_metadata_to_output(self, inData, outData, mData):
175
        # writing  rotation angles into the metadata associated with the output (reconstruction)
176
        self.angles = mData.get('rotation_angle')
177
        outData.meta_data.set("rotation_angle", copy.deepcopy(self.angles))
178
        if 'detector_x' in list(mData.get_dictionary().keys()):
179
            detector_x_dim = mData.get('detector_x')
180
        else:
181
            xDim = inData.get_data_dimension_by_axis_label('detector_x')
182
            detector_x_dim = inData.get_shape()[xDim]
183
        outData.meta_data.set("detector_x", copy.deepcopy(detector_x_dim))
184
185
    def __set_cor_from_meta_data(self, mData, inData):
186
        cor = mData.get('centre_of_rotation')
187
        sdirs = inData.get_slice_dimensions()
188
        total_frames = np.prod([inData.get_shape()[i] for i in sdirs])
189
        if total_frames > len(cor):
190
            cor = np.tile(cor, total_frames // len(cor))
191
        return cor
192
193
    def __polyfit_cor(self, cor_dict, inData):
194
        if 'detector_y' in list(inData.meta_data.get_dictionary().keys()):
195
            y = inData.meta_data.get('detector_y')
196
        else:
197
            yDim = inData.get_data_dimension_by_axis_label('detector_y')
198
            y = np.arange(inData.get_shape()[yDim])
199
200
        z = np.polyfit(list(map(int, list(cor_dict.keys()))), list(cor_dict.values()), 1)
201
        p = np.poly1d(z)
202
        cor = p(y)
203
        return cor
204
205
    def set_function(self, pad_shape):
206
        if not pad_shape:
207
            def cor_func(cor): return cor
208
            if self.parameters['log']:
209
                sino_func = self.__make_lambda()
210
            else:
211
                sino_func = self.__make_lambda(log=False)
212
        else:
213
            def cor_func(cor): return cor + self.sino_pad
214
            if self.parameters['log']:
215
                sino_func = self.__make_lambda(pad=pad_shape)
216
            else:
217
                sino_func = self.__make_lambda(pad=pad_shape, log=False)
218
        return sino_func, cor_func
219
220
    def __make_lambda(self, log=True, pad=False):
221
        log_func = 'np.nan_to_num(sino)' if not log else self.parameters['log_func']
222
        if pad:
223
            pad_tuples, mode = self.__get_pad_values(pad)
224
            log_func = log_func.replace(
225
                    'sino', 'np.pad(sino, %s, "%s")' % (pad_tuples, mode))
226
        return eval("lambda sino: " + log_func)
227
228
    def __get_pad_values(self, pad_shape):
229
        mode = 'edge'
230
        pad_tuples = [(0, 0)] * (len(pad_shape) - 1)
231
        pad_tuples.insert(self.pad_dim, (self.sino_pad, self.sino_pad))
232
        pad_tuples = tuple(pad_tuples)
233
        return pad_tuples, mode
234
235
    def base_process_frames_before(self, data):
236
        """
237
        Reconstruct a single sinogram with the provided centre of rotation
238
        """
239
        sl = self.get_current_slice_list()[0]
240
        init = data[1] if self.init_vol else None
241
        angles = \
242
            self.angles[:, sl[self.scan_dim]] if self.scan_dim else self.angles
243
244
        self.frame_angles = angles
245
246
        dim_sl = sl[self.main_dir]
247
248
        if self.cor_as_dataset:
249
            self.frame_cors = self.cor_func(data[len(data) - 1])
250
        else:
251
            frame_nos = \
252
                self.get_plugin_in_datasets()[0].get_current_frame_idx()
253
            a = self.cor[tuple([frame_nos])]
254
            self.frame_cors = self.cor_func(a)
255
256
        # for extra padded frames that make up the numbers
257
        if not self.frame_cors.shape:
258
            self.frame_cors = np.array([self.centre])
259
260
        len_data = len(np.arange(dim_sl.start, dim_sl.stop, dim_sl.step))
261
262
        missing = [self.centre] * (len(self.frame_cors) - len_data)
263
        self.frame_cors = np.append(self.frame_cors, missing)
264
265
        # fix to remove NaNs in the initialised image
266
        if init is not None:
267
            init[np.isnan(init)] == 0.0
268
        self.frame_init_data = init
269
270
        data[0] = self.fix_sino(self.sino_func(data[0]), self.frame_cors[0])
271
        return data
272
273
    def base_process_frames_after(self, data):
274
        lower_range, upper_range = self.range
275
        if lower_range is not None:
276
            data[data < lower_range] = 0
277
        if upper_range is not None:
278
            data[data > upper_range] = 0
279
        return data
280
281
    def get_padding_algorithms(self):
282
        """ A list of algorithms that allow the data to be padded. """
283
        return []
284
285
    def pad_sino(self, sino, cor):
286
        """  Pad the sinogram so the centre of rotation is at the centre. """
287
        detX = self._get_detX_dim()
288
        pad = self.get_centre_offset(sino, cor, detX)
289
        self.cor_shift = pad[0]
290
        pad_tuples = [(0, 0)] * (len(sino.shape) - 1)
291
        pad_tuples.insert(detX, tuple(pad))
292
        self.__set_pad_amount(max(pad))
293
        return np.pad(sino, tuple(pad_tuples), mode='edge')
294
295
    def _get_detX_dim(self):
296
        pData = self.get_plugin_in_datasets()[0]
297
        return pData.get_data_dimension_by_axis_label('x', contains=True)
298
299
    def get_centre_offset(self, sino, cor, detX):
300
        centre_pad = self.br_array_pad(cor, sino.shape[detX])
301
        sino_width = sino.shape[detX]
302
        new_width = sino_width + max(centre_pad)
303
        sino_pad = int(math.ceil(float(sino_width) / new_width * self.sino_pad) // 2)
304
        pad = np.array([sino_pad]*2) + centre_pad
305
        return pad
306
307
    def get_centre_shift(self, sino, cor):
308
        detX = self._get_detX_dim()
309
        return max(self.get_centre_offset(sino, self.centre, detX))
310
311
    def crop_sino(self, sino, cor):
312
        """  Crop the sinogram so the centre of rotation is at the centre. """
313
        detX = self._get_detX_dim()
314
        start, stop = self.br_array_pad(cor, sino.shape[detX])[::-1]
315
        self.cor_shift = -start
316
        sl = [slice(None)] * len(sino.shape)
317
        sl[detX] = slice(start, sino.shape[detX] - stop)
318
        sino = sino[tuple(sl)]
319
        self.set_mask(sino.shape)
320
        return sino
321
322
    def br_array_pad(self, ctr, nPixels):
323
        width = nPixels - 1.0
324
        alen = ctr
325
        blen = width - ctr
326
        mid = (width - 1.0) / 2.0
327
        shift = round(abs(blen - alen))
328
        p_low = 0 if (ctr > mid) else shift
329
        p_high = shift + 0 if (ctr > mid) else 0
330
        return np.array([int(p_low), int(p_high)])
331
332
    def keep_sino(self, sino, cor):
333
        """ No change to the sinogram """
334
        return sino
335
336
    def get_sino_centre_method(self):
337
        centre_pad = self.keep_sino
338
        if 'centre_pad' in list(self.parameters.keys()):
339
            cpad = self.parameters['centre_pad']
340
            if not (cpad is True or cpad is False):
341
                raise Exception('Unknown value for "centre_pad", please choose'
342
                                ' True or False.')
343
            centre_pad = self.pad_sino if cpad and self.padding_alg \
344
                else self.crop_sino
345
        return centre_pad
346
347
    def __set_pad_amount(self, pad_amount):
348
        self.base_pad_amount = pad_amount
349
350
    def get_pad_amount(self):
351
        return self.base_pad_amount
352
353
    def get_fov_fraction(self, sino, cor):
354
        """ Get the fraction of the original FOV that can be reconstructed due\
355
        to offset centre """
356
        pData = self.get_plugin_in_datasets()[0]
357
        detX = pData.get_data_dimension_by_axis_label('x', contains=True)
358
        original_length = sino.shape[detX]
359
        shift = self.get_centre_shift(sino, cor)
360
        return (original_length - shift) / float(original_length)
361
362
    def get_reconstruction_alg(self):
363
        return None
364
365
    def get_angles(self):
366
        """ Get the angles associated with the current sinogram(s).
367
368
        :returns: Angles of the current frames.
369
        :rtype: np.ndarray
370
        """
371
        return self.frame_angles
372
373
    def get_cors(self):
374
        """
375
        Get the centre of rotations associated with the current sinogram(s).
376
377
        :returns: Centre of rotation values for the current frames.
378
        :rtype: np.ndarray
379
        """
380
        return self.frame_cors + self.cor_shift
381
382
    def set_mask(self, shape):
383
        pass
384
385
    def get_initial_data(self):
386
        """
387
        Get the initial data (if it is exists) associated with the current \
388
        sinogram(s).
389
390
        :returns: The section of the initialisation data associated with the \
391
            current frames.
392
        :rtype: np.ndarray or None
393
        """
394
        return self.frame_init_data
395
396
    def get_frame_params(self):
397
        params = [self.get_cors(), self.get_angles(), self.get_vol_shape(),
398
                  self.get_initial_data()]
399
        return params
400
401
    def setup(self):
402
        in_dataset, out_dataset = self.get_datasets()
403
        # reduce the data as per data_subset parameter
404
        self.preview_flag = \
405
            self.set_preview(in_dataset[0], self.parameters['preview'])
406
407
        self._set_volume_dimensions(in_dataset[0])
408
        axis_labels = self._get_axis_labels(in_dataset[0])
409
        shape = self._get_shape(in_dataset[0])
410
411
        # output dataset
412
        out_dataset[0].create_dataset(axis_labels=axis_labels, shape=shape)
413
        out_dataset[0].add_volume_patterns(*self._get_volume_dimensions())
414
415
        # set information relating to the plugin data
416
        in_pData, out_pData = self.get_plugin_datasets()
417
418
        self.init_vol = 1 if 'init_vol' in list(self.parameters.keys()) and\
419
            self.parameters['init_vol'] else 0
420
        self.cor_as_dataset = 1 if isinstance(
421
            self.parameters['centre_of_rotation'], str) else 0
422
423
        for i in range(len(in_dataset) - self.init_vol - self.cor_as_dataset):
424
            in_pData[i].plugin_data_setup('SINOGRAM', self.get_max_frames(),
425
                                          slice_axis=self.get_slice_axis())
426
            idx = 1
427
428
        # initial volume dataset
429
        if self.init_vol:
430
#            from savu.data.data_structures.data_types import Replicate
431
#            if self.rep_dim:
432
#                in_dataset[idx].data = Replicate(
433
#                    in_dataset[idx], in_dataset[0].get_shape(self.rep_dim))
434
            in_pData[1].plugin_data_setup('VOLUME_XZ', self.get_max_frames())
435
            idx += 1
0 ignored issues
show
introduced by
The variable idx does not seem to be defined in case the for loop on line 423 is not entered. Are you sure this can never be the case?
Loading history...
436
437
        # cor dataset
438
        if self.cor_as_dataset:
439
            self.cor_as_dataset = True
440
            in_pData[idx].plugin_data_setup('METADATA', self.get_max_frames())
441
442
        # set pattern_name and nframes to process for all datasets
443
        out_pData[0].plugin_data_setup('VOLUME_XZ', self.get_max_frames())
444
        # metadata output populator
445
        in_meta_data = self.get_in_meta_data()[0]
446
        self.populate_metadata_to_output(in_dataset[0], out_dataset[0], in_meta_data)
447
448
    def _get_axis_labels(self, in_dataset):
449
        """
450
        Get the new axis labels for the output dataset - this is now a volume.
451
452
        Parameters
453
        ----------
454
        in_dataset : :class:`savu.data.data_structures.data.Data`
455
            The input dataset to the plugin.
456
457
        Returns
458
        -------
459
        labels : dict
460
            The axis labels for the dataset that is output from the plugin.
461
462
        """
463
        labels = in_dataset.data_info.get('axis_labels')[0]
464
        volX, volY, volZ = self._get_volume_dimensions()
465
        labels = [str(volX) + '.voxel_x.voxels', str(volZ) + '.voxel_z.voxels']
466
        if volY:
467
            labels.append(str(volY) + '.voxel_y.voxels')
468
        labels = {in_dataset: labels}
469
        return labels
470
471
    def _set_volume_dimensions(self, data):
472
        """
473
        Map the input dimensions to the output volume dimensions
474
475
        Parameters
476
        ----------
477
        in_dataset : :class:`savu.data.data_structures.data.Data`
478
            The input dataset to the plugin.
479
        """
480
        data._finalise_patterns()
481
        self.volX = data.get_data_dimension_by_axis_label("rotation_angle")
482
        self.volZ = data.get_data_dimension_by_axis_label("x", contains=True)
483
        self.volY = data.get_data_dimension_by_axis_label(
484
            "y", contains=True, exists=True)
485
486
    def _get_volume_dimensions(self):
487
        return self.volX, self.volY, self.volZ
488
489
    def _get_shape(self, in_dataset):
490
        shape = list(in_dataset.get_shape())
491
        volX, volY, volZ = self._get_volume_dimensions()
492
493
        if self.parameters['vol_shape'] in ('auto', 'fixed'):
494
            shape[volX] = shape[volZ]
495
        else:
496
            shape[volX] = self.parameters['vol_shape']
497
            shape[volZ] = self.parameters['vol_shape']
498
499
        if 'resolution' in self.parameters.keys():
500
            shape[volX] /= self.parameters['resolution']
501
            shape[volZ] /= self.parameters['resolution']
502
        return tuple(shape)
503
504
    def get_max_frames(self):
505
        """
506
        Number of data frames to pass to each instance of process_frames func
507
508
        Returns
509
        -------
510
        str or int
511
            "single", "multiple" or integer (only to be used if the number of
512
                                             frames MUST be fixed.)
513
        """
514
        return 'multiple'
515
516
    def get_slice_axis(self):
517
        """
518
        Fix the fastest changing slice dimension
519
520
        Returns
521
        -------
522
        str or None
523
            str should be the axis_label corresponding to the fastest changing
524
            dimension
525
526
        """
527
        return None
528
529
    def nInput_datasets(self):
530
        nIn = 1
531
        if 'init_vol' in self.parameters.keys() and \
532
                self.parameters['init_vol']:
533
            if len(self.parameters['init_vol'].split('.')) == 3:
534
                name, temp, self.rep_dim = self.parameters['init_vol']
535
                self.parameters['init_vol'] = name
536
            nIn += 1
537
            self.parameters['in_datasets'].append(self.parameters['init_vol'])
538
        if isinstance(self.parameters['centre_of_rotation'], str):
539
            self.parameters['in_datasets'].append(
540
                self.parameters['centre_of_rotation'])
541
            nIn += 1
542
        return nIn
543
544
    def nOutput_datasets(self):
545
        return self.nOut
546
547
    def reconstruct_pre_process(self):
548
        """
549
        Should be overridden to perform pre-processing in a child class
550
        """
551
        pass
552
553
    def executive_summary(self):
554
        summary = []
555
        if not self.preview_flag:
556
            summary.append(("WARNING: Ignoring preview parameters as a preview"
557
                            " has already been applied to the data."))
558
        if len(summary) > 0:
559
            return summary
560
        return ["Nothing to Report"]
561