Test Failed
Pull Request — master (#700)
by Daniil
03:41
created

BaseRecon.crop_sino()   A

Complexity

Conditions 1

Size

Total Lines 10
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 9
nop 3
dl 0
loc 10
rs 9.95
c 0
b 0
f 0
1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: base_recon
17
   :platform: Unix
18
   :synopsis: A base class for all reconstruction methods
19
20
.. moduleauthor:: Mark Basham <[email protected]>
21
22
"""
23
import math
24
import copy
25
import numpy as np
26
np.seterr(divide='ignore', invalid='ignore')
27
28
import savu.core.utils as cu
29
from savu.plugins.plugin import Plugin
30
31
MAX_OUTER_PAD = 2.1
32
33
34
class BaseRecon(Plugin):
35
    """
36
    A base class for reconstruction plugins
37
38
    :u*param centre_of_rotation: Centre of rotation to use for the \
39
    reconstruction. Default: 0.0.
40
41
    :u*param init_vol: Dataset to use as volume initialiser \
42
    (doesn't currently work with preview). Default: None.
43
44
    :param centre_pad: Pad the sinogram to centre it in order to fill the \
45
    reconstructed volume ROI for asthetic purposes.\
46
    NB: Only available for selected algorithms and will be ignored otherwise. \
47
    WARNING: This will significantly increase the size of the data and the \
48
    time to compute the reconstruction). Default: False.
49
50
    :param outer_pad: Pad the sinogram width to fill the reconstructed volume \
51
    for asthetic purposes.\
52
    Choose from True (defaults to sqrt(2)), False or float <= 2.1. \
53
    NB: Only available for selected algorithms and will be ignored otherwise.\
54
    WARNING: This will increase the size of the data and the \
55
    time to compute the reconstruction). Default: False.
56
57
    :u*param log: Take the log of the data before reconstruction \
58
    (True or False). Default: True.
59
60
    :u*param preview: A slice list of required frames. Default: [].
61
62
    :param force_zero: Set any values in the reconstructed image outside of \
63
    this range to zero. Default: [None, None].
64
65
    :param ratio: Ratio between the diameter of a circle mask and the width of\
66
    a reconstructed image. If passed as a list or tuple, the second value is \
67
    assigned to the outer mask area, e.g [0.95, 0.0]. Default: 0.95.
68
69
    :param log_func: Override the default log \
70
        function. Default: 'np.nan_to_num(-np.log(sino))'.
71
72
    :param vol_shape: Override the size of the reconstruction volume with an \
73
    integer value. Default: 'fixed'.
74
    """
75
76
    def __init__(self, name='BaseRecon'):
77
        super(BaseRecon, self).__init__(name)
78
        self.nOut = 1
79
        self.nIn = 1
80
        self.scan_dim = None
81
        self.rep_dim = None
82
        self.br_vol_shape = None
83
        self.frame_angles = None
84
        self.frame_cors = None
85
        self.frame_init_data = None
86
        self.centre = None
87
        self.base_pad_amount = None
88
        self.padding_alg = False
89
        self.cor_shift = 0
90
        self.init_vol = False
91
        self.cor_as_dataset = False
92
93
    def base_pre_process(self):
94
        in_data, out_data = self.get_datasets()
95
        in_pData, out_pData = self.get_plugin_datasets()
96
        self.pad_dim = \
97
            in_pData[0].get_data_dimension_by_axis_label('x', contains=True)
98
        in_meta_data = self.get_in_meta_data()[0]
99
        self.__set_padding_alg()
100
101
        self.exp.log(self.name + " End")
102
        self.br_vol_shape = out_pData[0].get_shape()
103
        self.set_centre_of_rotation(in_data[0], in_meta_data)
104
105
        self.main_dir = in_data[0].get_data_patterns()['SINOGRAM']['main_dir']
106
        self.angles = in_meta_data.get('rotation_angle')
107
        if len(self.angles.shape) != 1:
108
            self.scan_dim = in_data[0].get_data_dimension_by_axis_label('scan')
109
        self.slice_dirs = out_data[0].get_slice_dimensions()
110
111
        # metadata output populator
112
        self.populate_metadata_to_output(in_data[0], out_data[0], in_meta_data)
113
114
        shape = in_pData[0].get_shape()
115
        factor = self.__get_outer_pad()
116
        self.sino_pad = int(math.ceil(factor * shape[self.pad_dim]))
117
118
        self.sino_func, self.cor_func = self.set_function(shape) if \
119
            self.padding_alg else self.set_function(False)
120
121
        self.range = self.parameters['force_zero']
122
        self.fix_sino = self.get_sino_centre_method()
123
124
    def __get_outer_pad(self):
125
        # length of diagonal of square is side*sqrt(2)
126
        factor = math.sqrt(2) - 1
127
        pad = self.parameters['outer_pad'] if 'outer_pad' in \
128
            list(self.parameters.keys()) else False
129
130
        if pad is not False and not self.padding_alg:
131
            msg = 'This reconstruction algorithm cannot be padded.'
132
            cu.user_message(msg)
133
            return 0
134
135
        if isinstance(pad, bool):
136
            return factor if pad is True else 0
137
        factor = float(pad)
138
        if factor > MAX_OUTER_PAD:
139
            factor = MAX_OUTER_PAD
140
            msg = 'Maximum outer_pad value is 2.1, using this instead'
141
            cu.user_message(msg)
142
        return float(pad)
143
144
    def __set_padding_alg(self):
145
        """ Determine if this is an algorithm that allows sinogram padding. """
146
        pad_algs = self.get_padding_algorithms()
147
        alg = self.parameters['algorithm'] if 'algorithm' in \
148
            list(self.parameters.keys()) else None
149
        self.padding_alg = True if alg in pad_algs else False
150
151
    def get_vol_shape(self):
152
        return self.br_vol_shape
153
154
    def set_centre_of_rotation(self, inData, mData):
155
        # if cor has been passed as a dataset then do nothing
156
        if isinstance(self.parameters['centre_of_rotation'], str):
157
            return
158
        if 'centre_of_rotation' in list(mData.get_dictionary().keys()):
159
            cor = self.__set_cor_from_meta_data(mData, inData)
160
        else:
161
            val = self.parameters['centre_of_rotation']
162
            if isinstance(val, dict):
163
                cor = self.__polyfit_cor(val, inData)
164
            else:
165
                sdirs = inData.get_slice_dimensions()
166
                cor = np.ones(np.prod([inData.get_shape()[i] for i in sdirs]))
167
                # if centre of rotation has not been set then fix it in the
168
                # centre
169
                val = val if val != 0 else \
170
                    (self.get_vol_shape()[self._get_detX_dim()]) / 2.0
171
                cor *= val
172
                # mData.set('centre_of_rotation', cor) see Github ticket
173
        self.cor = cor
174
        self.centre = self.cor[0]
175
176
    def populate_metadata_to_output(self, inData, outData, mData):
177
        # writing  rotation angles into the metadata associated with the output (reconstruction)
178
        outData.meta_data.set("rotation_angle", copy.deepcopy(self.angles))
179
        if 'detector_x' in list(mData.get_dictionary().keys()):
180
            detector_x_dim = mData.get('detector_x')
181
        else:
182
            xDim = inData.get_data_dimension_by_axis_label('detector_x')
183
            detector_x_dim = inData.get_shape()[xDim]
184
        outData.meta_data.set("detector_x", copy.deepcopy(detector_x_dim))
185
        outData.meta_data.set("centre_of_rotation", copy.deepcopy(self.cor))
186
187
    def __set_cor_from_meta_data(self, mData, inData):
188
        cor = mData.get('centre_of_rotation')
189
        sdirs = inData.get_slice_dimensions()
190
        total_frames = np.prod([inData.get_shape()[i] for i in sdirs])
191
        if total_frames > len(cor):
192
            cor = np.tile(cor, total_frames // len(cor))
193
        return cor
194
195
    def __polyfit_cor(self, cor_dict, inData):
196
        if 'detector_y' in list(inData.meta_data.get_dictionary().keys()):
197
            y = inData.meta_data.get('detector_y')
198
        else:
199
            yDim = inData.get_data_dimension_by_axis_label('detector_y')
200
            y = np.arange(inData.get_shape()[yDim])
201
202
        z = np.polyfit(list(map(int, list(cor_dict.keys()))), list(cor_dict.values()), 1)
203
        p = np.poly1d(z)
204
        cor = p(y)
205
        return cor
206
207
    def set_function(self, pad_shape):
208
        if not pad_shape:
209
            def cor_func(cor): return cor
210
            if self.parameters['log']:
211
                sino_func = self.__make_lambda()
212
            else:
213
                sino_func = self.__make_lambda(log=False)
214
        else:
215
            def cor_func(cor): return cor + self.sino_pad
216
            if self.parameters['log']:
217
                sino_func = self.__make_lambda(pad=pad_shape)
218
            else:
219
                sino_func = self.__make_lambda(pad=pad_shape, log=False)
220
        return sino_func, cor_func
221
222
    def __make_lambda(self, log=True, pad=False):
223
        log_func = 'np.nan_to_num(sino)' if not log else self.parameters['log_func']
224
        if pad:
225
            pad_tuples, mode = self.__get_pad_values(pad)
226
            log_func = log_func.replace(
227
                    'sino', 'np.pad(sino, %s, "%s")' % (pad_tuples, mode))
228
        return eval("lambda sino: " + log_func)
229
230
    def __get_pad_values(self, pad_shape):
231
        mode = 'edge'
232
        pad_tuples = [(0, 0)] * (len(pad_shape) - 1)
233
        pad_tuples.insert(self.pad_dim, (self.sino_pad, self.sino_pad))
234
        pad_tuples = tuple(pad_tuples)
235
        return pad_tuples, mode
236
237
    def base_process_frames_before(self, data):
238
        """
239
        Reconstruct a single sinogram with the provided centre of rotation
240
        """
241
        sl = self.get_current_slice_list()[0]
242
        init = data[1] if self.init_vol else None
243
        angles = \
244
            self.angles[:, sl[self.scan_dim]] if self.scan_dim else self.angles
245
246
        self.frame_angles = angles
247
248
        dim_sl = sl[self.main_dir]
249
250
        if self.cor_as_dataset:
251
            self.frame_cors = self.cor_func(data[len(data) - 1])
252
        else:
253
            frame_nos = \
254
                self.get_plugin_in_datasets()[0].get_current_frame_idx()
255
            a = self.cor[tuple([frame_nos])]
256
            self.frame_cors = self.cor_func(a)
257
258
        # for extra padded frames that make up the numbers
259
        if not self.frame_cors.shape:
260
            self.frame_cors = np.array([self.centre])
261
262
        len_data = len(np.arange(dim_sl.start, dim_sl.stop, dim_sl.step))
263
264
        missing = [self.centre] * (len(self.frame_cors) - len_data)
265
        self.frame_cors = np.append(self.frame_cors, missing)
266
267
        # fix to remove NaNs in the initialised image
268
        if init is not None:
269
            init[np.isnan(init)] == 0.0
270
        self.frame_init_data = init
271
272
        data[0] = self.fix_sino(self.sino_func(data[0]), self.frame_cors[0])
273
        return data
274
275
    def base_process_frames_after(self, data):
276
        lower_range, upper_range = self.range
277
        if lower_range is not None:
278
            data[data < lower_range] = 0
279
        if upper_range is not None:
280
            data[data > upper_range] = 0
281
        return data
282
283
    def get_padding_algorithms(self):
284
        """ A list of algorithms that allow the data to be padded. """
285
        return []
286
287
    def pad_sino(self, sino, cor):
288
        """  Pad the sinogram so the centre of rotation is at the centre. """
289
        detX = self._get_detX_dim()
290
        pad = self.get_centre_offset(sino, cor, detX)
291
        self.cor_shift = pad[0]
292
        pad_tuples = [(0, 0)] * (len(sino.shape) - 1)
293
        pad_tuples.insert(detX, tuple(pad))
294
        self.__set_pad_amount(max(pad))
295
        return np.pad(sino, tuple(pad_tuples), mode='edge')
296
297
    def _get_detX_dim(self):
298
        pData = self.get_plugin_in_datasets()[0]
299
        return pData.get_data_dimension_by_axis_label('x', contains=True)
300
301
    def get_centre_offset(self, sino, cor, detX):
302
        centre_pad = self.br_array_pad(cor, sino.shape[detX])
303
        sino_width = sino.shape[detX]
304
        new_width = sino_width + max(centre_pad)
305
        sino_pad = int(math.ceil(float(sino_width) / new_width * self.sino_pad) // 2)
306
        pad = np.array([sino_pad]*2) + centre_pad
307
        return pad
308
309
    def get_centre_shift(self, sino, cor):
310
        detX = self._get_detX_dim()
311
        return max(self.get_centre_offset(sino, self.centre, detX))
312
313
    def crop_sino(self, sino, cor):
314
        """  Crop the sinogram so the centre of rotation is at the centre. """
315
        detX = self._get_detX_dim()
316
        start, stop = self.br_array_pad(cor, sino.shape[detX])[::-1]
317
        self.cor_shift = -start
318
        sl = [slice(None)] * len(sino.shape)
319
        sl[detX] = slice(start, sino.shape[detX] - stop)
320
        sino = sino[tuple(sl)]
321
        self.set_mask(sino.shape)
322
        return sino
323
324
    def br_array_pad(self, ctr, nPixels):
325
        width = nPixels - 1.0
326
        alen = ctr
327
        blen = width - ctr
328
        mid = (width - 1.0) / 2.0
329
        shift = round(abs(blen - alen))
330
        p_low = 0 if (ctr > mid) else shift
331
        p_high = shift + 0 if (ctr > mid) else 0
332
        return np.array([int(p_low), int(p_high)])
333
334
    def keep_sino(self, sino, cor):
335
        """ No change to the sinogram """
336
        return sino
337
338
    def get_sino_centre_method(self):
339
        centre_pad = self.keep_sino
340
        if 'centre_pad' in list(self.parameters.keys()):
341
            cpad = self.parameters['centre_pad']
342
            if not (cpad is True or cpad is False):
343
                raise Exception('Unknown value for "centre_pad", please choose'
344
                                ' True or False.')
345
            centre_pad = self.pad_sino if cpad and self.padding_alg \
346
                else self.crop_sino
347
        return centre_pad
348
349
    def __set_pad_amount(self, pad_amount):
350
        self.base_pad_amount = pad_amount
351
352
    def get_pad_amount(self):
353
        return self.base_pad_amount
354
355
    def get_fov_fraction(self, sino, cor):
356
        """ Get the fraction of the original FOV that can be reconstructed due\
357
        to offset centre """
358
        pData = self.get_plugin_in_datasets()[0]
359
        detX = pData.get_data_dimension_by_axis_label('x', contains=True)
360
        original_length = sino.shape[detX]
361
        shift = self.get_centre_shift(sino, cor)
362
        return (original_length - shift) / float(original_length)
363
364
    def get_reconstruction_alg(self):
365
        return None
366
367
    def get_angles(self):
368
        """ Get the angles associated with the current sinogram(s).
369
370
        :returns: Angles of the current frames.
371
        :rtype: np.ndarray
372
        """
373
        return self.frame_angles
374
375
    def get_cors(self):
376
        """
377
        Get the centre of rotations associated with the current sinogram(s).
378
379
        :returns: Centre of rotation values for the current frames.
380
        :rtype: np.ndarray
381
        """
382
        return self.frame_cors + self.cor_shift
383
384
    def set_mask(self, shape):
385
        pass
386
387
    def get_initial_data(self):
388
        """
389
        Get the initial data (if it is exists) associated with the current \
390
        sinogram(s).
391
392
        :returns: The section of the initialisation data associated with the \
393
            current frames.
394
        :rtype: np.ndarray or None
395
        """
396
        return self.frame_init_data
397
398
    def get_frame_params(self):
399
        params = [self.get_cors(), self.get_angles(), self.get_vol_shape(),
400
                  self.get_initial_data()]
401
        return params
402
403
    def setup(self):
404
        in_dataset, out_dataset = self.get_datasets()
405
        # reduce the data as per data_subset parameter
406
        self.preview_flag = \
407
            self.set_preview(in_dataset[0], self.parameters['preview'])
408
409
        self._set_volume_dimensions(in_dataset[0])
410
        axis_labels = self._get_axis_labels(in_dataset[0])
411
        shape = self._get_shape(in_dataset[0])
412
413
        # output dataset
414
        out_dataset[0].create_dataset(axis_labels=axis_labels, shape=shape)
415
        out_dataset[0].add_volume_patterns(*self._get_volume_dimensions())
416
417
        # set information relating to the plugin data
418
        in_pData, out_pData = self.get_plugin_datasets()
419
420
        self.init_vol = 1 if 'init_vol' in list(self.parameters.keys()) and\
421
            self.parameters['init_vol'] else 0
422
        self.cor_as_dataset = 1 if isinstance(
423
            self.parameters['centre_of_rotation'], str) else 0
424
425
        for i in range(len(in_dataset) - self.init_vol - self.cor_as_dataset):
426
            in_pData[i].plugin_data_setup('SINOGRAM', self.get_max_frames(),
427
                                          slice_axis=self.get_slice_axis())
428
            idx = 1
429
430
        # initial volume dataset
431
        if self.init_vol:
432
#            from savu.data.data_structures.data_types import Replicate
433
#            if self.rep_dim:
434
#                in_dataset[idx].data = Replicate(
435
#                    in_dataset[idx], in_dataset[0].get_shape(self.rep_dim))
436
            in_pData[1].plugin_data_setup('VOLUME_XZ', self.get_max_frames())
437
            idx += 1
0 ignored issues
show
introduced by
The variable idx does not seem to be defined in case the for loop on line 425 is not entered. Are you sure this can never be the case?
Loading history...
438
439
        # cor dataset
440
        if self.cor_as_dataset:
441
            self.cor_as_dataset = True
442
            in_pData[idx].plugin_data_setup('METADATA', self.get_max_frames())
443
444
        # set pattern_name and nframes to process for all datasets
445
        out_pData[0].plugin_data_setup('VOLUME_XZ', self.get_max_frames())
446
447
    def _get_axis_labels(self, in_dataset):
448
        """
449
        Get the new axis labels for the output dataset - this is now a volume.
450
451
        Parameters
452
        ----------
453
        in_dataset : :class:`savu.data.data_structures.data.Data`
454
            The input dataset to the plugin.
455
456
        Returns
457
        -------
458
        labels : dict
459
            The axis labels for the dataset that is output from the plugin.
460
461
        """
462
        labels = in_dataset.data_info.get('axis_labels')[0]
463
        volX, volY, volZ = self._get_volume_dimensions()
464
        labels = [str(volX) + '.voxel_x.voxels', str(volZ) + '.voxel_z.voxels']
465
        if volY:
466
            labels.append(str(volY) + '.voxel_y.voxels')
467
        labels = {in_dataset: labels}
468
        return labels
469
470
    def _set_volume_dimensions(self, data):
471
        """
472
        Map the input dimensions to the output volume dimensions
473
474
        Parameters
475
        ----------
476
        in_dataset : :class:`savu.data.data_structures.data.Data`
477
            The input dataset to the plugin.
478
        """
479
        data._finalise_patterns()
480
        self.volX = data.get_data_dimension_by_axis_label("rotation_angle")
481
        self.volZ = data.get_data_dimension_by_axis_label("x", contains=True)
482
        self.volY = data.get_data_dimension_by_axis_label(
483
            "y", contains=True, exists=True)
484
485
    def _get_volume_dimensions(self):
486
        return self.volX, self.volY, self.volZ
487
488
    def _get_shape(self, in_dataset):
489
        shape = list(in_dataset.get_shape())
490
        volX, volY, volZ = self._get_volume_dimensions()
491
492
        if self.parameters['vol_shape'] in ('auto', 'fixed'):
493
            shape[volX] = shape[volZ]
494
        else:
495
            shape[volX] = self.parameters['vol_shape']
496
            shape[volZ] = self.parameters['vol_shape']
497
498
        if 'resolution' in self.parameters.keys():
499
            shape[volX] /= self.parameters['resolution']
500
            shape[volZ] /= self.parameters['resolution']
501
        return tuple(shape)
502
503
    def get_max_frames(self):
504
        """
505
        Number of data frames to pass to each instance of process_frames func
506
507
        Returns
508
        -------
509
        str or int
510
            "single", "multiple" or integer (only to be used if the number of
511
                                             frames MUST be fixed.)
512
        """
513
        return 'multiple'
514
515
    def get_slice_axis(self):
516
        """
517
        Fix the fastest changing slice dimension
518
519
        Returns
520
        -------
521
        str or None
522
            str should be the axis_label corresponding to the fastest changing
523
            dimension
524
525
        """
526
        return None
527
528
    def nInput_datasets(self):
529
        nIn = 1
530
        if 'init_vol' in self.parameters.keys() and \
531
                self.parameters['init_vol']:
532
            if len(self.parameters['init_vol'].split('.')) == 3:
533
                name, temp, self.rep_dim = self.parameters['init_vol']
534
                self.parameters['init_vol'] = name
535
            nIn += 1
536
            self.parameters['in_datasets'].append(self.parameters['init_vol'])
537
        if isinstance(self.parameters['centre_of_rotation'], str):
538
            self.parameters['in_datasets'].append(
539
                self.parameters['centre_of_rotation'])
540
            nIn += 1
541
        return nIn
542
543
    def nOutput_datasets(self):
544
        return self.nOut
545
546
    def reconstruct_pre_process(self):
547
        """
548
        Should be overridden to perform pre-processing in a child class
549
        """
550
        pass
551
552
    def executive_summary(self):
553
        summary = []
554
        if not self.preview_flag:
555
            summary.append(("WARNING: Ignoring preview parameters as a preview"
556
                            " has already been applied to the data."))
557
        if len(summary) > 0:
558
            return summary
559
        return ["Nothing to Report"]
560