1
|
|
|
# Copyright 2014 Diamond Light Source Ltd. |
2
|
|
|
# |
3
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
4
|
|
|
# you may not use this file except in compliance with the License. |
5
|
|
|
# You may obtain a copy of the License at |
6
|
|
|
# |
7
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
8
|
|
|
# |
9
|
|
|
# Unless required by applicable law or agreed to in writing, software |
10
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS, |
11
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12
|
|
|
# See the License for the specific language governing permissions and |
13
|
|
|
# limitations under the License. |
14
|
|
|
|
15
|
|
|
""" |
16
|
|
|
.. module:: base_tomophantom_loader |
17
|
|
|
:platform: Unix |
18
|
|
|
:synopsis: A loader that generates synthetic 3D projection full-field tomo data\ |
19
|
|
|
as hdf5 dataset of any size. |
20
|
|
|
|
21
|
|
|
.. moduleauthor:: Daniil Kazantsev <[email protected]> |
22
|
|
|
""" |
23
|
|
|
|
24
|
|
|
import os |
25
|
|
|
import h5py |
26
|
|
|
import logging |
27
|
|
|
import numpy as np |
28
|
|
|
|
29
|
|
|
from savu.data.chunking import Chunking |
30
|
|
|
from savu.plugins.utils import register_plugin |
31
|
|
|
from savu.plugins.loaders.base_loader import BaseLoader |
32
|
|
|
from savu.plugins.savers.utils.hdf5_utils import Hdf5Utils |
33
|
|
|
|
34
|
|
|
import tomophantom |
35
|
|
|
from tomophantom import TomoP2D, TomoP3D |
36
|
|
|
|
37
|
|
|
@register_plugin |
38
|
|
|
class BaseTomophantomLoader(BaseLoader): |
39
|
|
|
def __init__(self, name='BaseTomophantomLoader'): |
40
|
|
|
super(BaseTomophantomLoader, self).__init__(name) |
41
|
|
|
|
42
|
|
|
def setup(self): |
43
|
|
|
exp = self.exp |
44
|
|
|
data_obj = exp.create_data_object('in_data', 'synth_proj_data') |
45
|
|
|
|
46
|
|
|
data_obj.set_axis_labels(*self.parameters['axis_labels']) |
47
|
|
|
self.__convert_patterns(data_obj,'synth_proj_data') |
48
|
|
|
self.__parameter_checks(data_obj) |
49
|
|
|
|
50
|
|
|
self.tomo_model = self.parameters['tomo_model'] |
51
|
|
|
# setting angles for parallel beam geometry |
52
|
|
|
self.angles = np.linspace(0.0,180.0-(1e-14), self.parameters['proj_data_dims'][0], dtype='float32') |
53
|
|
|
path = os.path.dirname(tomophantom.__file__) |
54
|
|
|
self.path_library3D = os.path.join(path, "Phantom3DLibrary.dat") |
55
|
|
|
|
56
|
|
|
data_obj.backing_file = self.__get_backing_file(data_obj, 'synth_proj_data') |
57
|
|
|
data_obj.data = data_obj.backing_file['/']['test'] |
58
|
|
|
#data_obj.data.dtype # Need to do something to .data to keep the file open! |
59
|
|
|
|
60
|
|
|
# create a phantom file |
61
|
|
|
data_obj2 = exp.create_data_object('in_data', 'phantom') |
62
|
|
|
|
63
|
|
|
data_obj2.set_axis_labels(*self.parameters['axis_labels_phantom']) |
64
|
|
|
self.__convert_patterns(data_obj2, 'phantom') |
65
|
|
|
self.__parameter_checks(data_obj2) |
66
|
|
|
|
67
|
|
|
data_obj2.backing_file = self.__get_backing_file(data_obj2, 'phantom') |
68
|
|
|
data_obj2.data = data_obj2.backing_file['/']['test'] |
69
|
|
|
#data_obj2.data.dtype # Need to do something to .data to keep the file open! |
70
|
|
|
|
71
|
|
|
data_obj.set_shape(data_obj.data.shape) |
72
|
|
|
self.n_entries = data_obj.get_shape()[0] |
73
|
|
|
cor_val=0.5*(self.parameters['proj_data_dims'][2]) |
74
|
|
|
self.cor=np.linspace(cor_val, cor_val, self.parameters['proj_data_dims'][1], dtype='float32') |
75
|
|
|
self._set_metadata(data_obj, self._get_n_entries()) |
76
|
|
|
return data_obj,data_obj2 |
77
|
|
|
|
78
|
|
|
def __get_backing_file(self, data_obj, file_name): |
79
|
|
|
fname = '%s/%s.h5' % \ |
80
|
|
|
(self.exp.get('out_path'), file_name) |
81
|
|
|
|
82
|
|
|
if os.path.exists(fname): |
83
|
|
|
return h5py.File(fname, 'r') |
84
|
|
|
|
85
|
|
|
self.hdf5 = Hdf5Utils(self.exp) |
86
|
|
|
|
87
|
|
|
dims_temp = self.parameters['proj_data_dims'].copy() |
88
|
|
|
proj_data_dims = tuple(dims_temp) |
|
|
|
|
89
|
|
|
if (file_name == 'phantom'): |
90
|
|
|
dims_temp[0]=dims_temp[1] |
91
|
|
|
dims_temp[2]=dims_temp[1] |
92
|
|
|
proj_data_dims = tuple(dims_temp) |
93
|
|
|
|
94
|
|
|
patterns = data_obj.get_data_patterns() |
95
|
|
|
p_name = list(patterns.keys())[0] |
96
|
|
|
p_dict = patterns[p_name] |
97
|
|
|
p_dict['max_frames_transfer'] = 1 |
98
|
|
|
nnext = {p_name: p_dict} |
99
|
|
|
|
100
|
|
|
pattern_idx = {'current': nnext, 'next': nnext} |
101
|
|
|
chunking = Chunking(self.exp, pattern_idx) |
102
|
|
|
chunks = chunking._calculate_chunking(proj_data_dims, np.int16) |
103
|
|
|
|
104
|
|
|
h5file = self.hdf5._open_backing_h5(fname, 'w') |
105
|
|
|
dset = h5file.create_dataset('test', proj_data_dims, chunks=chunks) |
106
|
|
|
|
107
|
|
|
self.exp._barrier() |
108
|
|
|
|
109
|
|
|
slice_dirs = list(nnext.values())[0]['slice_dims'] |
110
|
|
|
nDims = len(dset.shape) |
111
|
|
|
total_frames = np.prod([dset.shape[i] for i in slice_dirs]) |
112
|
|
|
sub_size = \ |
113
|
|
|
[1 if i in slice_dirs else dset.shape[i] for i in range(nDims)] |
114
|
|
|
|
115
|
|
|
# need an mpi barrier after creating the file before populating it |
116
|
|
|
idx = 0 |
117
|
|
|
sl, total_frames = \ |
118
|
|
|
self.__get_start_slice_list(slice_dirs, dset.shape, total_frames) |
119
|
|
|
# calculate the first slice |
120
|
|
|
for i in range(total_frames): |
121
|
|
|
if (file_name == 'synth_proj_data'): |
122
|
|
|
#generate projection data |
123
|
|
|
gen_data = TomoP3D.ModelSinoSub(self.tomo_model, proj_data_dims[1], proj_data_dims[2], proj_data_dims[1], (i, i+1), -self.angles, self.path_library3D) |
124
|
|
|
else: |
125
|
|
|
#generate phantom data |
126
|
|
|
gen_data = TomoP3D.ModelSub(self.tomo_model, proj_data_dims[1], (i, i+1), self.path_library3D) |
127
|
|
|
dset[tuple(sl)] = np.swapaxes(gen_data,0,1) |
128
|
|
|
if sl[slice_dirs[idx]].stop == dset.shape[slice_dirs[idx]]: |
129
|
|
|
idx += 1 |
130
|
|
|
if idx == len(slice_dirs): |
131
|
|
|
break |
132
|
|
|
tmp = sl[slice_dirs[idx]] |
133
|
|
|
sl[slice_dirs[idx]] = slice(tmp.start+1, tmp.stop+1) |
134
|
|
|
|
135
|
|
|
self.exp._barrier() |
136
|
|
|
|
137
|
|
|
try: |
138
|
|
|
h5file.close() |
139
|
|
|
except IOError as exc: |
140
|
|
|
logging.debug('There was a problem trying to close the file in random_hdf5_loader') |
141
|
|
|
|
142
|
|
|
return self.hdf5._open_backing_h5(fname, 'r') |
143
|
|
|
|
144
|
|
View Code Duplication |
def __get_start_slice_list(self, slice_dirs, shape, n_frames): |
|
|
|
|
145
|
|
|
n_processes = len(self.exp.get('processes')) |
146
|
|
|
rank = self.exp.get('process') |
147
|
|
|
frames = np.array_split(np.arange(n_frames), n_processes)[rank] |
148
|
|
|
f_range = list(range(0, frames[0])) if len(frames) else [] |
149
|
|
|
sl = [slice(0, 1) if i in slice_dirs else slice(None) |
150
|
|
|
for i in range(len(shape))] |
151
|
|
|
idx = 0 |
152
|
|
|
for i in f_range: |
153
|
|
|
if sl[slice_dirs[idx]] == shape[slice_dirs[idx]]-1: |
154
|
|
|
idx += 1 |
155
|
|
|
tmp = sl[slice_dirs[idx]] |
156
|
|
|
sl[slice_dirs[idx]] = slice(tmp.start+1, tmp.stop+1) |
157
|
|
|
|
158
|
|
|
return sl, len(frames) |
159
|
|
|
|
160
|
|
|
def __convert_patterns(self, data_obj, object_type): |
161
|
|
|
if (object_type == 'synth_proj_data'): |
162
|
|
|
pattern_list = self.parameters['patterns'] |
163
|
|
|
else: |
164
|
|
|
pattern_list = self.parameters['patterns_tomo'] |
165
|
|
|
for p in pattern_list: |
166
|
|
|
p_split = p.split('.') |
167
|
|
|
name = p_split[0] |
168
|
|
|
dims = p_split[1:] |
169
|
|
|
core_dims = tuple([int(i[0]) for i in [d.split('c') for d in dims] |
170
|
|
|
if len(i) == 2]) |
171
|
|
|
slice_dims = tuple([int(i[0]) for i in [d.split('s') for d in dims] |
172
|
|
|
if len(i) == 2]) |
173
|
|
|
data_obj.add_pattern( |
174
|
|
|
name, core_dims=core_dims, slice_dims=slice_dims) |
175
|
|
|
|
176
|
|
|
def _set_metadata(self, data_obj, n_entries): |
177
|
|
|
n_angles = len(self.angles) |
178
|
|
|
data_angles = n_entries |
179
|
|
|
if data_angles != n_angles: |
180
|
|
|
raise Exception("The number of angles %s does not match the data " |
181
|
|
|
"dimension length %s", n_angles, data_angles) |
182
|
|
|
data_obj.meta_data.set("rotation_angle", self.angles) |
183
|
|
|
data_obj.meta_data.set("centre_of_rotation", self.cor) |
184
|
|
|
|
185
|
|
|
def __parameter_checks(self, data_obj): |
186
|
|
|
if not self.parameters['proj_data_dims']: |
187
|
|
|
raise Exception( |
188
|
|
|
'Please specifiy the dimensions of the dataset to create.') |
189
|
|
|
|
190
|
|
|
def _get_n_entries(self): |
191
|
|
|
return self.n_entries |
192
|
|
|
|