Test Failed
Push — master ( 26b3f2...a95b01 )
by Daniil
01:45 queued 20s
created

TomobarRecon3d.setup()   F

Complexity

Conditions 16

Size

Total Lines 46
Code Lines 41

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 16
eloc 41
nop 1
dl 0
loc 46
rs 2.4
c 0
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like savu.plugins.reconstructions.tomobar.tomobar_recon_3D.TomobarRecon3d.setup() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# Copyright 2019 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: tomobar_recon_3D
17
   :platform: Unix
18
   :synopsis: A wrapper around TOmographic MOdel-BAsed Reconstruction (ToMoBAR) software \
19
   for direct and advanced iterative image reconstruction using _3D_ capabilities of regularisation. \
20
   This plugin will divide 3D projection data into overlapping subsets using padding.
21
22
.. moduleauthor:: Daniil Kazantsev <[email protected]>
23
"""
24
25
from savu.plugins.reconstructions.base_recon import BaseRecon
26
from savu.plugins.driver.gpu_plugin import GpuPlugin
27
28
import numpy as np
29
from tomobar.methodsIR import RecToolsIR
30
from tomobar.methodsDIR import RecToolsDIR
31
from savu.plugins.utils import register_plugin
32
from savu.core.iterate_plugin_group_utils import enable_iterative_loop, \
33
    check_if_end_plugin_in_iterate_group, setup_extra_plugin_data_padding
34
35
36
@register_plugin
37
class TomobarRecon3d(BaseRecon, GpuPlugin):
38
39
    def __init__(self):
40
        super(TomobarRecon3d, self).__init__("TomobarRecon3d")
41
        self.Vert_det = None
42
        self.pad = None
43
44 View Code Duplication
    @setup_extra_plugin_data_padding
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
45
    def set_filter_padding(self, in_pData, out_pData):
46
        self.pad = self.parameters['padding']
47
        in_data = self.get_in_datasets()[0]
48
        det_y = in_data.get_data_dimension_by_axis_label('detector_y')
49
        pad_det_y = '%s.%s' % (det_y, self.pad)
50
        pad_dict = {'pad_directions': [pad_det_y], 'pad_mode': 'edge'}
51
        in_pData[0].padding = pad_dict
52
        out_pData[0].padding = pad_dict
53
        if len(self.get_in_datasets()) > 1:
54
            in_pData[1].padding = pad_dict
55
56
    @enable_iterative_loop
57
    def setup(self):
58
        in_dataset = self.get_in_datasets()[0]
59
        procs = self.exp.meta_data.get("processes")
60
        procs = len([i for i in procs if 'GPU' in i]) # calculates the total number of GPU processes
61
        dim = in_dataset.get_data_dimension_by_axis_label('detector_y')
62
        nSlices = int(np.ceil(in_dataset.get_shape()[dim] / float(procs)))
63
        # calculate the amount of slices than would fit the GPU memory
64
        gpu_available_mb = self.get_gpu_memory()[0]/procs  # get the free GPU memory of a first device if many
65
        det_x_dim = in_dataset.get_shape()[in_dataset.get_data_dimension_by_axis_label('detector_x')]
66
        rot_angles_dim = in_dataset.get_shape()[in_dataset.get_data_dimension_by_axis_label('rotation_angle')]
67
        slice_size_mbbytes = int(np.ceil(((det_x_dim * det_x_dim) * 1024 * 4) / (1024 ** 3)))
68
69
        if self.parameters['reconstruction_method'] == 'FISTA3D':
70
            # calculate the GPU memory required based on 3D regularisation restrictions (avoiding CUDA-error)
71
            if 'ROF_TV' in self.parameters['regularisation_method']:
72
                slice_size_mbbytes *= 8
73
            if 'FGP_TV' in self.parameters['regularisation_method']:
74
                slice_size_mbbytes *= 12
75
            if 'SB_TV' in self.parameters['regularisation_method']:
76
                slice_size_mbbytes *= 10
77
            if 'PD_TV' in self.parameters['regularisation_method']:
78
                slice_size_mbbytes *= 8
79
            if 'LLT_ROF' in self.parameters['regularisation_method']:
80
                slice_size_mbbytes *= 12
81
            if 'TGV' in self.parameters['regularisation_method']:
82
                slice_size_mbbytes *= 15
83
            if 'NDF' in self.parameters['regularisation_method']:
84
                slice_size_mbbytes *= 5
85
            if 'Diff4th' in self.parameters['regularisation_method']:
86
                slice_size_mbbytes *= 5
87
            if 'NLTV' in self.parameters['regularisation_method']:
88
                slice_size_mbbytes *= 8
89
        if self.parameters['reconstruction_method'] == 'SIRT3D' or self.parameters['reconstruction_method'] == 'CGLS3D':
90
            slice_size_mbbytes *= 3
91
92
        slices_fit_total = int(gpu_available_mb / slice_size_mbbytes) - 2*self.parameters['padding']
93
        if nSlices > slices_fit_total:
94
            nSlices = slices_fit_total
95
        if nSlices < self.parameters['padding']:
96
            print("The padding value is larger than the number of slices in the 3D slab")
97
        self._set_max_frames(nSlices)
98
        # get experimental metadata of projection_shifts
99
        if 'projection_shifts' in list(self.exp.meta_data.dict.keys()):
100
            self.projection_shifts = self.exp.meta_data.dict['projection_shifts']
101
        super(TomobarRecon3d, self).setup()
102
103
    def pre_process(self):
104
        in_pData = self.get_plugin_in_datasets()[0]
105
        self.det_dimX_ind = in_pData.get_data_dimension_by_axis_label('detector_x')
106
        try:
107
            self.det_dimY_ind = in_pData.get_data_dimension_by_axis_label('detector_y')
108
        except ValueError:
109
            raise ValueError('<<<The dimension of the given projection data is 2D, while 3D is required! >>>')
110
        #  getting the value for padded vertical detector
111
        self.Vert_det = in_pData.get_shape()[self.det_dimY_ind] + 2 * self.pad
112
113
        # extract given parameters into dictionaries suitable for ToMoBAR input
114
        self._data_ = {'OS_number': self.parameters['algorithm_ordersubsets'],
115
                       'huber_threshold': self.parameters['data_Huber_thresh'],
116
                       'ringGH_lambda': self.parameters['data_full_ring_GH'],
117
                       'ringGH_accelerate': self.parameters['data_full_ring_accelerator_GH']}
118
119
        self._algorithm_ = {'iterations': self.parameters['algorithm_iterations'],
120
                            'nonnegativity': self.parameters['algorithm_nonnegativity'],
121
                            'mask_diameter': self.parameters['algorithm_mask'],
122
                            'verbose': self.parameters['algorithm_verbose']}
123
124
        self._regularisation_ = {'method': self.parameters['regularisation_method'],
125
                                 'regul_param': self.parameters['regularisation_parameter'],
126
                                 'iterations': self.parameters['regularisation_iterations'],
127
                                 'device_regulariser': self.parameters['regularisation_device'],
128
                                 'edge_threhsold': self.parameters['regularisation_edge_thresh'],
129
                                 'time_marching_step': self.parameters['regularisation_timestep'],
130
                                 'regul_param2': self.parameters['regularisation_parameter2'],
131
                                 'PD_LipschitzConstant': self.parameters['regularisation_PD_lip'],
132
                                 'NDF_penalty': self.parameters['regularisation_NDF_penalty'],
133
                                 'methodTV': self.parameters['regularisation_methodTV']}
134
135
    def process_frames(self, data):
136
        cor, angles, self.vol_shape, init = self.get_frame_params()
137
        self.anglesRAD = np.deg2rad(angles.astype(np.float32))
138
        projdata3D = data[0].astype(np.float32)
139
        dim_tuple = np.shape(projdata3D)
140
        self.Horiz_det = dim_tuple[self.det_dimX_ind]
141
        half_det_width = 0.5 * self.Horiz_det
142
        projdata3D[projdata3D > 10 ** 15] = 0.0
143
        projdata3D = np.require(np.swapaxes(projdata3D, 0, 1), requirements='CA')
144
        self._data_.update({'projection_norm_data': projdata3D})
145
146
        # setup the CoR and offset
147
        cor_astra = half_det_width - np.mean(cor)
148
        CenterOffset_scalar = cor_astra.item() - 0.5
149
        CenterOffset = np.zeros(np.shape(self.projection_shifts))
150
        CenterOffset[:, 0] = CenterOffset_scalar
151
        CenterOffset[:, 1] = -0.5 # TODO: maybe needs to be tweaked?
152
153
        # check if Projection2dAlignment is in the process list, and if so,
154
        # fetch the value of the "registration" parameter (in order to decide
155
        # whether projection shifts need to be taken into account or not)
156
        registration = False
157
        for plugin_dict in self.exp.meta_data.plugin_list.plugin_list:
158
            if plugin_dict['name'] == 'Projection2dAlignment':
159
                registration = plugin_dict['data']['registration']
160
                break
161
162
        if np.sum(self.projection_shifts) != 0.0 and not registration:
163
            # modify the offset to take into account the shifts
164
            CenterOffset[:, 0] -= self.projection_shifts[:, 0]
165
            CenterOffset[:, 1] -= self.projection_shifts[:, 1]
166
167
        # set parameters and initiate a TomoBar class object for iterative reconstruction
168
        RectoolsIter = RecToolsIR(DetectorsDimH=self.Horiz_det,  # DetectorsDimH # detector dimension (horizontal)
169
                                   DetectorsDimV=self.Vert_det,   # DetectorsDimV # detector dimension (vertical) for 3D case only
170
                                   CenterRotOffset=CenterOffset,  # The center of rotation combined with  the shift offsets
171
                                   AnglesVec=-self.anglesRAD,  # the vector of angles in radians
172
                                   ObjSize=self.vol_shape[0],  # a scalar to define the reconstructed object dimensions
173
                                   datafidelity=self.parameters['data_fidelity'], # data fidelity, choose LS, PWLS, SWLS
174
                                   device_projector=self.parameters['GPU_index'])
175
176
        # set parameters and initiate a TomoBar class object for direct reconstruction
177
        RectoolsDIR = RecToolsDIR(DetectorsDimH=self.Horiz_det,  # DetectorsDimH # detector dimension (horizontal)
178
                                   DetectorsDimV=self.Vert_det,  # DetectorsDimV # detector dimension (vertical) for 3D case only
179
                                   CenterRotOffset=CenterOffset,  # The center of rotation combined with the shift offsets
180
                                   AnglesVec=-self.anglesRAD,  # the vector of angles in radians
181
                                   ObjSize=self.vol_shape[0],  # a scalar to define the reconstructed object dimensions
182
                                   device_projector=self.parameters['GPU_index'])
183
184
        if self.parameters['reconstruction_method'] == 'FBP3D':
185
            recon = RectoolsDIR.FBP(projdata3D) #perform FBP3D
186
187
        if self.parameters['reconstruction_method'] == 'CGLS3D':
188
            # Run CGLS 3D reconstruction algorithm here
189
            self._algorithm_.update({'lipschitz_const': None})
190
            recon = RectoolsIter.CGLS(self._data_, self._algorithm_)
191
192
        if self.parameters['reconstruction_method'] == 'SIRT3D':
193
            # Run SIRT 3D reconstruction algorithm here
194
            self._algorithm_.update({'lipschitz_const': None})
195
            recon = RectoolsIter.SIRT(self._data_, self._algorithm_)
196
197
        if self.parameters['reconstruction_method'] == 'FISTA3D':
198
            if self.parameters['regularisation_method'] == 'PD_TV':
199
                self._regularisation_.update({'device_regulariser': self.parameters['GPU_index']})
200
            # if one selects PWLS or SWLS models then raw data is also required (2 inputs)
201
            if (self.parameters['data_fidelity'] == 'PWLS') or (self.parameters['data_fidelity'] == 'SWLS'):
202
                rawdata3D = data[1].astype(np.float32)
203
                rawdata3D[rawdata3D > 10 ** 15] = 0.0
204
                rawdata3D = np.swapaxes(rawdata3D, 0, 1) / np.max(np.float32(rawdata3D))
205
                self._data_.update({'projection_raw_data': rawdata3D})
206
                self._data_.update({'beta_SWLS': self.parameters['data_beta_SWLS'] * np.ones(self.Horiz_det)})
207
            # Run FISTA reconstruction algorithm here
208
            recon = RectoolsIter.FISTA(self._data_, self._algorithm_, self._regularisation_)
209
        return np.require(np.swapaxes(recon, 0, 1), requirements='CA')
0 ignored issues
show
introduced by
The variable recon does not seem to be defined in case SubscriptNode == 'FBP3D' on line 184 is False. Are you sure this can never be the case?
Loading history...
210
211
    def nInput_datasets(self):
212
        return max(len(self.parameters['in_datasets']), 1)
213
214
    # total number of output datasets
215
    def nOutput_datasets(self):
216
        if check_if_end_plugin_in_iterate_group(self.exp):
217
            return 2
218
        else:
219
            return 1
220
221
    # total number of output datasets that are clones
222
    def nClone_datasets(self):
223
        if check_if_end_plugin_in_iterate_group(self.exp):
224
            return 1
225
        else:
226
            return 0
227
228
    def _set_max_frames(self, frames):
229
        self._max_frames = frames
230