Test Failed
Pull Request — master (#700)
by Nicola
03:21
created

BaseRecon.get_frame_params()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 4
nop 1
dl 0
loc 4
rs 10
c 0
b 0
f 0
1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: base_recon
17
   :platform: Unix
18
   :synopsis: A base class for all reconstruction methods
19
20
.. moduleauthor:: Mark Basham <[email protected]>
21
22
"""
23
import math
24
import copy
25
import numpy as np
26
np.seterr(divide='ignore', invalid='ignore')
27
28
import savu.core.utils as cu
29
from savu.plugins.plugin import Plugin
30
31
MAX_OUTER_PAD = 2.1
32
33
34
class BaseRecon(Plugin):
35
36
37
    def __init__(self, name='BaseRecon'):
38
        super(BaseRecon, self).__init__(name)
39
        self.nOut = 1
40
        self.nIn = 1
41
        self.scan_dim = None
42
        self.rep_dim = None
43
        self.br_vol_shape = None
44
        self.frame_angles = None
45
        self.frame_cors = None
46
        self.frame_init_data = None
47
        self.centre = None
48
        self.base_pad_amount = None
49
        self.padding_alg = False
50
        self.cor_shift = 0
51
        self.init_vol = False
52
        self.cor_as_dataset = False
53
54
    def base_pre_process(self):
55
        in_data, out_data = self.get_datasets()
56
        in_pData, out_pData = self.get_plugin_datasets()
57
        self.pad_dim = \
58
            in_pData[0].get_data_dimension_by_axis_label('x', contains=True)
59
        in_meta_data = self.get_in_meta_data()[0]
60
61
        self.exp.log(self.name + " End")
62
        self.br_vol_shape = out_pData[0].get_shape()
63
        self.set_centre_of_rotation(in_data[0], out_data[0], in_meta_data)
64
65
        self.main_dir = in_data[0].get_data_patterns()['SINOGRAM']['main_dir']
66
        self.angles = in_meta_data.get('rotation_angle')
67
        if len(self.angles.shape) != 1:
68
            self.scan_dim = in_data[0].get_data_dimension_by_axis_label('scan')
69
        self.slice_dirs = out_data[0].get_slice_dimensions()
70
71
        shape = in_pData[0].get_shape()
72
        factor = self.__get_outer_pad()
73
        self.sino_pad = int(math.ceil(factor * shape[self.pad_dim]))
74
75
        self.sino_func, self.cor_func = self.set_function(shape)
76
77
        self.range = self.parameters['force_zero']
78
        self.fix_sino = self.get_sino_centre_method()
79
80
    def __get_outer_pad(self):
81
        pad = self.parameters['outer_pad'] if 'outer_pad' in self.parameters \
82
            else False
83
        # length of diagonal of square is side*sqrt(2)
84
        factor = math.sqrt(2) - 1
85
        if isinstance(pad, bool):
86
            return factor if pad is True else 0
87
88
        factor = float(pad)
89
        if factor > MAX_OUTER_PAD:
90
            factor = MAX_OUTER_PAD
91
            msg = 'Maximum outer_pad value is 2.1, using this instead'
92
            cu.user_message(msg)
93
        return factor
94
95
    def get_vol_shape(self):
96
        return self.br_vol_shape
97
98
    def set_centre_of_rotation(self, inData, outData, mData):
99
        # if cor has been passed as a dataset then do nothing
100
        if isinstance(self.parameters['centre_of_rotation'], str):
101
            return
102
        if 'centre_of_rotation' in list(mData.get_dictionary().keys()):
103
            cor = self.__set_cor_from_meta_data(mData, inData)
104
        else:
105
            val = self.parameters['centre_of_rotation']
106
            if isinstance(val, dict):
107
                cor = self.__polyfit_cor(val, inData)
108
            else:
109
                sdirs = inData.get_slice_dimensions()
110
                cor = np.ones(np.prod([inData.get_shape()[i] for i in sdirs]))
111
                # if centre of rotation has not been set then fix it in the
112
                # centre
113
                val = val if val != 0 else \
114
                    (self.get_vol_shape()[self._get_detX_dim()]) / 2.0
115
                cor *= val
116
                # mData.set('centre_of_rotation', cor) see Github ticket
117
        self.cor = cor
118
        outData.meta_data.set("centre_of_rotation", copy.deepcopy(self.cor))
119
        self.centre = self.cor[0]
120
121
    def populate_metadata_to_output(self, inData, outData, mData):
122
        # writing  rotation angles into the metadata associated with the output (reconstruction)
123
        self.angles = mData.get('rotation_angle')
124
        outData.meta_data.set("rotation_angle", copy.deepcopy(self.angles))
125
126
        xDim = inData.get_data_dimension_by_axis_label('x', contains=True)
127
        det_length = inData.get_shape()[xDim]
128
        outData.meta_data.set("detector_x_length", copy.deepcopy(det_length))
129
130
    def __set_cor_from_meta_data(self, mData, inData):
131
        cor = mData.get('centre_of_rotation')
132
        sdirs = inData.get_slice_dimensions()
133
        total_frames = np.prod([inData.get_shape()[i] for i in sdirs])
134
        if total_frames > len(cor):
135
            cor = np.tile(cor, total_frames // len(cor))
136
        return cor
137
138
    def __polyfit_cor(self, cor_dict, inData):
139
        if 'detector_y' in list(inData.meta_data.get_dictionary().keys()):
140
            y = inData.meta_data.get('detector_y')
141
        else:
142
            yDim = inData.get_data_dimension_by_axis_label('detector_y')
143
            y = np.arange(inData.get_shape()[yDim])
144
145
        z = np.polyfit(list(map(int, list(cor_dict.keys()))), list(cor_dict.values()), 1)
146
        p = np.poly1d(z)
147
        cor = p(y)
148
        return cor
149
150
    def set_function(self, pad_shape):
151
        centre_pad = self.parameters['centre_pad'] if 'centre_pad' in \
152
            self.parameters else False
153
        if not centre_pad:
154
            def cor_func(cor): return cor
155
            if self.parameters['log']:
156
                sino_func = self.__make_lambda()
157
            else:
158
                sino_func = self.__make_lambda(log=False)
159
        else:
160
            def cor_func(cor): return cor + self.sino_pad
161
            if self.parameters['log']:
162
                sino_func = self.__make_lambda(pad=pad_shape)
163
            else:
164
                sino_func = self.__make_lambda(pad=pad_shape, log=False)
165
        return sino_func, cor_func
166
167
    def __make_lambda(self, log=True, pad=False):
168
        log_func = 'np.nan_to_num(sino)' if not log else self.parameters['log_func']
169
        if pad:
170
            pad_tuples, mode = self.__get_pad_values(pad)
171
            log_func = log_func.replace(
172
                    'sino', 'np.pad(sino, %s, "%s")' % (pad_tuples, mode))
173
        return eval("lambda sino: " + log_func)
174
175
    def __get_pad_values(self, pad_shape):
176
        mode = 'edge'
177
        pad_tuples = [(0, 0)] * (len(pad_shape) - 1)
178
        pad_tuples.insert(self.pad_dim, (self.sino_pad, self.sino_pad))
179
        pad_tuples = tuple(pad_tuples)
180
        return pad_tuples, mode
181
182
    def base_process_frames_before(self, data):
183
        """
184
        Reconstruct a single sinogram with the provided centre of rotation
185
        """
186
        sl = self.get_current_slice_list()[0]
187
        init = data[1] if self.init_vol else None
188
        angles = \
189
            self.angles[:, sl[self.scan_dim]] if self.scan_dim else self.angles
190
191
        self.frame_angles = angles
192
193
        dim_sl = sl[self.main_dir]
194
195
        if self.cor_as_dataset:
196
            self.frame_cors = self.cor_func(data[len(data) - 1])
197
        else:
198
            frame_nos = \
199
                self.get_plugin_in_datasets()[0].get_current_frame_idx()
200
            a = self.cor[tuple([frame_nos])]
201
            self.frame_cors = self.cor_func(a)
202
203
        # for extra padded frames that make up the numbers
204
        if not self.frame_cors.shape:
205
            self.frame_cors = np.array([self.centre])
206
207
        len_data = len(np.arange(dim_sl.start, dim_sl.stop, dim_sl.step))
208
209
        missing = [self.centre] * (len(self.frame_cors) - len_data)
210
        self.frame_cors = np.append(self.frame_cors, missing)
211
212
        # fix to remove NaNs in the initialised image
213
        if init is not None:
214
            init[np.isnan(init)] == 0.0
215
        self.frame_init_data = init
216
217
        data[0] = self.fix_sino(self.sino_func(data[0]), self.frame_cors[0])
218
        return data
219
220
    def base_process_frames_after(self, data):
221
        lower_range, upper_range = self.range
222
        if lower_range is not None:
223
            data[data < lower_range] = 0
224
        if upper_range is not None:
225
            data[data > upper_range] = 0
226
        return data
227
228
    def pad_sino(self, sino, cor):
229
        """  Pad the sinogram so the centre of rotation is at the centre. """
230
        detX = self._get_detX_dim()
231
        pad = self.get_centre_offset(sino, cor, detX)
232
        self.cor_shift = pad[0]
233
        pad_tuples = [(0, 0)] * (len(sino.shape) - 1)
234
        pad_tuples.insert(detX, tuple(pad))
235
        self.__set_pad_amount(max(pad))
236
        return np.pad(sino, tuple(pad_tuples), mode='edge')
237
238
    def _get_detX_dim(self):
239
        pData = self.get_plugin_in_datasets()[0]
240
        return pData.get_data_dimension_by_axis_label('x', contains=True)
241
242
    def get_centre_offset(self, sino, cor, detX):
243
        centre_pad = self.br_array_pad(cor, sino.shape[detX])
244
        sino_width = sino.shape[detX]
245
        new_width = sino_width + max(centre_pad)
246
        sino_pad = int(math.ceil(float(sino_width) / new_width * self.sino_pad) // 2)
247
        pad = np.array([sino_pad]*2) + centre_pad
248
        return pad
249
250
    def get_centre_shift(self, sino, cor):
251
        detX = self._get_detX_dim()
252
        return max(self.get_centre_offset(sino, self.centre, detX))
253
254
    def crop_sino(self, sino, cor):
255
        """  Crop the sinogram so the centre of rotation is at the centre. """
256
        detX = self._get_detX_dim()
257
        start, stop = self.br_array_pad(cor, sino.shape[detX])[::-1]
258
        self.cor_shift = -start
259
        sl = [slice(None)] * len(sino.shape)
260
        sl[detX] = slice(start, sino.shape[detX] - stop)
261
        sino = sino[tuple(sl)]
262
        self.set_mask(sino.shape)
263
        return sino
264
265
    def br_array_pad(self, ctr, nPixels):
266
        width = nPixels - 1.0
267
        alen = ctr
268
        blen = width - ctr
269
        mid = (width - 1.0) / 2.0
270
        shift = round(abs(blen - alen))
271
        p_low = 0 if (ctr > mid) else shift
272
        p_high = shift + 0 if (ctr > mid) else 0
273
        return np.array([int(p_low), int(p_high)])
274
275
    def keep_sino(self, sino, cor):
276
        """ No change to the sinogram """
277
        return sino
278
279
    def get_sino_centre_method(self):
280
        centre_pad = self.keep_sino
281
        if 'centre_pad' in list(self.parameters.keys()) and \
282
                self.parameters['centre_pad'] is not False:
283
            cpad = self.parameters['centre_pad']
284
            if not (cpad is True or cpad is False):
285
                raise Exception('Unknown value for "centre_pad", please choose'
286
                                ' True or False.')
287
            centre_pad = self.pad_sino if cpad else self.crop_sino
288
        return centre_pad
289
290
    def __set_pad_amount(self, pad_amount):
291
        self.base_pad_amount = pad_amount
292
293
    def get_pad_amount(self):
294
        return self.base_pad_amount
295
296
    def get_fov_fraction(self, sino, cor):
297
        """ Get the fraction of the original FOV that can be reconstructed due\
298
        to offset centre """
299
        pData = self.get_plugin_in_datasets()[0]
300
        detX = pData.get_data_dimension_by_axis_label('x', contains=True)
301
        original_length = sino.shape[detX]
302
        shift = self.get_centre_shift(sino, cor)
303
        return (original_length - shift) / float(original_length)
304
305
    def get_reconstruction_alg(self):
306
        return None
307
308
    def get_angles(self):
309
        """ Get the angles associated with the current sinogram(s).
310
311
        :returns: Angles of the current frames.
312
        :rtype: np.ndarray
313
        """
314
        return self.frame_angles
315
316
    def get_cors(self):
317
        """
318
        Get the centre of rotations associated with the current sinogram(s).
319
320
        :returns: Centre of rotation values for the current frames.
321
        :rtype: np.ndarray
322
        """
323
        return self.frame_cors + self.cor_shift
324
325
    def set_mask(self, shape):
326
        pass
327
328
    def get_initial_data(self):
329
        """
330
        Get the initial data (if it is exists) associated with the current \
331
        sinogram(s).
332
333
        :returns: The section of the initialisation data associated with the \
334
            current frames.
335
        :rtype: np.ndarray or None
336
        """
337
        return self.frame_init_data
338
339
    def get_frame_params(self):
340
        params = [self.get_cors(), self.get_angles(), self.get_vol_shape(),
341
                  self.get_initial_data()]
342
        return params
343
344
    def setup(self):
345
        in_dataset, out_dataset = self.get_datasets()
346
        # reduce the data as per data_subset parameter
347
        self.preview_flag = \
348
            self.set_preview(in_dataset[0], self.parameters['preview'])
349
350
        self._set_volume_dimensions(in_dataset[0])
351
        axis_labels = self._get_axis_labels(in_dataset[0])
352
        shape = self._get_shape(in_dataset[0])
353
354
        # output dataset
355
        out_dataset[0].create_dataset(axis_labels=axis_labels, shape=shape)
356
        out_dataset[0].add_volume_patterns(*self._get_volume_dimensions())
357
358
        # set information relating to the plugin data
359
        in_pData, out_pData = self.get_plugin_datasets()
360
361
        self.init_vol = 1 if 'init_vol' in list(self.parameters.keys()) and\
362
            self.parameters['init_vol'] else 0
363
        self.cor_as_dataset = 1 if isinstance(
364
            self.parameters['centre_of_rotation'], str) else 0
365
366
        for i in range(len(in_dataset) - self.init_vol - self.cor_as_dataset):
367
            in_pData[i].plugin_data_setup('SINOGRAM', self.get_max_frames(),
368
                                          slice_axis=self.get_slice_axis())
369
            idx = 1
370
371
        # initial volume dataset
372
        if self.init_vol:
373
#            from savu.data.data_structures.data_types import Replicate
374
#            if self.rep_dim:
375
#                in_dataset[idx].data = Replicate(
376
#                    in_dataset[idx], in_dataset[0].get_shape(self.rep_dim))
377
            in_pData[1].plugin_data_setup('VOLUME_XZ', self.get_max_frames())
378
            idx += 1
0 ignored issues
show
introduced by
The variable idx does not seem to be defined in case the for loop on line 366 is not entered. Are you sure this can never be the case?
Loading history...
379
380
        # cor dataset
381
        if self.cor_as_dataset:
382
            self.cor_as_dataset = True
383
            in_pData[idx].plugin_data_setup('METADATA', self.get_max_frames())
384
385
        # set pattern_name and nframes to process for all datasets
386
        out_pData[0].plugin_data_setup('VOLUME_XZ', self.get_max_frames())
387
        # metadata output populator
388
        in_meta_data = self.get_in_meta_data()[0]
389
        self.populate_metadata_to_output(in_dataset[0], out_dataset[0], in_meta_data)
390
391
    def _get_axis_labels(self, in_dataset):
392
        """
393
        Get the new axis labels for the output dataset - this is now a volume.
394
395
        Parameters
396
        ----------
397
        in_dataset : :class:`savu.data.data_structures.data.Data`
398
            The input dataset to the plugin.
399
400
        Returns
401
        -------
402
        labels : dict
403
            The axis labels for the dataset that is output from the plugin.
404
405
        """
406
        labels = in_dataset.data_info.get('axis_labels')[0]
407
        volX, volY, volZ = self._get_volume_dimensions()
408
        labels = [str(volX) + '.voxel_x.voxels', str(volZ) + '.voxel_z.voxels']
409
        if volY:
410
            labels.append(str(volY) + '.voxel_y.voxels')
411
        labels = {in_dataset: labels}
412
        return labels
413
414
    def _set_volume_dimensions(self, data):
415
        """
416
        Map the input dimensions to the output volume dimensions
417
418
        Parameters
419
        ----------
420
        in_dataset : :class:`savu.data.data_structures.data.Data`
421
            The input dataset to the plugin.
422
        """
423
        data._finalise_patterns()
424
        self.volX = data.get_data_dimension_by_axis_label("rotation_angle")
425
        self.volZ = data.get_data_dimension_by_axis_label("x", contains=True)
426
        self.volY = data.get_data_dimension_by_axis_label(
427
            "y", contains=True, exists=True)
428
429
    def _get_volume_dimensions(self):
430
        return self.volX, self.volY, self.volZ
431
432
    def _get_shape(self, in_dataset):
433
        shape = list(in_dataset.get_shape())
434
        volX, volY, volZ = self._get_volume_dimensions()
435
436
        if self.parameters['vol_shape'] in ('auto', 'fixed'):
437
            shape[volX] = shape[volZ]
438
        else:
439
            shape[volX] = self.parameters['vol_shape']
440
            shape[volZ] = self.parameters['vol_shape']
441
442
        if 'resolution' in self.parameters.keys():
443
            shape[volX] = int(shape[volX] // self.parameters['resolution'])
444
            shape[volZ] = int(shape[volZ] // self.parameters['resolution'])
445
        return tuple(shape)
446
447
    def get_max_frames(self):
448
        """
449
        Number of data frames to pass to each instance of process_frames func
450
451
        Returns
452
        -------
453
        str or int
454
            "single", "multiple" or integer (only to be used if the number of
455
                                             frames MUST be fixed.)
456
        """
457
        return 'multiple'
458
459
    def get_slice_axis(self):
460
        """
461
        Fix the fastest changing slice dimension
462
463
        Returns
464
        -------
465
        str or None
466
            str should be the axis_label corresponding to the fastest changing
467
            dimension
468
469
        """
470
        return None
471
472
    def nInput_datasets(self):
473
        nIn = 1
474
        if 'init_vol' in self.parameters.keys() and \
475
                self.parameters['init_vol']:
476
            if len(self.parameters['init_vol'].split('.')) == 3:
477
                name, temp, self.rep_dim = self.parameters['init_vol']
478
                self.parameters['init_vol'] = name
479
            nIn += 1
480
            self.parameters['in_datasets'].append(self.parameters['init_vol'])
481
        if isinstance(self.parameters['centre_of_rotation'], str):
482
            self.parameters['in_datasets'].append(
483
                self.parameters['centre_of_rotation'])
484
            nIn += 1
485
        return nIn
486
487
    def nOutput_datasets(self):
488
        return self.nOut
489
490
    def reconstruct_pre_process(self):
491
        """
492
        Should be overridden to perform pre-processing in a child class
493
        """
494
        pass
495
496
    def executive_summary(self):
497
        summary = []
498
        if not self.preview_flag:
499
            summary.append(("WARNING: Ignoring preview parameters as a preview"
500
                            " has already been applied to the data."))
501
        if len(summary) > 0:
502
            return summary
503
        return ["Nothing to Report"]
504