Test Failed
Push — master ( e9a2f9...ebc117 )
by Daniil
05:01 queued 14s
created

TomobarRecon3d.pre_process()   A

Complexity

Conditions 1

Size

Total Lines 28
Code Lines 23

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 23
nop 1
dl 0
loc 28
rs 9.328
c 0
b 0
f 0
1
# Copyright 2019 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: tomobar_recon_3D
17
   :platform: Unix
18
   :synopsis: A wrapper around TOmographic MOdel-BAsed Reconstruction (ToMoBAR) software \
19
   for direct and advanced iterative image reconstruction using _3D_ capabilities of regularisation. \
20
   This plugin will divide 3D projection data into overlapping subsets using padding.
21
22
.. moduleauthor:: Daniil Kazantsev <[email protected]>
23
"""
24
25
from savu.plugins.reconstructions.base_recon import BaseRecon
26
from savu.plugins.driver.gpu_plugin import GpuPlugin
27
28
import numpy as np
29
from tomobar.methodsIR import RecToolsIR
30
from tomobar.methodsDIR import RecToolsDIR
31
from savu.plugins.utils import register_plugin
32
from savu.core.iterate_plugin_group_utils import enable_iterative_loop, \
33
    check_if_end_plugin_in_iterate_group, setup_extra_plugin_data_padding
34
35
36
@register_plugin
37
class TomobarRecon3d(BaseRecon, GpuPlugin):
38
39
    def __init__(self):
40
        super(TomobarRecon3d, self).__init__("TomobarRecon3d")
41
        self.Vert_det = None
42
        self.pad = None
43
44 View Code Duplication
    @setup_extra_plugin_data_padding
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
45
    def set_filter_padding(self, in_pData, out_pData):
46
        self.pad = self.parameters['padding']
47
        in_data = self.get_in_datasets()[0]
48
        det_y = in_data.get_data_dimension_by_axis_label('detector_y')
49
        pad_det_y = '%s.%s' % (det_y, self.pad)
50
        pad_dict = {'pad_directions': [pad_det_y], 'pad_mode': 'edge'}
51
        in_pData[0].padding = pad_dict
52
        out_pData[0].padding = pad_dict
53
        if len(self.get_in_datasets()) > 1:
54
            in_pData[1].padding = pad_dict
55
56
    @enable_iterative_loop
57
    def setup(self):
58
        in_dataset = self.get_in_datasets()[0]
59
        procs = self.exp.meta_data.get("processes")
60
        procs = len([i for i in procs if 'GPU' in i]) # calculates the total number of GPU processes
61
        dim = in_dataset.get_data_dimension_by_axis_label('detector_y')
62
        nSlices = int(np.ceil(in_dataset.get_shape()[dim] / float(procs)))
63
        # calculate the amount of slices than would fit the GPU memory
64
        gpu_available_mb = self.get_gpu_memory()[0]/procs  # get the free GPU memory of a first device if many
65
        det_x_dim = in_dataset.get_shape()[in_dataset.get_data_dimension_by_axis_label('detector_x')]
66
        rot_angles_dim = in_dataset.get_shape()[in_dataset.get_data_dimension_by_axis_label('rotation_angle')]
67
        slice_size_mbbytes = int(np.ceil(((det_x_dim * det_x_dim) * 1024 * 4) / (1024 ** 3)))
68
69
        if self.parameters['reconstruction_method'] == 'FISTA3D':
70
            # calculate the GPU memory required based on 3D regularisation restrictions (avoiding CUDA-error)
71
            if 'ROF_TV' in self.parameters['regularisation_method']:
72
                slice_size_mbbytes *= 8
73
            if 'FGP_TV' in self.parameters['regularisation_method']:
74
                slice_size_mbbytes *= 12
75
            if 'SB_TV' in self.parameters['regularisation_method']:
76
                slice_size_mbbytes *= 10
77
            if 'PD_TV' in self.parameters['regularisation_method']:
78
                slice_size_mbbytes *= 8
79
            if 'LLT_ROF' in self.parameters['regularisation_method']:
80
                slice_size_mbbytes *= 12
81
            if 'TGV' in self.parameters['regularisation_method']:
82
                slice_size_mbbytes *= 15
83
            if 'NDF' in self.parameters['regularisation_method']:
84
                slice_size_mbbytes *= 5
85
            if 'Diff4th' in self.parameters['regularisation_method']:
86
                slice_size_mbbytes *= 5
87
            if 'NLTV' in self.parameters['regularisation_method']:
88
                slice_size_mbbytes *= 8
89
90
        slices_fit_total = int(gpu_available_mb / slice_size_mbbytes) - 2*self.parameters['padding']
91
        if nSlices > slices_fit_total:
92
            nSlices = slices_fit_total
93
        self._set_max_frames(nSlices)
94
        # get experimental metadata of projection_shifts
95
        if 'projection_shifts' in list(self.exp.meta_data.dict.keys()):
96
            self.projection_shifts = self.exp.meta_data.dict['projection_shifts']
97
        super(TomobarRecon3d, self).setup()
98
99
    def pre_process(self):
100
        in_pData = self.get_plugin_in_datasets()[0]
101
        self.det_dimX_ind = in_pData.get_data_dimension_by_axis_label('detector_x')
102
        self.det_dimY_ind = in_pData.get_data_dimension_by_axis_label('detector_y')
103
        #  getting the value for padded vertical detector
104
        self.Vert_det = in_pData.get_shape()[self.det_dimY_ind] + 2 * self.pad
105
106
        # extract given parameters into dictionaries suitable for ToMoBAR input
107
        self._data_ = {'OS_number': self.parameters['algorithm_ordersubsets'],
108
                       'huber_threshold': self.parameters['data_Huber_thresh'],
109
                       'ringGH_lambda': self.parameters['data_full_ring_GH'],
110
                       'ringGH_accelerate': self.parameters['data_full_ring_accelerator_GH']}
111
112
        self._algorithm_ = {'iterations': self.parameters['algorithm_iterations'],
113
                            'nonnegativity': self.parameters['algorithm_nonnegativity'],
114
                            'mask_diameter': self.parameters['algorithm_mask'],
115
                            'verbose': self.parameters['algorithm_verbose']}
116
117
        self._regularisation_ = {'method': self.parameters['regularisation_method'],
118
                                 'regul_param': self.parameters['regularisation_parameter'],
119
                                 'iterations': self.parameters['regularisation_iterations'],
120
                                 'device_regulariser': self.parameters['regularisation_device'],
121
                                 'edge_threhsold': self.parameters['regularisation_edge_thresh'],
122
                                 'time_marching_step': self.parameters['regularisation_timestep'],
123
                                 'regul_param2': self.parameters['regularisation_parameter2'],
124
                                 'PD_LipschitzConstant': self.parameters['regularisation_PD_lip'],
125
                                 'NDF_penalty': self.parameters['regularisation_NDF_penalty'],
126
                                 'methodTV': self.parameters['regularisation_methodTV']}
127
128
    def process_frames(self, data):
129
        cor, angles, self.vol_shape, init = self.get_frame_params()
130
        self.anglesRAD = np.deg2rad(angles.astype(np.float32))
131
        projdata3D = data[0].astype(np.float32)
132
        dim_tuple = np.shape(projdata3D)
133
        self.Horiz_det = dim_tuple[self.det_dimX_ind]
134
        half_det_width = 0.5 * self.Horiz_det
135
        projdata3D[projdata3D > 10 ** 15] = 0.0
136
        projdata3D = np.swapaxes(projdata3D, 0, 1)
137
        self._data_.update({'projection_norm_data': projdata3D})
138
139
        # dealing with projection shifts and the CoR
140
        cor_astra = half_det_width - np.mean(cor)
141
        CenterOffset = cor_astra.item() - 0.5
142
        if np.sum(self.projection_shifts) != 0.0:
143
            CenterOffset = np.zeros(np.shape(self.projection_shifts))
144
            CenterOffset[:, 0] = (cor_astra.item() - 0.5) - self.projection_shifts[:, 0]
145
            CenterOffset[:, 1] = -self.projection_shifts[:, 1] - 0.5
146
147
        if self.parameters['reconstruction_method'] == 'FISTA3D':
148
            # if one selects PWLS or SWLS models then raw data is also required (2 inputs)
149
            if (self.parameters['data_fidelity'] == 'PWLS') or (self.parameters['data_fidelity'] == 'SWLS'):
150
                rawdata3D = data[1].astype(np.float32)
151
                rawdata3D[rawdata3D > 10 ** 15] = 0.0
152
                rawdata3D = np.swapaxes(rawdata3D, 0, 1) / np.max(np.float32(rawdata3D))
153
                self._data_.update({'projection_raw_data': rawdata3D})
154
                self._data_.update({'beta_SWLS': self.parameters['data_beta_SWLS'] * np.ones(self.Horiz_det)})
155
156
            # set parameters and initiate a TomoBar class object for FISTA reconstruction
157
            RectoolsIter = RecToolsIR(DetectorsDimH=self.Horiz_det,  # DetectorsDimH # detector dimension (horizontal)
158
                                       DetectorsDimV=self.Vert_det,   # DetectorsDimV # detector dimension (vertical) for 3D case only
159
                                       CenterRotOffset=CenterOffset,  # The center of rotation combined with  the shift offsets
160
                                       AnglesVec=-self.anglesRAD,  # the vector of angles in radians
161
                                       ObjSize=self.vol_shape[0],  # a scalar to define the reconstructed object dimensions
162
                                       datafidelity=self.parameters['data_fidelity'], # data fidelity, choose LS, PWLS, SWLS
163
                                       device_projector=self.parameters['GPU_index'])
164
165
            # Run FISTA reconstruction algorithm here
166
            recon = RectoolsIter.FISTA(self._data_, self._algorithm_, self._regularisation_)
167
168
        if self.parameters['reconstruction_method'] == 'FBP3D':
169
            RectoolsDIR = RecToolsDIR(DetectorsDimH=self.Horiz_det,  # DetectorsDimH # detector dimension (horizontal)
170
                                       DetectorsDimV=self.Vert_det,  # DetectorsDimV # detector dimension (vertical) for 3D case only
171
                                       CenterRotOffset=CenterOffset,  # The center of rotation combined with the shift offsets
172
                                       AnglesVec=-self.anglesRAD,  # the vector of angles in radians
173
                                       ObjSize=self.vol_shape[0],  # a scalar to define the reconstructed object dimensions
174
                                       device_projector=self.parameters['GPU_index'])
175
176
            recon = RectoolsDIR.FBP(projdata3D) #perform FBP3D
177
178
        if self.parameters['reconstruction_method'] == 'CGLS3D':
179
            # set parameters and initiate a TomoBar class object for FISTA reconstruction
180
            RectoolsIter = RecToolsIR(DetectorsDimH=self.Horiz_det,  # DetectorsDimH # detector dimension (horizontal)
181
                                       DetectorsDimV=self.Vert_det,   # DetectorsDimV # detector dimension (vertical) for 3D case only
182
                                       CenterRotOffset=CenterOffset,  # The center of rotation combined with  the shift offsets
183
                                       AnglesVec=-self.anglesRAD,  # the vector of angles in radians
184
                                       ObjSize=self.vol_shape[0],  # a scalar to define the reconstructed object dimensions
185
                                       datafidelity=self.parameters['data_fidelity'], # data fidelity, choose LS, PWLS, SWLS
186
                                       device_projector=self.parameters['GPU_index'])
187
188
            # Run CGLS reconstruction algorithm here
189
            recon = RectoolsIter.CGLS(self._data_, self._algorithm_)
190
191
        recon = np.swapaxes(recon, 0, 1)
0 ignored issues
show
introduced by
The variable recon does not seem to be defined in case SubscriptNode == 'FISTA3D' on line 147 is False. Are you sure this can never be the case?
Loading history...
192
        return recon
193
194
    def nInput_datasets(self):
195
        return max(len(self.parameters['in_datasets']), 1)
196
197
    # total number of output datasets
198
    def nOutput_datasets(self):
199
        if check_if_end_plugin_in_iterate_group(self.exp):
200
            return 2
201
        else:
202
            return 1
203
204
    # total number of output datasets that are clones
205
    def nClone_datasets(self):
206
        if check_if_end_plugin_in_iterate_group(self.exp):
207
            return 1
208
        else:
209
            return 0
210
211
    def _set_max_frames(self, frames):
212
        self._max_frames = frames
213