Test Failed
Push — master ( 9ba79c...17f3e3 )
by Yousef
01:54 queued 19s
created

AstraReconGpu.astra_setup()   A

Complexity

Conditions 2

Size

Total Lines 5
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 5
nop 1
dl 0
loc 5
rs 10
c 0
b 0
f 0
1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: astra_recon_gpu
17
   :platform: Unix
18
   :synopsis: Wrapper around the Astra toolbox for gpu reconstruction using vector geometry
19
.. moduleauthor:: Mark Basham <[email protected]>
20
21
"""
22
import astra
23
import numpy as np
24
25
from savu.plugins.reconstructions.astra_recons.base_astra_vector_recon \
26
    import BaseAstraVectorRecon
27
from savu.plugins.driver.gpu_plugin import GpuPlugin
28
from savu.plugins.utils import register_plugin
29
from savu.core.iterate_plugin_group_utils import \
30
    check_if_end_plugin_in_iterate_group
31
32
33
@register_plugin
34
class AstraReconGpu(BaseAstraVectorRecon, GpuPlugin):
35
36
    def __init__(self):
37
        super(AstraReconGpu, self).__init__("AstraReconGpu")
38
        self.GPU_index = None
39
        self.res = False
40
        self.start = 0
41
42
    def set_options(self, cfg):
43
        if 'option' not in cfg.keys():
44
            cfg['option'] = {}
45
        cfg['option']['GPUindex'] = self.parameters['GPU_index']
46
        return cfg
47
48
    # total number of output datasets
49
    def nOutput_datasets(self):
50
        alg = self.parameters['algorithm']
51
        if self.parameters['res_norm'] is True and 'FBP' not in alg \
52
            and check_if_end_plugin_in_iterate_group(self.exp):
53
            err_str = "The res_norm output dataset has not yet been " \
54
                "implemented for when AstraReconGpu is at the end of an " \
55
                "iterative loop"
56
            raise ValueError(err_str)
57
        elif self.parameters['res_norm'] is True and 'FBP' not in alg:
58
            self.res = True
59
            self.parameters['out_datasets'].append('res_norm')
60
            return 2
61
        elif check_if_end_plugin_in_iterate_group(self.exp):
62
            return 2
63
        else:
64
            return 1
65
66
    def astra_setup(self):
67
        options_list = ["FBP_CUDA", "SIRT_CUDA", "SART_CUDA", "CGLS_CUDA",
68
                        "FP_CUDA", "BP_CUDA", "BP3D_CUDA", "FBP3D_CUDA", "SIRT3D_CUDA", "CGLS3D_CUDA"]
69
        if not options_list.count(self.parameters['algorithm']):
70
            raise Exception("Unknown Astra GPU algorithm.")
71
72
    def astra_2D_vector_recon(self, data):
73
        sino = data[0]
74
        cor, angles, vol_shape, init = self.get_frame_params()
75
        if self.res:
76
            res = np.zeros(self.len_res)
77
        # create volume geom
78
        vol_geom = astra.create_vol_geom(vol_shape)
79
        # create projection geom
80
        det_width = sino.shape[self.dim_detX]
81
        half_det_width = 0.5 * det_width
82
        cor_astra_scalar = half_det_width - cor
83
        # set parallel beam vector geometry
84
        vectors = self.vec_geom_init2D(np.deg2rad(angles), 1.0, cor_astra_scalar - 0.5)
85
        try:
86
            # vector geometry (astra > 1.9v)
87
            proj_geom = astra.create_proj_geom('parallel_vec', det_width, vectors)
88
        except:
89
            print('Warning: using scalar geometry since the Astra version <1.9 does not support the vector one for 2D')
90
            proj_geom = astra.create_proj_geom('parallel', 1.0, det_width, angles)
91
        sino = np.transpose(sino, (self.dim_rot, self.dim_detX))
92
93
        # Create a data object to hold the sinogram data
94
        sino_id = astra.data2d.create('-sino', proj_geom, sino)
95
96
        # create reconstruction id
97
        if init is not None:
98
            rec_id = astra.data2d.create('-vol', vol_geom, init)
99
        else:
100
            rec_id = astra.data2d.create('-vol', vol_geom)
101
102
        #        if self.mask_id:
103
        #            self.mask_id = astra.data2d.create('-vol', vol_geom, self.mask)
104
        # setup configuration options
105
        cfg = self.set_config(rec_id, sino_id, proj_geom, vol_geom)
106
        # create algorithm id
107
        alg_id = astra.algorithm.create(cfg)
108
        # run algorithm
109
        if self.res:
110
            for j in range(self.iters):
111
                # Run a single iteration
112
                astra.algorithm.run(alg_id, 1)
113
                res[j] = astra.algorithm.get_res_norm(alg_id)
114
        else:
115
            astra.algorithm.run(alg_id, self.iters)
116
        # get reconstruction matrix
117
118
        if self.manual_mask is not False:
119
            recon = self.manual_mask * astra.data2d.get(rec_id)
120
        else:
121
            recon = astra.data2d.get(rec_id)
122
123
        # delete geometry
124
        self.delete(alg_id, sino_id, rec_id, False)
125
        return [recon, res] if self.res else recon
0 ignored issues
show
introduced by
The variable res does not seem to be defined in case self.res on line 75 is False. Are you sure this can never be the case?
Loading history...
126
127
    def astra_3D_vector_recon(self, data):
128
        proj_data3d = data[0]  # get 3d block of projection data
129
        cor, angles, vol_shape, init = self.get_frame_params()
130
        projection_shifts2d = self.get_frame_shifts()
131
        half_det_width = 0.5 * proj_data3d.shape[self.sino_dim_detX]
132
        cor_astra_scalar = half_det_width - np.mean(cor)  # works with scalar CoR only atm
133
134
        recon = np.zeros(vol_shape)
135
        recon = np.expand_dims(recon, axis=self.slice_dir)
136
        if self.res:
137
            res = np.zeros((self.vol_shape[self.slice_dir], self.iters))
138
139
        # create volume geometry
140
        vol_geom = \
141
            astra.create_vol_geom(vol_shape[0], vol_shape[2], vol_shape[1])
142
143
        # define astra vector geometry for 3d case
144
        vectors3d = self.vec_geom_init3D(np.deg2rad(angles + 90.0), 1.0, 1.0, cor_astra_scalar - 0.5,
145
                                         projection_shifts2d)
146
        proj_geom = astra.create_proj_geom('parallel3d_vec',
147
                                           proj_data3d.shape[self.sino_dim_detY],
148
                                           proj_data3d.shape[self.sino_dim_detX],
149
                                           vectors3d)
150
151
        proj_data3d = np.swapaxes(proj_data3d, 0, 1)
152
        if self.parameters['algorithm'] == "FBP3D_CUDA":
153
            # pre-filter projection data
154
            proj_data3d = self.filtersinc3d(proj_data3d)
155
156
        # create projection data id
157
        proj_id = astra.data3d.create("-sino", proj_geom, proj_data3d)
158
159
        # create reconstruction id
160
        if init is not None:
161
            rec_id = astra.data3d.create('-vol', vol_geom, init)
162
        else:
163
            rec_id = astra.data3d.create('-vol', vol_geom)
164
165
        # setup configuration options
166
        cfg = self.set_config(rec_id, proj_id, proj_geom, vol_geom)
167
168
        if self.parameters['algorithm'] == "FBP3D_CUDA":
169
            cfg['type'] = 'BP3D_CUDA'
170
171
        # create algorithm id
172
        alg_id = astra.algorithm.create(cfg)
173
174
        # run algorithm
175
        if self.res:
176
            for j in range(self.iters):
177
                # Run a single iteration
178
                astra.algorithm.run(alg_id, 1)
179
                res[j] = astra.algorithm.get_res_norm(alg_id)
180
        else:
181
            astra.algorithm.run(alg_id, self.iters)
182
183
        # get reconstruction matrix
184
        # if self.manual_mask:
185
        #    recon = self.mask*astra.data3d.get(rec_id)
186
        # else:
187
        #    recon = astra.data3d.get(rec_id)
188
189
        recon = np.transpose(astra.data3d.get(rec_id), (2, 0, 1))
190
191
        # delete geometry
192
        self.delete(alg_id, proj_id, rec_id, False)
193
194
        self.start += 1
195
        if self.res:
196
            return [recon, res]
0 ignored issues
show
introduced by
The variable res does not seem to be defined in case self.res on line 136 is False. Are you sure this can never be the case?
Loading history...
197
        else:
198
            return recon
199
200
    def rotation_matrix2D(self, theta):
201
        # define 2D rotation matrix
202
        return np.array([[np.cos(theta), -np.sin(theta)],
203
                         [np.sin(theta), np.cos(theta)]])
204
205
    def rotation_matrix3D(self, theta):
206
        # define 3D rotation matrix
207
        return np.array([[np.cos(theta), -np.sin(theta), 0.0],
208
                         [np.sin(theta), np.cos(theta), 0.0],
209
                         [0.0, 0.0, 1.0]])
210
211
    def vec_geom_init2D(self, angles_rad, DetectorSpacingX, CenterRotOffset):
212
        # define 2D vector geometry
213
        s0 = [0.0, -1.0]  # source
214
        d0 = [CenterRotOffset, 0.0]  # detector
215
        u0 = [DetectorSpacingX, 0.0]  # detector coordinates
216
217
        vectors = np.zeros([angles_rad.size, 6])
218
        for i in range(0, angles_rad.size):
219
            theta = angles_rad[i]
220
            vec_temp = np.dot(self.rotation_matrix2D(theta), s0)
221
            vectors[i, 0:2] = vec_temp[:]  # ray position
222
            vec_temp = np.dot(self.rotation_matrix2D(theta), d0)
223
            vectors[i, 2:4] = vec_temp[:]  # center of detector position
224
            vec_temp = np.dot(self.rotation_matrix2D(theta), u0)
225
            vectors[i, 4:6] = vec_temp[:]  # detector pixel (0,0) to (0,1).
226
        return vectors
227
228
    def vec_geom_init3D(self, angles_rad, DetectorSpacingX, DetectorSpacingY, CenterRotOffset, projection_shifts2d):
229
        # define 3D vector geometry
230
        s0 = [0.0, -1.0, 0.0]  # source
231
        u0 = [DetectorSpacingX, 0.0, 0.0]  # detector coordinates
232
        v0 = [0.0, 0.0, DetectorSpacingY]  # detector coordinates
233
234
        vectors = np.zeros([angles_rad.size, 12])
235
        for i in range(0, angles_rad.size):
236
            d0 = [CenterRotOffset - projection_shifts2d[i, 0], 0.0,
237
                  -projection_shifts2d[i, 1] - 0.5]  # detector
238
            theta = angles_rad[i]
239
            vec_temp = np.dot(self.rotation_matrix3D(theta), s0)
240
            vectors[i, 0:3] = vec_temp[:]  # ray position
241
            vec_temp = np.dot(self.rotation_matrix3D(theta), d0)
242
            vectors[i, 3:6] = vec_temp[:]  # center of detector position
243
            vec_temp = np.dot(self.rotation_matrix3D(theta), u0)
244
            vectors[i, 6:9] = vec_temp[:]  # detector pixel (0,0) to (0,1).
245
            vec_temp = np.dot(self.rotation_matrix3D(theta), v0)
246
            vectors[i, 9:12] = vec_temp[:]  # Vector from detector pixel (0,0) to (1,0)
247
        return vectors
248
249
    def filtersinc3d(self, projection3d):
250
        import scipy.fftpack
251
        # applies a sinc filter to 3D projection data
252
        # Data format [DetectorVert, Projections, DetectorHoriz]
253
        # adopted from Matlabs code by  Waqas Akram
254
        # "a":	This parameter varies the filter magnitude response.
255
        # When "a" is very small (a<<1), the response approximates |w|
256
        # As "a" is increased, the filter response starts to
257
        # roll off at high frequencies.
258
        a = 1.1
259
        [DetectorsLengthV, projectionsNum, DetectorsLengthH] = np.shape(projection3d)
260
        w = np.linspace(-np.pi, np.pi - (2 * np.pi) / DetectorsLengthH, DetectorsLengthH, dtype='float32')
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DetectorsLengthH does not seem to be defined.
Loading history...
261
262
        rn1 = np.abs(2.0 / a * np.sin(a * w / 2.0))
263
        rn2 = np.sin(a * w / 2.0)
264
        rd = (a * w) / 2.0
265
        rd_c = np.zeros([1, DetectorsLengthH])
266
        rd_c[0, :] = rd
267
        r = rn1 * (np.dot(rn2, np.linalg.pinv(rd_c))) ** 2
268
        multiplier = (1.0 / projectionsNum)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable projectionsNum does not seem to be defined.
Loading history...
269
        f = scipy.fftpack.fftshift(r)
270
        filtered = np.zeros(np.shape(projection3d))
271
272
        for j in range(0, DetectorsLengthV):
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DetectorsLengthV does not seem to be defined.
Loading history...
273
            for i in range(0, projectionsNum):
274
                IMG = scipy.fftpack.fft(projection3d[j, i, :])
275
                filtered[j, i, :] = multiplier * np.real(scipy.fftpack.ifft(IMG * f))
276
        return np.float32(filtered)
277