Test Failed
Push — master ( 1e6f38...800e06 )
by Daniil
01:18 queued 17s
created

savu.plugins.loaders.base_tomophantom_loader   F

Complexity

Total Complexity 66

Size/Duplication

Total Lines 417
Duplicated Lines 17.75 %

Importance

Changes 0
Metric Value
eloc 283
dl 74
loc 417
rs 3.12
c 0
b 0
f 0
wmc 66

19 Methods

Rating   Name   Duplication   Size   Complexity  
A BaseTomophantomLoader.__init__() 0 2 1
A BaseTomophantomLoader.setup() 0 39 1
A BaseTomophantomLoader.post_process() 0 17 2
B BaseTomophantomLoader.__output_data() 19 19 6
A BaseTomophantomLoader._link_datafile_to_nexus_file() 12 12 2
A BaseTomophantomLoader._output_metadata_dict() 0 10 4
C BaseTomophantomLoader.__get_backing_file() 0 81 10
A BaseTomophantomLoader._get_n_entries() 0 2 1
A BaseTomophantomLoader.__create_dataset() 0 5 2
B BaseTomophantomLoader._populate_nexus_file() 0 38 6
A BaseTomophantomLoader.__add_nxs_data() 18 18 5
A BaseTomophantomLoader.__get_start_slice_list() 15 15 5
B BaseTomophantomLoader.__output_axis_labels() 0 31 7
A BaseTomophantomLoader.__add_nxs_entry() 0 9 3
A BaseTomophantomLoader.__convert_patterns() 0 15 3
A BaseTomophantomLoader.__parameter_checks() 0 4 2
A BaseTomophantomLoader._set_metadata() 0 8 2
A BaseTomophantomLoader._link_nexus_file() 0 15 2
A BaseTomophantomLoader.__output_data_patterns() 10 10 2

How to fix   Duplicated Code    Complexity   

Duplicated Code

Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.

Common duplication problems, and corresponding solutions are:

Complexity

 Tip:   Before tackling complexity, make sure that you eliminate any duplication first. This often can reduce the size of classes significantly.

Complex classes like savu.plugins.loaders.base_tomophantom_loader often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: base_tomophantom_loader
17
   :platform: Unix
18
   :synopsis: A loader that generates synthetic 3D projection full-field tomo data\
19
        as hdf5 dataset of any size.
20
21
.. moduleauthor:: Daniil Kazantsev <[email protected]>
22
"""
23
24
import os
25
import h5py
26
import logging
27
import numpy as np
28
from mpi4py import MPI
29
30
from savu.data.chunking import Chunking
31
from savu.plugins.utils import register_plugin
32
from savu.plugins.loaders.base_loader import BaseLoader
33
from savu.plugins.savers.utils.hdf5_utils import Hdf5Utils
34
35
from savu.data.plugin_list import PluginList
36
37
import tomophantom
38
from tomophantom import TomoP2D, TomoP3D
39
40
@register_plugin
41
class BaseTomophantomLoader(BaseLoader):
42
    def __init__(self, name='BaseTomophantomLoader'):
43
        super(BaseTomophantomLoader, self).__init__(name)
44
45
    def setup(self):
46
        exp = self.exp
47
        data_obj = exp.create_data_object('in_data', 'synth_proj_data')
48
49
50
        data_obj.set_axis_labels(*self.parameters['axis_labels'])
51
        self.__convert_patterns(data_obj,'synth_proj_data')
52
        self.__parameter_checks(data_obj)
53
54
        self.tomo_model = self.parameters['tomo_model']
55
        # setting angles for parallel beam geometry
56
        self.angles = np.linspace(0.0,180.0-(1e-14), self.parameters['proj_data_dims'][0], dtype='float32')
57
        path = os.path.dirname(tomophantom.__file__)
58
        self.path_library3D = os.path.join(path, "Phantom3DLibrary.dat")
59
60
61
        data_obj.backing_file = self.__get_backing_file(data_obj, 'synth_proj_data')
62
        data_obj.data = data_obj.backing_file['/']['entry1']['tomo_entry']['data']['data']
63
        #data_obj.data.dtype # Need to do something to .data to keep the file open!
64
65
        # create a phantom file
66
        data_obj2 = exp.create_data_object('in_data', 'phantom')
67
        data_obj2.set_axis_labels(*['voxel_x.voxel', 'voxel_y.voxel', 'voxel_z.voxel'])
68
        self.__convert_patterns(data_obj2, 'phantom')
69
        self.__parameter_checks(data_obj2)
70
71
        data_obj2.backing_file = self.__get_backing_file(data_obj2, 'phantom')
72
        data_obj2.data = data_obj2.backing_file['/']['phantom']['data']
73
        #data_obj2.data.dtype # Need to do something to .data to keep the file open!
74
        data_obj.set_shape(data_obj.data.shape)
75
        group_name = '1-TomoPhantomLoader-phantom'
76
77
        self.n_entries = data_obj.get_shape()[0]
78
        cor_val=0.5*(self.parameters['proj_data_dims'][2])
79
        self.cor=np.linspace(cor_val, cor_val, self.parameters['proj_data_dims'][1], dtype='float32')
80
        self._set_metadata(data_obj, self._get_n_entries())
81
82
83
        return data_obj, data_obj2
84
85
    def __get_backing_file(self, data_obj, file_name):
86
        fname = '%s/%s.h5' % \
87
            (self.exp.get('out_path'), file_name)
88
89
        if os.path.exists(fname):
90
            return h5py.File(fname, 'r')
91
92
        self.hdf5 = Hdf5Utils(self.exp)
93
94
        dims_temp = self.parameters['proj_data_dims'].copy()
95
        proj_data_dims = tuple(dims_temp)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable tuple does not seem to be defined.
Loading history...
96
        if (file_name == 'phantom'):
97
            dims_temp[0]=dims_temp[1]
98
            dims_temp[2]=dims_temp[1]
99
            proj_data_dims = tuple(dims_temp)
100
101
        patterns = data_obj.get_data_patterns()
102
        p_name = list(patterns.keys())[0]
103
        p_dict = patterns[p_name]
104
        p_dict['max_frames_transfer'] = 1
105
        nnext = {p_name: p_dict}
106
107
        pattern_idx = {'current': nnext, 'next': nnext}
108
        chunking = Chunking(self.exp, pattern_idx)
109
        chunks = chunking._calculate_chunking(proj_data_dims, np.int16)
110
111
        h5file = self.hdf5._open_backing_h5(fname, 'w')
112
113
        if file_name == 'phantom':
114
            group = h5file.create_group('/phantom', track_order=None)
115
        else:
116
            group = h5file.create_group('/entry1/tomo_entry/data', track_order=None)
117
118
        data_obj.dtype = np.dtype('<f4')
119
        dset = self.hdf5.create_dataset_nofill(group, "data", proj_data_dims, data_obj.dtype, chunks=chunks)
120
121
        self.exp._barrier()
122
123
124
        slice_dirs = list(nnext.values())[0]['slice_dims']
125
        nDims = len(dset.shape)
126
        total_frames = np.prod([dset.shape[i] for i in slice_dirs])
127
        sub_size = \
128
            [1 if i in slice_dirs else dset.shape[i] for i in range(nDims)]
129
130
        # need an mpi barrier after creating the file before populating it
131
        idx = 0
132
        sl, total_frames = \
133
            self.__get_start_slice_list(slice_dirs, dset.shape, total_frames)
134
        # calculate the first slice
135
        for i in range(total_frames):
136
            if sl[slice_dirs[idx]].stop == dset.shape[slice_dirs[idx]]:
137
                idx += 1
138
                if idx == len(slice_dirs):
139
                    break
140
            tmp = sl[slice_dirs[idx]]
141
            if (file_name == 'synth_proj_data'):
142
                #generate projection data
143
                gen_data = TomoP3D.ModelSinoSub(self.tomo_model, proj_data_dims[1], proj_data_dims[2],
144
                                                proj_data_dims[1], (tmp.start, tmp.start + 1), -self.angles,
145
                                                self.path_library3D)
146
            else:
147
                #generate phantom data
148
                gen_data = TomoP3D.ModelSub(self.tomo_model, proj_data_dims[1], (tmp.start, tmp.start + 1),
149
                                            self.path_library3D)
150
            dset[tuple(sl)] = np.swapaxes(gen_data,0,1)
151
            sl[slice_dirs[idx]] = slice(tmp.start+1, tmp.stop+1)
152
153
        self.exp._barrier()
154
155
156
157
        try:
158
            #nxsfile = NXdata(h5file)
159
            #nxsfile.save(file_name + ".nxs")
160
161
            h5file.close()
162
        except IOError as exc:
163
            logging.debug('There was a problem trying to close the file in random_hdf5_loader')
164
165
        return self.hdf5._open_backing_h5(fname, 'r')
166
167 View Code Duplication
    def __get_start_slice_list(self, slice_dirs, shape, n_frames):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
168
        n_processes = len(self.exp.get('processes'))
169
        rank = self.exp.get('process')
170
        frames = np.array_split(np.arange(n_frames), n_processes)[rank]
171
        f_range = list(range(0, frames[0])) if len(frames) else []
172
        sl = [slice(0, 1) if i in slice_dirs else slice(None)
173
              for i in range(len(shape))]
174
        idx = 0
175
        for i in f_range:
176
            if sl[slice_dirs[idx]] == shape[slice_dirs[idx]]-1:
177
                idx += 1
178
            tmp = sl[slice_dirs[idx]]
179
            sl[slice_dirs[idx]] = slice(tmp.start+1, tmp.stop+1)
180
181
        return sl, len(frames)
182
183
    def __convert_patterns(self, data_obj, object_type):
184
        if (object_type == 'synth_proj_data'):
185
            pattern_list = self.parameters['patterns']
186
        else:
187
            pattern_list = self.parameters['patterns_tomo']
188
        for p in pattern_list:
189
            p_split = p.split('.')
190
            name = p_split[0]
191
            dims = p_split[1:]
192
            core_dims = tuple([int(i[0]) for i in [d.split('c') for d in dims]
193
                              if len(i) == 2])
194
            slice_dims = tuple([int(i[0]) for i in [d.split('s') for d in dims]
195
                               if len(i) == 2])
196
            data_obj.add_pattern(
197
                    name, core_dims=core_dims, slice_dims=slice_dims)
198
199
200
201
    def _set_metadata(self, data_obj, n_entries):
202
        n_angles = len(self.angles)
203
        data_angles = n_entries
204
        if data_angles != n_angles:
205
            raise Exception("The number of angles %s does not match the data "
206
                            "dimension length %s", n_angles, data_angles)
207
        data_obj.meta_data.set(['rotation_angle'], self.angles)
208
        data_obj.meta_data.set(['centre_of_rotation'], self.cor)
209
210
    def __parameter_checks(self, data_obj):
211
        if not self.parameters['proj_data_dims']:
212
            raise Exception(
213
                    'Please specifiy the dimensions of the dataset to create.')
214
215
    def _get_n_entries(self):
216
        return self.n_entries
217
218
219
    def post_process(self, data_obj, data_obj2):
220
221
        filename = self.exp.meta_data.get('nxs_filename')
222
        fsplit = filename.split('/')
223
        plugin_number = len(self.exp.meta_data.plugin_list.plugin_list)
224
        if plugin_number == 1:
225
            fsplit[-1] = 'synthetic_data.nxs'
226
        else:
227
            fsplit[-1] = 'synthetic_data_processed.nxs'
228
        filename = '/'.join(fsplit)
229
        self.exp.meta_data.set('nxs_filename', filename)
230
231
        plugin_list = PluginList()
232
        #plugin_list._save_plugin_list(filename)
233
        self.exp._finalise_setup(plugin_list)
234
        self._link_nexus_file(data_obj2, 'phantom', plugin_list)
235
        self._link_nexus_file(data_obj, 'synth_proj_data', plugin_list)
236
237
238
239
    def _link_nexus_file(self, data_obj, name, plugin_list):
240
        """Link phantom + synthetic projection data h5 files to a single nexus file containing both."""
241
242
243
        if name == 'phantom':
244
            data_obj.exp.meta_data.set(['group_name', 'phantom'], 'phantom')
245
            data_obj.exp.meta_data.set(['link_type', 'phantom'], 'final_result')
246
            data_obj.meta_data.set(["meta_data", "PLACEHOLDER", "VOLUME_XZ"], [10])
247
248
        else:
249
            data_obj.exp.meta_data.set(['group_name', 'synth_proj_data'], 'entry1/tomo_entry/data')
250
            data_obj.exp.meta_data.set(['link_type', 'synth_proj_data'], 'entry1')
251
252
        self._populate_nexus_file(data_obj)
253
        self._link_datafile_to_nexus_file(data_obj)
254
255
256
    def _populate_nexus_file(self, data):
257
        """"""
258
259
        filename = self.exp.meta_data.get('nxs_filename')
260
        name = data.data_info.get('name')
261
        #driver = "mpio", comm = MPI.COMM_WORLD
262
        with h5py.File(filename, 'a', driver="mpio", comm = MPI.COMM_WORLD) as nxs_file:
263
        #nxs_file = self.hdf5._open_backing_h5(filename, 'a', mpi=False)
264
265
            group_name = self.exp.meta_data.get(['group_name', name])
266
            link_type = self.exp.meta_data.get(['link_type', name])
267
268
            if name == 'phantom':
269
                if 'entry' not in list(nxs_file.keys()):
270
                    nxs_entry = nxs_file.create_group('entry')
271
                else:
272
                    nxs_entry = nxs_file['entry']
273
                if link_type == 'final_result':
274
                    group_name = 'final_result_' + data.get_name()
275
                else:
276
                    link = nxs_entry.require_group(link_type.encode("ascii"))
277
                    link.attrs['NX_class'] = 'NXcollection'
278
                    nxs_entry = link
279
280
                # delete the group if it already exists
281
                if group_name in nxs_entry:
282
                    del nxs_entry[group_name]
283
284
                plugin_entry = nxs_entry.require_group(group_name)
285
286
            else:
287
                plugin_entry = nxs_file.create_group(f'/{group_name}')
288
289
            self.__output_data_patterns(data, plugin_entry)
290
            self._output_metadata_dict(plugin_entry, data.meta_data.get_dictionary())
291
            self.__output_axis_labels(data, plugin_entry)
292
293
            plugin_entry.attrs['NX_class'] = 'NXdata'
294
295
296
    def __output_axis_labels(self, data, entry):
297
        axis_labels = data.data_info.get("axis_labels")
298
        ddict = data.meta_data.get_dictionary()
299
300
        axes = []
301
        count = 0
302
        dims_temp = self.parameters['proj_data_dims'].copy()
303
        if data.data_info.get('name') == 'phantom':
304
            dims_temp[0] = dims_temp[1]
305
            dims_temp[2] = dims_temp[1]
306
        dims = tuple(dims_temp)
307
308
        for labels in axis_labels:
309
            name = list(labels.keys())[0]
310
            axes.append(name)
311
            entry.attrs[name + '_indices'] = count
312
313
            mData = ddict[name] if name in list(ddict.keys()) \
314
                else np.arange(dims[count])
315
316
            if isinstance(mData, list):
317
                mData = np.array(mData)
318
319
            if 'U' in str(mData.dtype):
320
                mData = mData.astype(np.string_)
321
            if name not in list(entry.keys()):
322
                axis_entry = entry.require_dataset(name, mData.shape, mData.dtype)
323
                axis_entry[...] = mData[...]
324
                axis_entry.attrs['units'] = list(labels.values())[0]
325
            count += 1
326
        entry.attrs['axes'] = axes
327
328 View Code Duplication
    def __output_data_patterns(self, data, entry):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
329
        data_patterns = data.data_info.get("data_patterns")
330
        entry = entry.require_group('patterns')
331
        entry.attrs['NX_class'] = 'NXcollection'
332
        for pattern in data_patterns:
333
            nx_data = entry.require_group(pattern)
334
            nx_data.attrs['NX_class'] = 'NXparameters'
335
            values = data_patterns[pattern]
336
            self.__output_data(nx_data, values['core_dims'], 'core_dims')
337
            self.__output_data(nx_data, values['slice_dims'], 'slice_dims')
338
339
    def _output_metadata_dict(self, entry, mData):
340
        entry.attrs['NX_class'] = 'NXcollection'
341
        for key, value in mData.items():
342
            if key != 'rotation_angle':
343
                nx_data = entry.require_group(key)
344
                if isinstance(value, dict):
345
                    self._output_metadata_dict(nx_data, value)
346
                else:
347
                    nx_data.attrs['NX_class'] = 'NXdata'
348
                    self.__output_data(nx_data, value, key)
349
350 View Code Duplication
    def __output_data(self, entry, data, name):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
351
        if isinstance(data, dict):
352
            entry = entry.require_group(name)
353
            entry.attrs['NX_class'] = 'NXcollection'
354
            for key, value in data.items():
355
                self.__output_data(entry, value, key)
356
        else:
357
            try:
358
                self.__create_dataset(entry, name, data)
359
            except Exception:
360
                try:
361
                    import json
362
                    data = np.array([json.dumps(data).encode("ascii")])
363
                    self.__create_dataset(entry, name, data)
364
                except Exception:
365
                    try:
366
                        self.__create_dataset(entry, name, data)
367
                    except:
368
                        raise Exception('Unable to output %s to file.' % name)
369
370
    def __create_dataset(self, entry, name, data):
371
        if name not in list(entry.keys()):
372
            entry.create_dataset(name, data=data)
373
        else:
374
            entry[name][...] = data
375
376 View Code Duplication
    def _link_datafile_to_nexus_file(self, data):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
377
        filename = self.exp.meta_data.get('nxs_filename')
378
379
        with h5py.File(filename, 'a', driver="mpio", comm = MPI.COMM_WORLD) as nxs_file:
380
        #nxs_file = self.hdf5._open_backing_h5(filename, 'a', mpi=False)
381
            # entry path in nexus file
382
            name = data.get_name()
383
            group_name = self.exp.meta_data.get(['group_name', name])
384
            link = self.exp.meta_data.get(['link_type', name])
385
            name = data.get_name(orig=True)
386
            nxs_entry = self.__add_nxs_entry(nxs_file, link, group_name, name)
387
            self.__add_nxs_data(nxs_file, nxs_entry, link, group_name, data)
388
389
    def __add_nxs_entry(self, nxs_file, link, group_name, name):
390
        if name == 'phantom':
391
            nxs_entry = '/entry/' + link
392
        else:
393
            nxs_entry = ''
394
        nxs_entry += '_' + name if link == 'final_result' else "/" + group_name
395
        nxs_entry = nxs_file[nxs_entry]
396
        nxs_entry.attrs['signal'] = 'data'
397
        return nxs_entry
398
399 View Code Duplication
    def __add_nxs_data(self, nxs_file, nxs_entry, link, group_name, data):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
400
        data_entry = nxs_entry.name + '/data'
401
        # output file path
402
        h5file = data.backing_file.filename
403
404
        if link == 'input_data':
405
            dataset = self.__is_h5dataset(data)
406
            if dataset:
407
                nxs_file[data_entry] = \
408
                    h5py.ExternalLink(os.path.abspath(h5file), dataset.name)
409
        else:
410
            # entry path in output file path
411
            m_data = self.exp.meta_data.get
412
            if not (link == 'intermediate' and
413
                    m_data('inter_path') != m_data('out_path')):
414
                h5file = h5file.split(m_data('out_folder') + '/')[-1]
415
            nxs_file[data_entry] = \
416
                h5py.ExternalLink(h5file, group_name + '/data')
417