Test Failed
Pull Request — master (#888)
by Daniil
03:57
created

BaseAstraRecon.setup()   B

Complexity

Conditions 7

Size

Total Lines 36
Code Lines 28

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 7
eloc 28
nop 1
dl 0
loc 36
rs 7.808
c 0
b 0
f 0
1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: base_astra_recon
17
   :platform: Unix
18
   :synopsis: A base for all Astra toolbox reconstruction algorithms
19
.. moduleauthor:: Mark Basham <[email protected]>
20
"""
21
22
import astra
23
import numpy as np
24
25
from savu.plugins.reconstructions.base_recon import BaseRecon
26
from savu.core.iterate_plugin_group_utils import enable_iterative_loop, \
27
    check_if_end_plugin_in_iterate_group
28
29
class BaseAstraRecon(BaseRecon):
30
31
    def __init__(self, name='BaseAstraRecon'):
32
        super(BaseAstraRecon, self).__init__(name)
33
        self.res = False
34
35
    # total number of output datasets
36
    def nOutput_datasets(self):
37
        if check_if_end_plugin_in_iterate_group(self.exp):
38
            return 2
39
        else:
40
            return 1
41
42
    # total number of output datasets that are clones
43
    def nClone_datasets(self):
44
        if check_if_end_plugin_in_iterate_group(self.exp):
45
            return 1
46
        else:
47
            return 0
48
49
    @enable_iterative_loop
50
    def setup(self):
51
        self.alg = self.parameters['algorithm']
52
        self.get_max_frames = self._get_multiple if '3D' in self.alg else self._get_single
53
54
        super(BaseAstraRecon, self).setup()
55
        out_dataset = self.get_out_datasets()
56
57
        # if res_norm is required then setup another output dataset
58
        if len(out_dataset) == 3 and self.nClone_datasets() == 1:
59
            err_str = "The res_norm output dataset has not yet been " \
60
                "implemented for when AstraReconCpu is at the end of an " \
61
                "iterative loop"
62
            raise ValueError(err_str)
63
        elif len(out_dataset) == 2 and self.nClone_datasets() == 0:
64
            self.res = True
65
            out_pData = self.get_plugin_out_datasets()
66
            in_data = self.get_in_datasets()[0]
67
            dim_detX = \
68
                in_data.get_data_dimension_by_axis_label('y', contains=True)
69
70
            nIts = self.parameters['n_iterations']
71
            nIts = nIts if isinstance(nIts, list) else [nIts]
72
            self.len_res = max(nIts)
73
            shape = (in_data.get_shape()[dim_detX], max(nIts))
74
75
            label = ['vol_y.voxel', 'iteration.number']
76
            pattern = {'name': 'SINOGRAM', 'slice_dims': (0,),
77
                       'core_dims': (1,)}
78
79
            out_dataset[1].create_dataset(axis_labels=label, shape=shape)
80
            out_dataset[1].add_pattern(pattern['name'],
81
                                       slice_dims=pattern['slice_dims'],
82
                                       core_dims=pattern['core_dims'])
83
            out_pData[1].plugin_data_setup(
84
                pattern['name'], self.get_max_frames())
85
86
    def pre_process(self):
87
        self.alg = self.parameters['algorithm']
88
        self.iters = self.parameters['n_iterations']
89
90
        if '3D' in self.alg:
91
            self.setup_3D()
92
            self.process_frames = self.astra_3D_recon
93
        else:
94
            self.setup_2D()
95
            self.process_frames = self.astra_2D_recon
96
97 View Code Duplication
    def setup_2D(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
98
        pData = self.get_plugin_in_datasets()[0]
99
        self.dim_detX = \
100
            pData.get_data_dimension_by_axis_label('x', contains=True)
101
        self.dim_rot = \
102
            pData.get_data_dimension_by_axis_label('rot', contains=True)
103
104
        self.sino_shape = pData.get_shape()
105
        self.nDims = len(self.sino_shape)
106
        self.nCols = self.sino_shape[self.dim_detX]
107
        self.set_mask(self.sino_shape)
108
109
    def set_mask(self, shape):
110
        l = self.get_plugin_out_datasets()[0].get_shape()[0]
111
        c = np.linspace(-l / 2.0, l / 2.0, l)
112
        x, y = np.meshgrid(c, c)
113
114
        ratio = self.parameters['ratio']
115
        if isinstance(ratio, list) or isinstance(ratio, tuple):
116
            ratio_mask = ratio[0]
117
            outer_mask = ratio[1]
118
            if isinstance(outer_mask, str):
119
                outer_mask = np.nan
120
        else:
121
            ratio_mask = ratio
122
            outer_mask = np.nan
123
        r = (l - 1) * ratio_mask
124
        outer_pad = True if self.parameters['outer_pad'] and self.padding_alg\
125
            else False
126
        if not outer_pad:
127
            self.manual_mask = \
128
                np.array((x**2 + y**2 < (r / 2.0)**2), dtype=np.float)
129
            self.manual_mask[self.manual_mask == 0] = outer_mask
130
        else:
131
            self.manual_mask = False
132
133
    def astra_2D_recon(self, data):
134
        sino = data[0]
135
        cor, angles, vol_shape, init = self.get_frame_params()
136
        angles = np.deg2rad(angles)
137
        if self.res:
138
            res = np.zeros(self.len_res)
139
        # create volume geom
140
        vol_geom = astra.create_vol_geom(vol_shape)
141
        # create projection geom
142
        det_width = sino.shape[self.dim_detX]
143
        proj_geom = astra.create_proj_geom('parallel', 1.0, det_width, angles)
144
        sino = np.transpose(sino, (self.dim_rot, self.dim_detX))
145
146
        # create sinogram id
147
        sino_id = astra.data2d.create("-sino", proj_geom, sino)
148
        # create reconstruction id
149
        if init is not None:
150
            rec_id = astra.data2d.create('-vol', vol_geom, init)
151
        else:
152
            rec_id = astra.data2d.create('-vol', vol_geom)
153
154
#        if self.mask_id:
155
#            self.mask_id = astra.data2d.create('-vol', vol_geom, self.mask)
156
        # setup configuration options
157
        cfg = self.set_config(rec_id, sino_id, proj_geom, vol_geom)
158
        # create algorithm id
159
        alg_id = astra.algorithm.create(cfg)
160
        # run algorithm
161
        if self.res:
162
            for j in range(self.iters):
163
                # Run a single iteration
164
                astra.algorithm.run(alg_id, 1)
165
                res[j] = astra.algorithm.get_res_norm(alg_id)
166
        else:
167
            astra.algorithm.run(alg_id, self.iters)
168
        # get reconstruction matrix
169
170
        if self.manual_mask is not False:
171
            recon = self.manual_mask * astra.data2d.get(rec_id)
172
        else:
173
            recon = astra.data2d.get(rec_id)
174
175
        # delete geometry
176
        self.delete(alg_id, sino_id, rec_id, False)
177
        return [recon, res] if self.res else recon
0 ignored issues
show
introduced by
The variable res does not seem to be defined in case self.res on line 137 is False. Are you sure this can never be the case?
Loading history...
178
179 View Code Duplication
    def set_config(self, rec_id, sino_id, proj_geom, vol_geom):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
180
        cfg = astra.astra_dict(self.alg)
181
        cfg['ReconstructionDataId'] = rec_id
182
        cfg['ProjectionDataId'] = sino_id
183
        if 'FBP' in self.alg:
184
            fbp_filter = self.parameters['FBP_filter'] if 'FBP_filter' in \
185
                list(self.parameters.keys()) else 'none'
186
            cfg['FilterType'] = fbp_filter
187
        if 'projector' in list(self.parameters.keys()):
188
            proj_id = astra.create_projector(
189
                self.parameters['projector'], proj_geom, vol_geom)
190
            cfg['ProjectorId'] = proj_id
191
        cfg = self.set_options(cfg)
192
        return cfg
193
194
    def delete(self, alg_id, sino_id, rec_id, proj_id):
195
        astra.algorithm.delete(alg_id)
196
        astra.data2d.delete(sino_id)
197
        astra.data2d.delete(rec_id)
198
        if proj_id:
199
            astra.projector.delete(proj_id)
200
201
    def _get_single(self):
202
        return 'single'
203
204
    def _get_multiple(self):
205
        return 'multiple'
206