Test Failed
Pull Request — master (#878)
by
unknown
04:38
created

BaseTomophantomLoader.__init__()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 4
nop 2
dl 0
loc 4
rs 10
c 0
b 0
f 0
1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: base_tomophantom_loader
17
   :platform: Unix
18
   :synopsis: A loader that generates synthetic 3D projection full-field tomo data\
19
        as hdf5 dataset of any size.
20
21
.. moduleauthor:: Daniil Kazantsev <[email protected]>
22
"""
23
24
import os
25
import h5py
26
import logging
27
import numpy as np
28
from mpi4py import MPI
29
30
from savu.data.chunking import Chunking
31
from savu.plugins.utils import register_plugin
32
from savu.plugins.loaders.base_loader import BaseLoader
33
from savu.plugins.savers.utils.hdf5_utils import Hdf5Utils
34
from savu.plugins.stats.statistics import Statistics
35
36
import tomophantom
37
from tomophantom import TomoP2D, TomoP3D
38
39
@register_plugin
40
class BaseTomophantomLoader(BaseLoader):
41
    def __init__(self, name='BaseTomophantomLoader'):
42
        super(BaseTomophantomLoader, self).__init__(name)
43
        self.cor = None
44
        self.n_entries = None
45
46
    def setup(self):
47
        exp = self.exp
48
        data_obj = exp.create_data_object('in_data', 'synth_proj_data')
49
50
        self.proj_stats_obj = Statistics()
51
        self.proj_stats_obj.pattern = "PROJECTION"
52
        self.proj_stats_obj.plugin_name = "TomoPhantomLoader"
53
        self.proj_stats_obj.p_num = 1
54
        self.proj_stats_obj._iterative_group = None
55
        self.proj_stats_obj.stats = {'max': [], 'min': [], 'mean': [], 'std_dev': [], 'RSS': [], 'data_points': []}
56
57
        self.phantom_stats_obj = Statistics()
58
        self.phantom_stats_obj.pattern = "VOLUME_XY"
59
        self.phantom_stats_obj.plugin_name = "TomoPhantomLoader"
60
        self.phantom_stats_obj.p_num = 0
61
        self.phantom_stats_obj._iterative_group = None
62
        self.phantom_stats_obj.stats = {'max': [], 'min': [], 'mean': [], 'std_dev': [], 'RSS': [], 'data_points': []}
63
64
        self.proj_stats_obj.plugin_names[1] = "TomoPhantomLoader"  # This object belongs to the whole statistics class
65
        self.proj_stats_obj.plugin_numbers["TomoPhantomLoader"] = 1  # This object belongs to the whole statistics class
66
67
        data_obj.set_axis_labels(*self.parameters['axis_labels'])
68
        self.__convert_patterns(data_obj,'synth_proj_data')
69
        self.__parameter_checks(data_obj)
70
71
        self.tomo_model = self.parameters['tomo_model']
72
        # setting angles for parallel beam geometry
73
        self.angles = np.linspace(0.0, 180.0-(1e-14), self.parameters['proj_data_dims'][0], dtype='float32')
74
        path = os.path.dirname(tomophantom.__file__)
75
        self.path_library3D = os.path.join(path, "Phantom3DLibrary.dat")
76
77
        data_obj.backing_file = self.__get_backing_file(data_obj, 'synth_proj_data')
78
        data_obj.data = data_obj.backing_file['/']['entry1']['tomo_entry']['data']['data']
79
80
        # create a phantom file
81
        data_obj2 = exp.create_data_object('in_data', 'phantom')
82
        data_obj2.set_axis_labels(*['voxel_x.voxel', 'voxel_y.voxel', 'voxel_z.voxel'])
83
        self.__convert_patterns(data_obj2, 'phantom')
84
        self.__parameter_checks(data_obj2)
85
86
        data_obj2.backing_file = self.__get_backing_file(data_obj2, 'phantom')
87
        data_obj2.data = data_obj2.backing_file['/']['phantom']['data']
88
        data_obj.set_shape(data_obj.data.shape)
89
        group_name = '1-TomoPhantomLoader-phantom'
90
91
        self.n_entries = data_obj.get_shape()[0]
92
        cor_val = 0.5*(self.parameters['proj_data_dims'][2])
93
        self.cor = np.linspace(cor_val, cor_val, self.parameters['proj_data_dims'][1], dtype='float32')
94
95
        self.proj_stats_obj.volume_stats = self.proj_stats_obj.calc_volume_stats(self.proj_stats_obj.stats)  # Calculating volume-wide stats for projection
96
        Statistics.global_stats[1] = self.proj_stats_obj.volume_stats
97
        self.proj_stats_obj._write_stats_to_file(p_num=1, plugin_name="TomoPhantomLoader (synthetic projection)")  # writing these to file (stats/stats.h5)
98
99
        self.phantom_stats_obj.volume_stats = self.phantom_stats_obj.calc_volume_stats(self.phantom_stats_obj.stats)  # calculating volume-wide stats for phantom
100
        Statistics.global_stats[0] = self.phantom_stats_obj.volume_stats
101
        self.phantom_stats_obj._write_stats_to_file(p_num=0, plugin_name="TomoPhantomLoader (phantom)")  # writing these to file (stats/stats.h5)
102
103
        self._set_metadata(data_obj, self._get_n_entries())
104
105
        return data_obj, data_obj2
106
107
    def __get_backing_file(self, data_obj, file_name):
108
        fname = '%s/%s.h5' % \
109
            (self.exp.get('out_path'), file_name)
110
111
        if os.path.exists(fname):
112
            return h5py.File(fname, 'r')
113
114
        self.hdf5 = Hdf5Utils(self.exp)
115
116
        dims_temp = self.parameters['proj_data_dims'].copy()
117
        proj_data_dims = tuple(dims_temp)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable tuple does not seem to be defined.
Loading history...
118
        if file_name == 'phantom':
119
            dims_temp[0] = dims_temp[1]
120
            dims_temp[2] = dims_temp[1]
121
            proj_data_dims = tuple(dims_temp)
122
123
        patterns = data_obj.get_data_patterns()
124
        p_name = list(patterns.keys())[0]
125
        p_dict = patterns[p_name]
126
        p_dict['max_frames_transfer'] = 1
127
        nnext = {p_name: p_dict}
128
129
        pattern_idx = {'current': nnext, 'next': nnext}
130
        chunking = Chunking(self.exp, pattern_idx)
131
        chunks = chunking._calculate_chunking(proj_data_dims, np.int16)
132
133
        h5file = self.hdf5._open_backing_h5(fname, 'w')
134
135
        if file_name == 'phantom':
136
            group = h5file.create_group('/phantom', track_order=None)
137
        else:
138
            group = h5file.create_group('/entry1/tomo_entry/data', track_order=None)
139
140
        data_obj.dtype = np.dtype('<f4')
141
        dset = self.hdf5.create_dataset_nofill(group, "data", proj_data_dims, data_obj.dtype, chunks=chunks)
142
143
        self.exp._barrier()
144
145
146
        slice_dirs = list(nnext.values())[0]['slice_dims']
147
        nDims = len(dset.shape)
148
        total_frames = np.prod([dset.shape[i] for i in slice_dirs])
149
        sub_size = \
150
            [1 if i in slice_dirs else dset.shape[i] for i in range(nDims)]
151
152
        # need an mpi barrier after creating the file before populating it
153
        idx = 0
154
        sl, total_frames = \
155
            self.__get_start_slice_list(slice_dirs, dset.shape, total_frames)
156
        # calculate the first slice
157
        for i in range(total_frames):
158
            if sl[slice_dirs[idx]].stop == dset.shape[slice_dirs[idx]]:
159
                idx += 1
160
                if idx == len(slice_dirs):
161
                    break
162
            tmp = sl[slice_dirs[idx]]
163
            if (file_name == 'synth_proj_data'):
164
                #generate projection data
165
                gen_data = TomoP3D.ModelSinoSub(self.tomo_model, proj_data_dims[1], proj_data_dims[2],
166
                                                proj_data_dims[1], (tmp.start, tmp.start + 1), -self.angles,
167
                                                self.path_library3D)
168
                self.proj_stats_obj.set_slice_stats(gen_data, pad=None)  # getting slice stats for projection
169
            else:
170
                #generate phantom data
171
                gen_data = TomoP3D.ModelSub(self.tomo_model, proj_data_dims[1], (tmp.start, tmp.start + 1),
172
                                            self.path_library3D)
173
                self.phantom_stats_obj.set_slice_stats(gen_data, pad=None)  #getting slice stats for phantom
174
            dset[tuple(sl)] = np.swapaxes(gen_data,0,1)
175
            sl[slice_dirs[idx]] = slice(tmp.start+1, tmp.stop+1)
176
177
        self.exp._barrier()
178
179
180
181
        try:
182
            #nxsfile = NXdata(h5file)
183
            #nxsfile.save(file_name + ".nxs")
184
185
            h5file.close()
186
        except IOError as exc:
187
            logging.debug('There was a problem trying to close the file in random_hdf5_loader')
188
189
        return self.hdf5._open_backing_h5(fname, 'r')
190
191 View Code Duplication
    def __get_start_slice_list(self, slice_dirs, shape, n_frames):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
192
        n_processes = len(self.exp.get('processes'))
193
        rank = self.exp.get('process')
194
        frames = np.array_split(np.arange(n_frames), n_processes)[rank]
195
        f_range = list(range(0, frames[0])) if len(frames) else []
196
        sl = [slice(0, 1) if i in slice_dirs else slice(None)
197
              for i in range(len(shape))]
198
        idx = 0
199
        for i in f_range:
200
            if sl[slice_dirs[idx]] == shape[slice_dirs[idx]]-1:
201
                idx += 1
202
            tmp = sl[slice_dirs[idx]]
203
            sl[slice_dirs[idx]] = slice(tmp.start+1, tmp.stop+1)
204
205
        return sl, len(frames)
206
207
    def __convert_patterns(self, data_obj, object_type):
208
        if object_type == 'synth_proj_data':
209
            pattern_list = self.parameters['patterns']
210
        else:
211
            pattern_list = self.parameters['patterns_tomo2']
212
        for p in pattern_list:
213
            p_split = p.split('.')
214
            name = p_split[0]
215
            dims = p_split[1:]
216
            core_dims = tuple([int(i[0]) for i in [d.split('c') for d in dims]
217
                              if len(i) == 2])
218
            slice_dims = tuple([int(i[0]) for i in [d.split('s') for d in dims]
219
                               if len(i) == 2])
220
            data_obj.add_pattern(
221
                    name, core_dims=core_dims, slice_dims=slice_dims)
222
223
224
225
    def _set_metadata(self, data_obj, n_entries):
226
        n_angles = len(self.angles)
227
        data_angles = n_entries
228
        if data_angles != n_angles:
229
            raise Exception("The number of angles %s does not match the data "
230
                            "dimension length %s", n_angles, data_angles)
231
        data_obj.meta_data.set(['rotation_angle'], self.angles)
232
        data_obj.meta_data.set(['centre_of_rotation'], self.cor)
233
        data_obj
234
235
    def __parameter_checks(self, data_obj):
236
        if not self.parameters['proj_data_dims']:
237
            raise Exception(
238
                    'Please specifiy the dimensions of the dataset to create.')
239
240
    def _get_n_entries(self):
241
        return self.n_entries
242
243
244
    def post_process(self, data_obj, data_obj2):
245
246
        filename = self.exp.meta_data.get('nxs_filename')
247
        fsplit = filename.split('/')
248
        plugin_number = len(self.exp.meta_data.plugin_list.plugin_list)
249
        if plugin_number == 1:
250
            fsplit[-1] = 'synthetic_data.nxs'
251
        else:
252
            fsplit[-1] = 'synthetic_data_processed.nxs'
253
        filename = '/'.join(fsplit)
254
        self.exp.meta_data.set('nxs_filename', filename)
255
        self._link_nexus_file(data_obj2, 'phantom')
256
        self._link_nexus_file(data_obj, 'synth_proj_data')
257
258
259
260
    def _link_nexus_file(self, data_obj, name):
261
        """Link phantom + synthetic projection data h5 files to a single nexus file containing both."""
262
263
        if name == 'phantom':
264
            data_obj.exp.meta_data.set(['group_name', 'phantom'], 'phantom')
265
            data_obj.exp.meta_data.set(['link_type', 'phantom'], 'final_result')
266
            stats_dict = self.phantom_stats_obj._array_to_dict(self.phantom_stats_obj.volume_stats)
267
            for key in list(stats_dict.keys()):
268
                data_obj.meta_data.set(["stats", key], stats_dict[key])
269
270
        else:
271
            data_obj.exp.meta_data.set(['group_name', 'synth_proj_data'], 'entry1/tomo_entry/data')
272
            data_obj.exp.meta_data.set(['link_type', 'synth_proj_data'], 'entry1')
273
            stats_dict = self.proj_stats_obj._array_to_dict(self.proj_stats_obj.volume_stats)
274
            for key in list(stats_dict.keys()):
275
                data_obj.meta_data.set(["stats", key], stats_dict[key])
276
277
        self._populate_nexus_file(data_obj)
278
        self._link_datafile_to_nexus_file(data_obj)
279
280
281
    def _populate_nexus_file(self, data):
282
        """"""
283
284
        filename = self.exp.meta_data.get('nxs_filename')
285
        name = data.data_info.get('name')
286
        with h5py.File(filename, 'a', driver="mpio", comm = MPI.COMM_WORLD) as nxs_file:
287
288
            group_name = self.exp.meta_data.get(['group_name', name])
289
            link_type = self.exp.meta_data.get(['link_type', name])
290
291
            if name == 'phantom':
292
                if 'entry' not in list(nxs_file.keys()):
293
                    nxs_entry = nxs_file.create_group('entry')
294
                else:
295
                    nxs_entry = nxs_file['entry']
296
                if link_type == 'final_result':
297
                    group_name = 'final_result_' + data.get_name()
298
                else:
299
                    link = nxs_entry.require_group(link_type.encode("ascii"))
300
                    link.attrs['NX_class'] = 'NXcollection'
301
                    nxs_entry = link
302
303
                # delete the group if it already exists
304
                if group_name in nxs_entry:
305
                    del nxs_entry[group_name]
306
307
                plugin_entry = nxs_entry.require_group(group_name)
308
309
            else:
310
                plugin_entry = nxs_file.create_group(f'/{group_name}')
311
312
            self.__output_data_patterns(data, plugin_entry)
313
            self._output_metadata_dict(plugin_entry, data.meta_data.get_dictionary())
314
            self.__output_axis_labels(data, plugin_entry)
315
316
            plugin_entry.attrs['NX_class'] = 'NXdata'
317
318
319
    def __output_axis_labels(self, data, entry):
320
        axis_labels = data.data_info.get("axis_labels")
321
        ddict = data.meta_data.get_dictionary()
322
323
        axes = []
324
        count = 0
325
        dims_temp = self.parameters['proj_data_dims'].copy()
326
        if data.data_info.get('name') == 'phantom':
327
            dims_temp[0] = dims_temp[1]
328
            dims_temp[2] = dims_temp[1]
329
        dims = tuple(dims_temp)
330
331
        for labels in axis_labels:
332
            name = list(labels.keys())[0]
333
            axes.append(name)
334
            entry.attrs[name + '_indices'] = count
335
336
            mData = ddict[name] if name in list(ddict.keys()) \
337
                else np.arange(dims[count])
338
339
            if isinstance(mData, list):
340
                mData = np.array(mData)
341
342
            if 'U' in str(mData.dtype):
343
                mData = mData.astype(np.string_)
344
            if name not in list(entry.keys()):
345
                axis_entry = entry.require_dataset(name, mData.shape, mData.dtype)
346
                axis_entry[...] = mData[...]
347
                axis_entry.attrs['units'] = list(labels.values())[0]
348
            count += 1
349
        entry.attrs['axes'] = axes
350
351 View Code Duplication
    def __output_data_patterns(self, data, entry):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
352
        data_patterns = data.data_info.get("data_patterns")
353
        entry = entry.require_group('patterns')
354
        entry.attrs['NX_class'] = 'NXcollection'
355
        for pattern in data_patterns:
356
            nx_data = entry.require_group(pattern)
357
            nx_data.attrs['NX_class'] = 'NXparameters'
358
            values = data_patterns[pattern]
359
            self.__output_data(nx_data, values['core_dims'], 'core_dims')
360
            self.__output_data(nx_data, values['slice_dims'], 'slice_dims')
361
362
    def _output_metadata_dict(self, entry, mData):
363
        entry.attrs['NX_class'] = 'NXcollection'
364
        for key, value in mData.items():
365
            if key != 'rotation_angle':
366
                nx_data = entry.require_group(key)
367
                if isinstance(value, dict):
368
                    self._output_metadata_dict(nx_data, value)
369
                else:
370
                    nx_data.attrs['NX_class'] = 'NXdata'
371
                    self.__output_data(nx_data, value, key)
372
373 View Code Duplication
    def __output_data(self, entry, data, name):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
374
        if isinstance(data, dict):
375
            entry = entry.require_group(name)
376
            entry.attrs['NX_class'] = 'NXcollection'
377
            for key, value in data.items():
378
                self.__output_data(entry, value, key)
379
        else:
380
            try:
381
                self.__create_dataset(entry, name, data)
382
            except Exception:
383
                try:
384
                    import json
385
                    data = np.array([json.dumps(data).encode("ascii")])
386
                    self.__create_dataset(entry, name, data)
387
                except Exception:
388
                    try:
389
                        self.__create_dataset(entry, name, data)
390
                    except:
391
                        raise Exception('Unable to output %s to file.' % name)
392
393
    def __create_dataset(self, entry, name, data):
394
        if name not in list(entry.keys()):
395
            entry.create_dataset(name, data=data)
396
        else:
397
            entry[name][...] = data
398
399
    def _link_datafile_to_nexus_file(self, data):
400
        filename = self.exp.meta_data.get('nxs_filename')
401
402
        with h5py.File(filename, 'a', driver="mpio", comm = MPI.COMM_WORLD) as nxs_file:
403
            # entry path in nexus file
404
            name = data.get_name()
405
            group_name = self.exp.meta_data.get(['group_name', name])
406
            link = self.exp.meta_data.get(['link_type', name])
407
            name = data.get_name(orig=True)
408
            nxs_entry = self.__add_nxs_entry(nxs_file, link, group_name, name)
409
            self.__add_nxs_data(nxs_file, nxs_entry, link, group_name, data)
410
411
    def __add_nxs_entry(self, nxs_file, link, group_name, name):
412
        if name == 'phantom':
413
            nxs_entry = '/entry/' + link
414
        else:
415
            nxs_entry = ''
416
        nxs_entry += '_' + name if link == 'final_result' else "/" + group_name
417
        nxs_entry = nxs_file[nxs_entry]
418
        nxs_entry.attrs['signal'] = 'data'
419
        return nxs_entry
420
421 View Code Duplication
    def __add_nxs_data(self, nxs_file, nxs_entry, link, group_name, data):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
422
        data_entry = nxs_entry.name + '/data'
423
        # output file path
424
        h5file = data.backing_file.filename
425
426
        if link == 'input_data':
427
            dataset = self.__is_h5dataset(data)
428
            if dataset:
429
                nxs_file[data_entry] = \
430
                    h5py.ExternalLink(os.path.abspath(h5file), dataset.name)
431
        else:
432
            # entry path in output file path
433
            m_data = self.exp.meta_data.get
434
            if not (link == 'intermediate' and
435
                    m_data('inter_path') != m_data('out_path')):
436
                h5file = h5file.split(m_data('out_folder') + '/')[-1]
437
            nxs_file[data_entry] = \
438
                h5py.ExternalLink(h5file, group_name + '/data')
439