Test Failed
Pull Request — master (#845)
by Daniil
03:47
created

savu.plugins.loaders.base_tomophantom_loader   A

Complexity

Total Complexity 24

Size/Duplication

Total Lines 194
Duplicated Lines 7.73 %

Importance

Changes 0
Metric Value
eloc 134
dl 15
loc 194
rs 10
c 0
b 0
f 0
wmc 24

8 Methods

Rating   Name   Duplication   Size   Complexity  
A BaseTomophantomLoader.setup() 0 35 1
C BaseTomophantomLoader.__get_backing_file() 0 65 9
A BaseTomophantomLoader._get_n_entries() 0 2 1
A BaseTomophantomLoader.__get_start_slice_list() 15 15 5
A BaseTomophantomLoader.__convert_patterns() 0 15 3
A BaseTomophantomLoader.__parameter_checks() 0 4 2
A BaseTomophantomLoader._set_metadata() 0 8 2
A BaseTomophantomLoader.__init__() 0 2 1

How to fix   Duplicated Code   

Duplicated Code

Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.

Common duplication problems, and corresponding solutions are:

1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: base_tomophantom_loader
17
   :platform: Unix
18
   :synopsis: A loader that generates synthetic 3D projection full-field tomo data\
19
        as hdf5 dataset of any size.
20
21
.. moduleauthor:: Daniil Kazantsev <[email protected]>
22
"""
23
24
import os
25
import h5py
26
import logging
27
import numpy as np
28
29
from savu.data.chunking import Chunking
30
from savu.plugins.utils import register_plugin
31
from savu.plugins.loaders.base_loader import BaseLoader
32
from savu.plugins.savers.utils.hdf5_utils import Hdf5Utils
33
34
import tomophantom
35
from tomophantom import TomoP2D, TomoP3D
36
import os
37
import numpy as np
38
39
@register_plugin
40
class BaseTomophantomLoader(BaseLoader):
41
    def __init__(self, name='BaseTomophantomLoader'):
42
        super(BaseTomophantomLoader, self).__init__(name)
43
44
    def setup(self):
45
        exp = self.exp
46
        data_obj = exp.create_data_object('in_data', 'synth_proj_data')
47
48
        data_obj.set_axis_labels(*self.parameters['axis_labels'])
49
        self.__convert_patterns(data_obj,'synth_proj_data')
50
        self.__parameter_checks(data_obj)
51
52
        self.tomo_model = self.parameters['tomo_model']
53
        # setting angles for parallel beam geometry
54
        self.angles = np.linspace(0.0,180.0-(1e-14), self.parameters['proj_data_dims'][0], dtype='float32')
55
        path = os.path.dirname(tomophantom.__file__)
56
        self.path_library3D = os.path.join(path, "Phantom3DLibrary.dat")
57
58
        data_obj.backing_file = self.__get_backing_file(data_obj, 'synth_proj_data')
59
        data_obj.data = data_obj.backing_file['/']['test']
60
        data_obj.data.dtype # Need to do something to .data to keep the file open!
61
62
        # create a phantom file
63
        data_obj2 = exp.create_data_object('in_data', 'phantom')
64
65
        data_obj2.set_axis_labels(*self.parameters['axis_labels_phantom'])
66
        self.__convert_patterns(data_obj2, 'phantom')
67
        self.__parameter_checks(data_obj2)
68
69
        data_obj2.backing_file = self.__get_backing_file(data_obj2, 'phantom')
70
        data_obj2.data = data_obj2.backing_file['/']['test']
71
        data_obj2.data.dtype # Need to do something to .data to keep the file open!
72
73
        data_obj.set_shape(data_obj.data.shape)
74
        self.n_entries = data_obj.get_shape()[0]
75
        cor_val=0.5*(self.parameters['proj_data_dims'][2])
76
        self.cor=np.linspace(cor_val, cor_val, self.parameters['proj_data_dims'][1], dtype='float32')
77
        self._set_metadata(data_obj, self._get_n_entries())
78
        return data_obj,data_obj2
79
80
    def __get_backing_file(self, data_obj, file_name):
81
        fname = '%s/%s.h5' % \
82
            (self.exp.get('out_path'), file_name)
83
84
        if os.path.exists(fname):
85
            return h5py.File(fname, 'r')
86
87
        self.hdf5 = Hdf5Utils(self.exp)
88
89
        dims_temp = self.parameters['proj_data_dims'].copy()
90
        proj_data_dims = tuple(dims_temp)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable tuple does not seem to be defined.
Loading history...
91
        if (file_name == 'phantom'):
92
            dims_temp[0]=dims_temp[1]
93
            dims_temp[2]=dims_temp[1]
94
            proj_data_dims = tuple(dims_temp)
95
96
        patterns = data_obj.get_data_patterns()
97
        p_name = list(patterns.keys())[0]
98
        p_dict = patterns[p_name]
99
        p_dict['max_frames_transfer'] = 1
100
        nnext = {p_name: p_dict}
101
102
        pattern_idx = {'current': nnext, 'next': nnext}
103
        chunking = Chunking(self.exp, pattern_idx)
104
        chunks = chunking._calculate_chunking(proj_data_dims, np.int16)
105
106
        h5file = self.hdf5._open_backing_h5(fname, 'w')
107
        dset = h5file.create_dataset('test', proj_data_dims, chunks=chunks)
108
109
        self.exp._barrier()
110
111
        slice_dirs = list(nnext.values())[0]['slice_dims']
112
        nDims = len(dset.shape)
113
        total_frames = np.prod([dset.shape[i] for i in slice_dirs])
114
        sub_size = \
115
            [1 if i in slice_dirs else dset.shape[i] for i in range(nDims)]
116
117
        # need an mpi barrier after creating the file before populating it
118
        idx = 0
119
        sl, total_frames = \
120
            self.__get_start_slice_list(slice_dirs, dset.shape, total_frames)
121
        # calculate the first slice
122
        for i in range(total_frames):
123
            if (file_name == 'synth_proj_data'):
124
                #generate projection data
125
                gen_data = TomoP3D.ModelSinoSub(self.tomo_model, proj_data_dims[1], proj_data_dims[2], proj_data_dims[1], (i, i+1), -self.angles, self.path_library3D)
126
            else:
127
                #generate phantom data
128
                gen_data = TomoP3D.ModelSub(self.tomo_model, proj_data_dims[1], (i, i+1), self.path_library3D)
129
            dset[tuple(sl)] = np.swapaxes(gen_data,0,1)
130
            if sl[slice_dirs[idx]].stop == dset.shape[slice_dirs[idx]]:
131
                idx += 1
132
                if idx == len(slice_dirs):
133
                    break
134
            tmp = sl[slice_dirs[idx]]
135
            sl[slice_dirs[idx]] = slice(tmp.start+1, tmp.stop+1)
136
137
        self.exp._barrier()
138
139
        try:
140
            h5file.close()
141
        except IOError as exc:
142
            logging.debug('There was a problem trying to close the file in random_hdf5_loader')
143
144
        return self.hdf5._open_backing_h5(fname, 'r')
145
146 View Code Duplication
    def __get_start_slice_list(self, slice_dirs, shape, n_frames):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
147
        n_processes = len(self.exp.get('processes'))
148
        rank = self.exp.get('process')
149
        frames = np.array_split(np.arange(n_frames), n_processes)[rank]
150
        f_range = list(range(0, frames[0])) if len(frames) else []
151
        sl = [slice(0, 1) if i in slice_dirs else slice(None)
152
              for i in range(len(shape))]
153
        idx = 0
154
        for i in f_range:
155
            if sl[slice_dirs[idx]] == shape[slice_dirs[idx]]-1:
156
                idx += 1
157
            tmp = sl[slice_dirs[idx]]
158
            sl[slice_dirs[idx]] = slice(tmp.start+1, tmp.stop+1)
159
160
        return sl, len(frames)
161
162
    def __convert_patterns(self, data_obj, object_type):
163
        if (object_type == 'synth_proj_data'):
164
            pattern_list = self.parameters['patterns']
165
        else:
166
            pattern_list = self.parameters['patterns_tomo']
167
        for p in pattern_list:
168
            p_split = p.split('.')
169
            name = p_split[0]
170
            dims = p_split[1:]
171
            core_dims = tuple([int(i[0]) for i in [d.split('c') for d in dims]
172
                              if len(i) == 2])
173
            slice_dims = tuple([int(i[0]) for i in [d.split('s') for d in dims]
174
                               if len(i) == 2])
175
            data_obj.add_pattern(
176
                    name, core_dims=core_dims, slice_dims=slice_dims)
177
178
    def _set_metadata(self, data_obj, n_entries):
179
        n_angles = len(self.angles)
180
        data_angles = n_entries
181
        if data_angles != n_angles:
182
            raise Exception("The number of angles %s does not match the data "
183
                            "dimension length %s", n_angles, data_angles)
184
        data_obj.meta_data.set("rotation_angle", self.angles)
185
        data_obj.meta_data.set("centre_of_rotation", self.cor)
186
187
    def __parameter_checks(self, data_obj):
188
        if not self.parameters['proj_data_dims']:
189
            raise Exception(
190
                    'Please specifiy the dimensions of the dataset to create.')
191
192
    def _get_n_entries(self):
193
        return self.n_entries
194