Test Failed
Push — master ( 9ba79c...17f3e3 )
by Yousef
01:54 queued 19s
created

BaseAstraVectorRecon.setup()   B

Complexity

Conditions 6

Size

Total Lines 40
Code Lines 25

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 6
eloc 25
nop 1
dl 0
loc 40
rs 8.3466
c 0
b 0
f 0
1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: base_astra_vector_recon
17
   :platform: Unix
18
   :synopsis: A base for Astra toolbox reconstruction algorithms using vector geometry
19
.. moduleauthor:: Mark Basham <[email protected]>
20
"""
21
22
import astra
23
import numpy as np
24
25
from savu.plugins.reconstructions.base_recon import BaseRecon
26
from savu.core.iterate_plugin_group_utils import enable_iterative_loop, \
27
    check_if_end_plugin_in_iterate_group
28
29
class BaseAstraVectorRecon(BaseRecon):
30
    """
31
    A Plugin to perform Astra toolbox reconstruction using vector geometry
32
33
    :u*param n_iterations: Number of Iterations - only valid for iterative \
34
        algorithms. Default: 1.
35
    """
36
37
    def __init__(self, name='BaseAstraVectorRecon'):
38
        super(BaseAstraVectorRecon, self).__init__(name)
39
        self.res = False
40
41
    # total number of output datasets that are clones
42
    def nClone_datasets(self):
43
        if check_if_end_plugin_in_iterate_group(self.exp):
44
            return 1
45
        else:
46
            return 0
47
48
    @enable_iterative_loop
49
    def setup(self):
50
        self.alg = self.parameters['algorithm']
51
        self.get_max_frames = \
52
            self._get_multiple if '3D' in self.alg else self._get_single
53
54
        super(BaseAstraVectorRecon, self).setup()
55
        out_dataset = self.get_out_datasets()
56
57
        # if res_norm is required then setup another output dataset
58
        if self.parameters['res_norm'] and self.nClone_datasets() == 1:
59
            err_str = "The res_norm output dataset has not yet been " \
60
                "implemented for when AstraReconGpu is at the end of an " \
61
                "iterative loop"
62
            raise ValueError(err_str)
63
        elif self.parameters['res_norm']:
64
            self.res = True
65
            out_pData = self.get_plugin_out_datasets()
66
            in_data = self.get_in_datasets()[0]
67
            dim_detX = \
68
                in_data.get_data_dimension_by_axis_label('y', contains=True)
69
70
            nIts = self.parameters['n_iterations']
71
            nIts = nIts if isinstance(nIts, list) else [nIts]
72
            self.len_res = max(nIts)
73
            shape = (in_data.get_shape()[dim_detX], max(nIts))
74
75
            label = ['vol_y.voxel', 'iteration.number']
76
            #pattern = {'name': 'SINOGRAM', 'slice_dims': (0,),
77
            #           'core_dims': (1,)}
78
79
            out_dataset[1].create_dataset(axis_labels=label, shape=shape)
80
            """
81
            out_dataset[1].add_pattern(pattern['name'],
82
                                       slice_dims=pattern['slice_dims'],
83
                                       core_dims=pattern['core_dims'])
84
            """
85
            out_dataset[1].add_pattern(
86
                "METADATA", core_dims=(1,), slice_dims=(0,))
87
            out_pData[1].plugin_data_setup('METADATA', self.get_max_frames())
88
89
    def pre_process(self):
90
        self.alg = self.parameters['algorithm']
91
        self.iters = self.parameters['n_iterations']
92
93
        if '3D' in self.alg:
94
            self.setup_3D()
95
            self.process_frames = self.astra_3D_vector_recon
96
        else:
97
            self.setup_2D()
98
            self.process_frames = self.astra_2D_vector_recon
99
100 View Code Duplication
    def setup_2D(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
101
        pData = self.get_plugin_in_datasets()[0]
102
        self.dim_detX = \
103
            pData.get_data_dimension_by_axis_label('x', contains=True)
104
        self.dim_rot = \
105
            pData.get_data_dimension_by_axis_label('rot', contains=True)
106
107
        self.sino_shape = pData.get_shape()
108
        self.nDims = len(self.sino_shape)
109
        self.nCols = self.sino_shape[self.dim_detX]
110
        self.set_mask(self.sino_shape)
111
112
    def setup_3D(self):
113
        pData = self.get_plugin_in_datasets()[0]
114
        self.sino_dim_detX = \
115
            pData.get_data_dimension_by_axis_label('x', contains=True)
116
        self.sino_dim_detY = \
117
            pData.get_data_dimension_by_axis_label('y', contains=True)
118
        self.det_rot = \
119
            pData.get_data_dimension_by_axis_label('angle', contains=True)
120
        self.sino_shape = pData.get_shape()
121
        self.nDims = len(self.sino_shape)
122
        #self.nCols = self.sino_shape[self.sino_dim_detX]
123
        self.slice_dir = pData.get_slice_dimension()
124
        #self.slice_func = self.slice_sino(self.nDims)
125
        """
126
        l = self.sino_shape[self.sino_dim_detX]
127
        c = np.linspace(-l/2.0, l/2.0, l)
128
        x, y = np.meshgrid(c, c)
129
        self.mask_id = False
130
        mask = np.array((x**2 + y**2 < (l/2.0)**2), dtype=np.float)
131
        self.mask = np.transpose(
132
            np.tile(mask, (self.get_max_frames(), 1, 1)), (1, 0, 2))
133
        self.manual_mask = True if not self.parameters['sino_pad'] else False
134
        """
135
136
    def set_mask(self, shape):
137
        l = self.get_plugin_out_datasets()[0].get_shape()[0]
138
        c = np.linspace(-l / 2.0, l / 2.0, l)
139
        x, y = np.meshgrid(c, c)
140
141
        ratio = self.parameters['ratio']
142
        if isinstance(ratio, list) or isinstance(ratio, tuple):
143
            ratio_mask = ratio[0]
144
            outer_mask = ratio[1]
145
            if isinstance(outer_mask, str):
146
                if self.parameters['outer_pad'] is True:
147
                    outer_mask = 1.0
148
                else:
149
                    outer_mask = 0.0
150
        else:
151
            ratio_mask = ratio
152
            if self.parameters['outer_pad'] is True:
153
                outer_mask = 1.0
154
            else:
155
                outer_mask = 0.0
156
        r = (l - 1) * ratio_mask
157
        outer_pad = True if self.parameters['outer_pad'] and self.padding_alg\
158
            else False
159
        if not outer_pad:
160
            self.manual_mask = \
161
                np.array((x**2 + y**2 < (r / 2.0)**2), dtype=np.float)
162
            self.manual_mask[self.manual_mask == 0] = outer_mask
163
        else:
164
            self.manual_mask = False
165
166 View Code Duplication
    def set_config(self, rec_id, sino_id, proj_geom, vol_geom):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
167
        cfg = astra.astra_dict(self.alg)
168
        cfg['ReconstructionDataId'] = rec_id
169
        cfg['ProjectionDataId'] = sino_id
170
        if 'FBP' in self.alg:
171
            fbp_filter = self.parameters['FBP_filter'] if 'FBP_filter' in \
172
                self.parameters.keys() else 'none'
173
            cfg['FilterType'] = fbp_filter
174
        if 'projector' in self.parameters.keys():
175
            proj_id = astra.create_projector(
176
                self.parameters['projector'], proj_geom, vol_geom)
177
            cfg['ProjectorId'] = proj_id
178
        cfg = self.set_options(cfg)
179
        return cfg
180
181
    def delete(self, alg_id, sino_id, rec_id, proj_id):
182
        astra.algorithm.delete(alg_id)
183
        astra.data2d.delete(sino_id)
184
        astra.data2d.delete(rec_id)
185
        if proj_id:
186
            astra.projector.delete(proj_id)
187
188
    def _get_single(self):
189
        return 'single'
190
191
    def _get_multiple(self):
192
        return 'multiple'
193