Test Failed
Pull Request — master (#875)
by Daniil
04:09
created

  A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
# Copyright 2019 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: tomobar_recon_3D
17
   :platform: Unix
18
   :synopsis: A wrapper around TOmographic MOdel-BAsed Reconstruction (ToMoBAR) software \
19
   for advanced iterative image reconstruction using _3D_ capabilities of regularisation. \
20
   This plugin will divide 3D projection data into overlapping subsets using padding.
21
22
.. moduleauthor:: Daniil Kazantsev <[email protected]>
23
"""
24
25
from savu.plugins.reconstructions.base_recon import BaseRecon
26
from savu.plugins.driver.gpu_plugin import GpuPlugin
27
28
import numpy as np
29
from tomobar.methodsIR import RecToolsIR
30
from savu.plugins.utils import register_plugin
31
32
33
@register_plugin
34
class TomobarRecon3d(BaseRecon, GpuPlugin):
35
36
    def __init__(self):
37
        super(TomobarRecon3d, self).__init__("TomobarRecon3d")
38
39
    def set_filter_padding(self, in_pData, out_pData):
40
        self.pad = self.parameters['padding']
41
        in_data = self.get_in_datasets()[0]
42
        det_y = in_data.get_data_dimension_by_axis_label('detector_y')
43
        pad_det_y = '%s.%s' % (det_y, self.pad)
44
        pad_dict = {'pad_directions': [pad_det_y], 'pad_mode': 'edge'}
45
        in_pData[0].padding = pad_dict
46
        out_pData[0].padding = pad_dict
47
48
    def setup(self):
49
        in_dataset = self.get_in_datasets()[0]
50
        procs = self.exp.meta_data.get("processes")
51
        procs = len([i for i in procs if 'GPU' in i])
52
        dim = in_dataset.get_data_dimension_by_axis_label('detector_y')
53
        nSlices = int(np.ceil(in_dataset.get_shape()[dim] / float(procs)))
54
        # calculate the amount of slices than would fit the GPU memory
55
        gpu_available_mb = self.get_gpu_memory()[0]  # get the free GPU memory of a first device if many
56
        det_x_dim = in_dataset.get_shape()[in_dataset.get_data_dimension_by_axis_label('detector_x')]
57
        rot_angles_dim = in_dataset.get_shape()[in_dataset.get_data_dimension_by_axis_label('rotation_angle')]
58
        slice_dize_mbbytes = int(np.ceil(((det_x_dim * rot_angles_dim) * 1024 * 4)/(1024**3)))
59
        # calculate the GPU memory required based on 3D regularisation restrictions (avoiding CUDA-error)
60
        if 'ROF_TV' in self.parameters['regularisation_method']:
61
            slice_dize_mbbytes *= 4.5
62
        if 'FGP_TV' in self.parameters['regularisation_method']:
63
            slice_dize_mbbytes *= 8.5
64
        if 'SB_TV' in self.parameters['regularisation_method']:
65
            slice_dize_mbbytes *= 6.5
66
        if 'PD_TV' in self.parameters['regularisation_method']:
67
            slice_dize_mbbytes *= 6.5
68
        if 'LLT_ROF' in self.parameters['regularisation_method']:
69
            slice_dize_mbbytes *= 8.5
70
        if 'TGV' in self.parameters['regularisation_method']:
71
            slice_dize_mbbytes *= 11.5
72
        if 'NDF' in self.parameters['regularisation_method']:
73
            slice_dize_mbbytes *= 3.5
74
        if 'Diff4th' in self.parameters['regularisation_method']:
75
            slice_dize_mbbytes *= 3.5
76
        if 'NLTV' in self.parameters['regularisation_method']:
77
            slice_dize_mbbytes *= 4.5
78
        slices_fit_total = int(gpu_available_mb/slice_dize_mbbytes)
79
        if nSlices > slices_fit_total:
80
            nSlices = slices_fit_total
81
        self._set_max_frames(nSlices)
82
        # get experimental metadata of projection_shifts
83
        if 'projection_shifts' in list(self.exp.meta_data.dict.keys()):
84
            self.projection_shifts = self.exp.meta_data.dict['projection_shifts']
85
        super(TomobarRecon3d, self).setup()
86
87
    def pre_process(self):
88
        in_pData = self.get_plugin_in_datasets()[0]
89
        out_pData = self.get_plugin_out_datasets()[0]
90
        detY = in_pData.get_data_dimension_by_axis_label('detector_y')
91
        # ! padding the vertical detector !
92
        self.Vert_det = in_pData.get_shape()[detY] + 2 * self.pad
93
94
        in_pData = self.get_plugin_in_datasets()
95
        self.det_dimX_ind = in_pData[0].get_data_dimension_by_axis_label('detector_x')
96
        self.det_dimY_ind = in_pData[0].get_data_dimension_by_axis_label('detector_y')
97
98
        # extract given parameters into dictionaries suitable for ToMoBAR input
99
        self._data_ = {'OS_number': self.parameters['algorithm_ordersubsets'],
100
                       'huber_threshold': self.parameters['data_Huber_thresh'],
101
                       'ringGH_lambda': self.parameters['data_full_ring_GH'],
102
                       'ringGH_accelerate': self.parameters['data_full_ring_accelerator_GH']}
103
104
        self._algorithm_ = {'iterations': self.parameters['algorithm_iterations'],
105
                            'nonnegativity': self.parameters['algorithm_nonnegativity'],
106
                            'mask_diameter': self.parameters['algorithm_mask'],
107
                            'verbose': self.parameters['algorithm_verbose']}
108
109
        self._regularisation_ = {'method': self.parameters['regularisation_method'],
110
                                 'regul_param': self.parameters['regularisation_parameter'],
111
                                 'iterations': self.parameters['regularisation_iterations'],
112
                                 'device_regulariser': self.parameters['regularisation_device'],
113
                                 'edge_threhsold': self.parameters['regularisation_edge_thresh'],
114
                                 'time_marching_step': self.parameters['regularisation_timestep'],
115
                                 'regul_param2': self.parameters['regularisation_parameter2'],
116
                                 'PD_LipschitzConstant': self.parameters['regularisation_PD_lip'],
117
                                 'NDF_penalty': self.parameters['regularisation_NDF_penalty'],
118
                                 'methodTV': self.parameters['regularisation_methodTV']}
119
120
    def process_frames(self, data):
121
        cor, angles, self.vol_shape, init = self.get_frame_params()
122
        self.anglesRAD = np.deg2rad(angles.astype(np.float32))
123
        projdata3D = data[0].astype(np.float32)
124
        dim_tuple = np.shape(projdata3D)
125
        self.Horiz_det = dim_tuple[self.det_dimX_ind]
126
        half_det_width = 0.5 * self.Horiz_det
127
        cor_astra = half_det_width - np.mean(cor)
128
        projdata3D[projdata3D > 10 ** 15] = 0.0
129
        projdata3D = np.swapaxes(projdata3D, 0, 1)
130
        self._data_.update({'projection_norm_data': projdata3D})
131
132
        # if one selects PWLS or SWLS models then raw data is also required (2 inputs)
133
        if ((self.parameters['data_fidelity'] == 'PWLS') or (self.parameters['data_fidelity'] == 'SWLS')):
134
            rawdata3D = data[1].astype(np.float32)
135
            rawdata3D[rawdata3D > 10 ** 15] = 0.0
136
            rawdata3D = np.swapaxes(rawdata3D, 0, 1) / np.max(np.float32(rawdata3D))
137
            self._data_.update({'projection_raw_data': rawdata3D})
138
            self._data_.update({'beta_SWLS': self.parameters['data_beta_SWLS'] * np.ones(self.Horiz_det)})
139
140
        # set parameters and initiate a TomoBar class object
141
        self.Rectools = RecToolsIR(DetectorsDimH=self.Horiz_det,  # DetectorsDimH # detector dimension (horizontal)
142
                                   DetectorsDimV=self.Vert_det,  # DetectorsDimV # detector dimension (vertical) for 3D case only
143
                                   CenterRotOffset=(cor_astra.item() - 0.5)-self.projection_shifts,  # The center of rotation (CoR)
144
                                   AnglesVec=-self.anglesRAD,  # the vector of angles in radians
145
                                   ObjSize=self.vol_shape[0],  # a scalar to define the reconstructed object dimensions
146
                                   datafidelity=self.parameters['data_fidelity'],  # data fidelity, choose LS, PWLS, SWLS
147
                                   device_projector='gpu')
148
149
        # Run FISTA reconstruction algorithm here
150
        recon = self.Rectools.FISTA(self._data_, self._algorithm_, self._regularisation_)
151
        recon = np.swapaxes(recon, 0, 1)
152
        return recon
153
154
    def nInput_datasets(self):
155
        return max(len(self.parameters['in_datasets']), 1)
156
157
    def nOutput_datasets(self):
158
        return 1
159
160
    def _set_max_frames(self, frames):
161
        self._max_frames = frames