Test Failed
Pull Request — master (#913)
by Daniil
04:13
created

TomobarRecon3d.setup()   F

Complexity

Conditions 15

Size

Total Lines 44
Code Lines 39

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 15
eloc 39
nop 1
dl 0
loc 44
rs 2.9998
c 0
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like savu.plugins.reconstructions.tomobar.tomobar_recon_3D.TomobarRecon3d.setup() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# Copyright 2019 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: tomobar_recon_3D
17
   :platform: Unix
18
   :synopsis: A wrapper around TOmographic MOdel-BAsed Reconstruction (ToMoBAR) software \
19
   for direct and advanced iterative image reconstruction using _3D_ capabilities of regularisation. \
20
   This plugin will divide 3D projection data into overlapping subsets using padding.
21
22
.. moduleauthor:: Daniil Kazantsev <[email protected]>
23
"""
24
25
from savu.plugins.reconstructions.base_recon import BaseRecon
26
from savu.plugins.driver.gpu_plugin import GpuPlugin
27
28
import numpy as np
29
from tomobar.methodsIR import RecToolsIR
30
from tomobar.methodsDIR import RecToolsDIR
31
from savu.plugins.utils import register_plugin
32
from savu.core.iterate_plugin_group_utils import enable_iterative_loop, \
33
    check_if_end_plugin_in_iterate_group, setup_extra_plugin_data_padding
34
35
36
@register_plugin
37
class TomobarRecon3d(BaseRecon, GpuPlugin):
38
39
    def __init__(self):
40
        super(TomobarRecon3d, self).__init__("TomobarRecon3d")
41
        self.Vert_det = None
42
        self.pad = None
43
44 View Code Duplication
    @setup_extra_plugin_data_padding
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
45
    def set_filter_padding(self, in_pData, out_pData):
46
        self.pad = self.parameters['padding']
47
        in_data = self.get_in_datasets()[0]
48
        det_y = in_data.get_data_dimension_by_axis_label('detector_y')
49
        pad_det_y = '%s.%s' % (det_y, self.pad)
50
        pad_dict = {'pad_directions': [pad_det_y], 'pad_mode': 'edge'}
51
        in_pData[0].padding = pad_dict
52
        out_pData[0].padding = pad_dict
53
        if len(self.get_in_datasets()) > 1:
54
            in_pData[1].padding = pad_dict
55
56
    @enable_iterative_loop
57
    def setup(self):
58
        in_dataset = self.get_in_datasets()[0]
59
        procs = self.exp.meta_data.get("processes")
60
        procs = len([i for i in procs if 'GPU' in i]) # calculates the total number of GPU processes
61
        dim = in_dataset.get_data_dimension_by_axis_label('detector_y')
62
        nSlices = int(np.ceil(in_dataset.get_shape()[dim] / float(procs)))
63
        # calculate the amount of slices than would fit the GPU memory
64
        gpu_available_mb = self.get_gpu_memory()[0]/procs  # get the free GPU memory of a first device if many
65
        det_x_dim = in_dataset.get_shape()[in_dataset.get_data_dimension_by_axis_label('detector_x')]
66
        rot_angles_dim = in_dataset.get_shape()[in_dataset.get_data_dimension_by_axis_label('rotation_angle')]
67
        slice_size_mbbytes = int(np.ceil(((det_x_dim * det_x_dim) * 1024 * 4) / (1024 ** 3)))
68
69
        if self.parameters['reconstruction_method'] == 'FISTA3D':
70
            # calculate the GPU memory required based on 3D regularisation restrictions (avoiding CUDA-error)
71
            if 'ROF_TV' in self.parameters['regularisation_method']:
72
                slice_size_mbbytes *= 8
73
            if 'FGP_TV' in self.parameters['regularisation_method']:
74
                slice_size_mbbytes *= 12
75
            if 'SB_TV' in self.parameters['regularisation_method']:
76
                slice_size_mbbytes *= 10
77
            if 'PD_TV' in self.parameters['regularisation_method']:
78
                slice_size_mbbytes *= 8
79
            if 'LLT_ROF' in self.parameters['regularisation_method']:
80
                slice_size_mbbytes *= 12
81
            if 'TGV' in self.parameters['regularisation_method']:
82
                slice_size_mbbytes *= 15
83
            if 'NDF' in self.parameters['regularisation_method']:
84
                slice_size_mbbytes *= 5
85
            if 'Diff4th' in self.parameters['regularisation_method']:
86
                slice_size_mbbytes *= 5
87
            if 'NLTV' in self.parameters['regularisation_method']:
88
                slice_size_mbbytes *= 8
89
        if self.parameters['reconstruction_method'] == 'SIRT3D' or self.parameters['reconstruction_method'] == 'CGLS3D':
90
            slice_size_mbbytes *= 3
91
92
        slices_fit_total = int(gpu_available_mb / slice_size_mbbytes) - 2*self.parameters['padding']
93
        if nSlices > slices_fit_total:
94
            nSlices = slices_fit_total
95
        self._set_max_frames(nSlices)
96
        # get experimental metadata of projection_shifts
97
        if 'projection_shifts' in list(self.exp.meta_data.dict.keys()):
98
            self.projection_shifts = self.exp.meta_data.dict['projection_shifts']
99
        super(TomobarRecon3d, self).setup()
100
101
    def pre_process(self):
102
        in_pData = self.get_plugin_in_datasets()[0]
103
        self.det_dimX_ind = in_pData.get_data_dimension_by_axis_label('detector_x')
104
        try:
105
            self.det_dimY_ind = in_pData.get_data_dimension_by_axis_label('detector_y')
106
        except ValueError:
107
            raise ValueError('<<<The dimension of the given projection data is 2D, while 3D is required! >>>')
108
        #  getting the value for padded vertical detector
109
        self.Vert_det = in_pData.get_shape()[self.det_dimY_ind] + 2 * self.pad
110
111
        # extract given parameters into dictionaries suitable for ToMoBAR input
112
        self._data_ = {'OS_number': self.parameters['algorithm_ordersubsets'],
113
                       'huber_threshold': self.parameters['data_Huber_thresh'],
114
                       'ringGH_lambda': self.parameters['data_full_ring_GH'],
115
                       'ringGH_accelerate': self.parameters['data_full_ring_accelerator_GH']}
116
117
        self._algorithm_ = {'iterations': self.parameters['algorithm_iterations'],
118
                            'nonnegativity': self.parameters['algorithm_nonnegativity'],
119
                            'mask_diameter': self.parameters['algorithm_mask'],
120
                            'verbose': self.parameters['algorithm_verbose']}
121
122
        self._regularisation_ = {'method': self.parameters['regularisation_method'],
123
                                 'regul_param': self.parameters['regularisation_parameter'],
124
                                 'iterations': self.parameters['regularisation_iterations'],
125
                                 'device_regulariser': self.parameters['regularisation_device'],
126
                                 'edge_threhsold': self.parameters['regularisation_edge_thresh'],
127
                                 'time_marching_step': self.parameters['regularisation_timestep'],
128
                                 'regul_param2': self.parameters['regularisation_parameter2'],
129
                                 'PD_LipschitzConstant': self.parameters['regularisation_PD_lip'],
130
                                 'NDF_penalty': self.parameters['regularisation_NDF_penalty'],
131
                                 'methodTV': self.parameters['regularisation_methodTV']}
132
133
    def process_frames(self, data):
134
        cor, angles, self.vol_shape, init = self.get_frame_params()
135
        self.anglesRAD = np.deg2rad(angles.astype(np.float32))
136
        projdata3D = data[0].astype(np.float32)
137
        dim_tuple = np.shape(projdata3D)
138
        self.Horiz_det = dim_tuple[self.det_dimX_ind]
139
        half_det_width = 0.5 * self.Horiz_det
140
        projdata3D[projdata3D > 10 ** 15] = 0.0
141
        projdata3D = np.swapaxes(projdata3D, 0, 1)
142
        self._data_.update({'projection_norm_data': projdata3D})
143
144
        # dealing with projection shifts and the CoR
145
        cor_astra = half_det_width - np.mean(cor)
146
        CenterOffset = cor_astra.item() - 0.5
147
        if np.sum(self.projection_shifts) != 0.0:
148
            CenterOffset = np.zeros(np.shape(self.projection_shifts))
149
            CenterOffset[:, 0] = (cor_astra.item() - 0.5) - self.projection_shifts[:, 0]
150
            CenterOffset[:, 1] = -self.projection_shifts[:, 1] - 0.5
151
152
        # set parameters and initiate a TomoBar class object for iterative reconstruction
153
        RectoolsIter = RecToolsIR(DetectorsDimH=self.Horiz_det,  # DetectorsDimH # detector dimension (horizontal)
154
                                   DetectorsDimV=self.Vert_det,   # DetectorsDimV # detector dimension (vertical) for 3D case only
155
                                   CenterRotOffset=CenterOffset,  # The center of rotation combined with  the shift offsets
156
                                   AnglesVec=-self.anglesRAD,  # the vector of angles in radians
157
                                   ObjSize=self.vol_shape[0],  # a scalar to define the reconstructed object dimensions
158
                                   datafidelity=self.parameters['data_fidelity'], # data fidelity, choose LS, PWLS, SWLS
159
                                   device_projector=self.parameters['GPU_index'])
160
161
        # set parameters and initiate a TomoBar class object for direct reconstruction
162
        RectoolsDIR = RecToolsDIR(DetectorsDimH=self.Horiz_det,  # DetectorsDimH # detector dimension (horizontal)
163
                                   DetectorsDimV=self.Vert_det,  # DetectorsDimV # detector dimension (vertical) for 3D case only
164
                                   CenterRotOffset=CenterOffset,  # The center of rotation combined with the shift offsets
165
                                   AnglesVec=-self.anglesRAD,  # the vector of angles in radians
166
                                   ObjSize=self.vol_shape[0],  # a scalar to define the reconstructed object dimensions
167
                                   device_projector=self.parameters['GPU_index'])
168
169
        if self.parameters['reconstruction_method'] == 'FBP3D':
170
            recon = RectoolsDIR.FBP(projdata3D) #perform FBP3D
171
172
        if self.parameters['reconstruction_method'] == 'CGLS3D':
173
            # Run CGLS 3D reconstruction algorithm here
174
            self._algorithm_.update({'lipschitz_const': None})
175
            recon = RectoolsIter.CGLS(self._data_, self._algorithm_)
176
177
        if self.parameters['reconstruction_method'] == 'SIRT3D':
178
            # Run SIRT 3D reconstruction algorithm here
179
            self._algorithm_.update({'lipschitz_const': None})
180
            recon = RectoolsIter.SIRT(self._data_, self._algorithm_)
181
182
        if self.parameters['reconstruction_method'] == 'FISTA3D':
183
            if self.parameters['regularisation_method'] == 'PD_TV':
184
	            self._regularisation_.update({'device_regulariser': self.parameters['GPU_index']})
185
            # if one selects PWLS or SWLS models then raw data is also required (2 inputs)
186
            if (self.parameters['data_fidelity'] == 'PWLS') or (self.parameters['data_fidelity'] == 'SWLS'):
187
                rawdata3D = data[1].astype(np.float32)
188
                rawdata3D[rawdata3D > 10 ** 15] = 0.0
189
                rawdata3D = np.swapaxes(rawdata3D, 0, 1) / np.max(np.float32(rawdata3D))
190
                self._data_.update({'projection_raw_data': rawdata3D})
191
                self._data_.update({'beta_SWLS': self.parameters['data_beta_SWLS'] * np.ones(self.Horiz_det)})
192
            # Run FISTA reconstruction algorithm here
193
            recon = RectoolsIter.FISTA(self._data_, self._algorithm_, self._regularisation_)
194
195
        recon = np.swapaxes(recon, 0, 1)
0 ignored issues
show
introduced by
The variable recon does not seem to be defined in case SubscriptNode == 'FBP3D' on line 169 is False. Are you sure this can never be the case?
Loading history...
196
        return recon
197
198
    def nInput_datasets(self):
199
        return max(len(self.parameters['in_datasets']), 1)
200
201
    # total number of output datasets
202
    def nOutput_datasets(self):
203
        if check_if_end_plugin_in_iterate_group(self.exp):
204
            return 2
205
        else:
206
            return 1
207
208
    # total number of output datasets that are clones
209
    def nClone_datasets(self):
210
        if check_if_end_plugin_in_iterate_group(self.exp):
211
            return 1
212
        else:
213
            return 0
214
215
    def _set_max_frames(self, frames):
216
        self._max_frames = frames
217