Total Complexity | 63 |
Total Lines | 403 |
Duplicated Lines | 24.07 % |
Changes | 0 |
Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.
Common duplication problems, and corresponding solutions are:
Complex classes like savu.plugins.loaders.base_tomophantom_loader often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | # Copyright 2014 Diamond Light Source Ltd. |
||
2 | # |
||
3 | # Licensed under the Apache License, Version 2.0 (the "License"); |
||
4 | # you may not use this file except in compliance with the License. |
||
5 | # You may obtain a copy of the License at |
||
6 | # |
||
7 | # http://www.apache.org/licenses/LICENSE-2.0 |
||
8 | # |
||
9 | # Unless required by applicable law or agreed to in writing, software |
||
10 | # distributed under the License is distributed on an "AS IS" BASIS, |
||
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||
12 | # See the License for the specific language governing permissions and |
||
13 | # limitations under the License. |
||
14 | |||
15 | """ |
||
16 | .. module:: base_tomophantom_loader |
||
17 | :platform: Unix |
||
18 | :synopsis: A loader that generates synthetic 3D projection full-field tomo data\ |
||
19 | as hdf5 dataset of any size. |
||
20 | |||
21 | .. moduleauthor:: Daniil Kazantsev <[email protected]> |
||
22 | """ |
||
23 | |||
24 | import os |
||
25 | import h5py |
||
26 | import logging |
||
27 | import numpy as np |
||
28 | |||
29 | from savu.data.chunking import Chunking |
||
30 | from savu.plugins.utils import register_plugin |
||
31 | from savu.plugins.loaders.base_loader import BaseLoader |
||
32 | from savu.plugins.savers.utils.hdf5_utils import Hdf5Utils |
||
33 | |||
34 | from savu.data.meta_data import MetaData |
||
35 | from savu.core.transports.base_transport import BaseTransport |
||
36 | |||
37 | import tomophantom |
||
38 | from tomophantom import TomoP2D, TomoP3D |
||
39 | |||
40 | @register_plugin |
||
41 | class BaseTomophantomLoader(BaseLoader): |
||
42 | def __init__(self, name='BaseTomophantomLoader'): |
||
43 | super(BaseTomophantomLoader, self).__init__(name) |
||
44 | |||
45 | def setup(self): |
||
46 | exp = self.exp |
||
47 | data_obj = exp.create_data_object('in_data', 'synth_proj_data') |
||
48 | |||
49 | |||
50 | data_obj.set_axis_labels(*self.parameters['axis_labels']) |
||
51 | self.__convert_patterns(data_obj,'synth_proj_data') |
||
52 | self.__parameter_checks(data_obj) |
||
53 | |||
54 | self.tomo_model = self.parameters['tomo_model'] |
||
55 | # setting angles for parallel beam geometry |
||
56 | self.angles = np.linspace(0.0,180.0-(1e-14), self.parameters['proj_data_dims'][0], dtype='float32') |
||
57 | path = os.path.dirname(tomophantom.__file__) |
||
58 | self.path_library3D = os.path.join(path, "Phantom3DLibrary.dat") |
||
59 | |||
60 | |||
61 | data_obj.backing_file = self.__get_backing_file(data_obj, 'synth_proj_data') |
||
62 | #data_obj.data = data_obj.backing_file['/']['test'] |
||
63 | data_obj.data = data_obj.backing_file['/']['entry1']['tomo_entry']['data']['data'] |
||
64 | #data_obj.data.dtype # Need to do something to .data to keep the file open! |
||
65 | |||
66 | # create a phantom file |
||
67 | data_obj2 = exp.create_data_object('in_data', 'phantom') |
||
68 | data_obj2.set_axis_labels(*['voxel_x.voxel', 'voxel_y.voxel', 'voxel_z.voxel']) |
||
69 | self.__convert_patterns(data_obj2, 'phantom') |
||
70 | self.__parameter_checks(data_obj2) |
||
71 | |||
72 | #data_obj2.data_path = 'phantom/input_data' |
||
73 | data_obj2.backing_file = self.__get_backing_file(data_obj2, 'phantom') |
||
74 | # data_obj2.data = data_obj2.backing_file['/']['test'] |
||
75 | data_obj2.data = data_obj2.backing_file['/']['phantom']['data'] |
||
76 | #data_obj2.data.dtype # Need to do something to .data to keep the file open! |
||
77 | data_obj.set_shape(data_obj.data.shape) |
||
78 | group_name = '1-TomoPhantomLoader-phantom' |
||
79 | |||
80 | self.n_entries = data_obj.get_shape()[0] |
||
81 | cor_val=0.5*(self.parameters['proj_data_dims'][2]) |
||
82 | self.cor=np.linspace(cor_val, cor_val, self.parameters['proj_data_dims'][1], dtype='float32') |
||
83 | self._set_metadata(data_obj, self._get_n_entries()) |
||
84 | |||
85 | self._link_nexus_file(data_obj, 'synth_proj_data') |
||
86 | self._link_nexus_file(data_obj2, 'phantom') |
||
87 | return data_obj, data_obj2 |
||
88 | |||
89 | def __get_backing_file(self, data_obj, file_name): |
||
90 | fname = '%s/%s.h5' % \ |
||
91 | (self.exp.get('out_path'), file_name) |
||
92 | |||
93 | if os.path.exists(fname): |
||
94 | return h5py.File(fname, 'r') |
||
95 | |||
96 | self.hdf5 = Hdf5Utils(self.exp) |
||
97 | |||
98 | dims_temp = self.parameters['proj_data_dims'].copy() |
||
99 | proj_data_dims = tuple(dims_temp) |
||
|
|||
100 | if (file_name == 'phantom'): |
||
101 | dims_temp[0]=dims_temp[1] |
||
102 | dims_temp[2]=dims_temp[1] |
||
103 | proj_data_dims = tuple(dims_temp) |
||
104 | |||
105 | patterns = data_obj.get_data_patterns() |
||
106 | p_name = list(patterns.keys())[0] |
||
107 | p_dict = patterns[p_name] |
||
108 | p_dict['max_frames_transfer'] = 1 |
||
109 | nnext = {p_name: p_dict} |
||
110 | |||
111 | pattern_idx = {'current': nnext, 'next': nnext} |
||
112 | chunking = Chunking(self.exp, pattern_idx) |
||
113 | chunks = chunking._calculate_chunking(proj_data_dims, np.int16) |
||
114 | |||
115 | h5file = self.hdf5._open_backing_h5(fname, 'w') |
||
116 | |||
117 | #dset = h5file.create_dataset('test', proj_data_dims, chunks=chunks) |
||
118 | |||
119 | if file_name == 'phantom': |
||
120 | group = h5file.create_group('/phantom', track_order=None) |
||
121 | else: |
||
122 | group = h5file.create_group('/entry1/tomo_entry/data', track_order=None) |
||
123 | #group.attrs['NX_class'] = 'NXdata' |
||
124 | #group.attrs['signal'] = 'data' |
||
125 | |||
126 | dset = self.hdf5.create_dataset_nofill(group, "data", proj_data_dims, data_obj.dtype, chunks = chunks) |
||
127 | |||
128 | self.exp._barrier() |
||
129 | |||
130 | |||
131 | slice_dirs = list(nnext.values())[0]['slice_dims'] |
||
132 | nDims = len(dset.shape) |
||
133 | total_frames = np.prod([dset.shape[i] for i in slice_dirs]) |
||
134 | sub_size = \ |
||
135 | [1 if i in slice_dirs else dset.shape[i] for i in range(nDims)] |
||
136 | |||
137 | # need an mpi barrier after creating the file before populating it |
||
138 | idx = 0 |
||
139 | sl, total_frames = \ |
||
140 | self.__get_start_slice_list(slice_dirs, dset.shape, total_frames) |
||
141 | # calculate the first slice |
||
142 | for i in range(total_frames): |
||
143 | if (file_name == 'synth_proj_data'): |
||
144 | #generate projection data |
||
145 | gen_data = TomoP3D.ModelSinoSub(self.tomo_model, proj_data_dims[1], proj_data_dims[2], proj_data_dims[1], (i, i+1), -self.angles, self.path_library3D) |
||
146 | else: |
||
147 | #generate phantom data |
||
148 | gen_data = TomoP3D.ModelSub(self.tomo_model, proj_data_dims[1], (i, i+1), self.path_library3D) |
||
149 | dset[tuple(sl)] = np.swapaxes(gen_data,0,1) |
||
150 | if sl[slice_dirs[idx]].stop == dset.shape[slice_dirs[idx]]: |
||
151 | idx += 1 |
||
152 | if idx == len(slice_dirs): |
||
153 | break |
||
154 | tmp = sl[slice_dirs[idx]] |
||
155 | sl[slice_dirs[idx]] = slice(tmp.start+1, tmp.stop+1) |
||
156 | |||
157 | self.exp._barrier() |
||
158 | |||
159 | |||
160 | |||
161 | try: |
||
162 | #nxsfile = NXdata(h5file) |
||
163 | #nxsfile.save(file_name + ".nxs") |
||
164 | |||
165 | h5file.close() |
||
166 | except IOError as exc: |
||
167 | logging.debug('There was a problem trying to close the file in random_hdf5_loader') |
||
168 | |||
169 | return self.hdf5._open_backing_h5(fname, 'r') |
||
170 | |||
171 | View Code Duplication | def __get_start_slice_list(self, slice_dirs, shape, n_frames): |
|
172 | n_processes = len(self.exp.get('processes')) |
||
173 | rank = self.exp.get('process') |
||
174 | frames = np.array_split(np.arange(n_frames), n_processes)[rank] |
||
175 | f_range = list(range(0, frames[0])) if len(frames) else [] |
||
176 | sl = [slice(0, 1) if i in slice_dirs else slice(None) |
||
177 | for i in range(len(shape))] |
||
178 | idx = 0 |
||
179 | for i in f_range: |
||
180 | if sl[slice_dirs[idx]] == shape[slice_dirs[idx]]-1: |
||
181 | idx += 1 |
||
182 | tmp = sl[slice_dirs[idx]] |
||
183 | sl[slice_dirs[idx]] = slice(tmp.start+1, tmp.stop+1) |
||
184 | |||
185 | return sl, len(frames) |
||
186 | |||
187 | def __convert_patterns(self, data_obj, object_type): |
||
188 | if (object_type == 'synth_proj_data'): |
||
189 | pattern_list = self.parameters['patterns'] |
||
190 | else: |
||
191 | pattern_list = self.parameters['patterns_tomo'] |
||
192 | for p in pattern_list: |
||
193 | p_split = p.split('.') |
||
194 | name = p_split[0] |
||
195 | dims = p_split[1:] |
||
196 | core_dims = tuple([int(i[0]) for i in [d.split('c') for d in dims] |
||
197 | if len(i) == 2]) |
||
198 | slice_dims = tuple([int(i[0]) for i in [d.split('s') for d in dims] |
||
199 | if len(i) == 2]) |
||
200 | data_obj.add_pattern( |
||
201 | name, core_dims=core_dims, slice_dims=slice_dims) |
||
202 | |||
203 | |||
204 | |||
205 | def _set_metadata(self, data_obj, n_entries): |
||
206 | n_angles = len(self.angles) |
||
207 | data_angles = n_entries |
||
208 | if data_angles != n_angles: |
||
209 | raise Exception("The number of angles %s does not match the data " |
||
210 | "dimension length %s", n_angles, data_angles) |
||
211 | data_obj.meta_data.set(['rotation_angle'], self.angles) |
||
212 | data_obj.meta_data.set(['centre_of_rotation'], self.cor) |
||
213 | |||
214 | """ |
||
215 | stats = MetaData() |
||
216 | |||
217 | stats.set(["stats", "min"], [0]*180) |
||
218 | stats.set(["stats", "max"], [5]*180) |
||
219 | stats.set(["stats", "RMSE"], [1, 7]) |
||
220 | |||
221 | data_obj.meta_data.set(["stats", "min", "PROJECTION"], [0] * 150) |
||
222 | data_obj.meta_data.set(["stats", "max", "PROJECTION"], [5] * 140) |
||
223 | data_obj.meta_data.set(["stats", "RMSE", "VOLUME_XZ"], [3] * 120) |
||
224 | """ |
||
225 | def __parameter_checks(self, data_obj): |
||
226 | if not self.parameters['proj_data_dims']: |
||
227 | raise Exception( |
||
228 | 'Please specifiy the dimensions of the dataset to create.') |
||
229 | |||
230 | def _get_n_entries(self): |
||
231 | return self.n_entries |
||
232 | |||
233 | def _link_nexus_file(self, data_obj, name): |
||
234 | filename = self.exp.meta_data.get('nxs_filename') |
||
235 | fsplit = filename.split('/') |
||
236 | plugin_number = len(self.exp.meta_data.plugin_list.plugin_list) |
||
237 | if plugin_number == 1: |
||
238 | fsplit[-1] = 'synthetic_data.nxs' |
||
239 | else: |
||
240 | fsplit[-1] = 'synthetic_data_processed.nxs' |
||
241 | filename = '/'.join(fsplit) |
||
242 | self.exp.meta_data.set('nxs_filename', filename) |
||
243 | if name == 'phantom': |
||
244 | data_obj.exp.meta_data.set(['group_name', 'phantom'], 'phantom') |
||
245 | data_obj.exp.meta_data.set(['link_type', 'phantom'], 'final_result') |
||
246 | data_obj.meta_data.set(["meta_data", "PLACEHOLDER", "VOLUME_XZ"], [10]) |
||
247 | |||
248 | else: |
||
249 | data_obj.exp.meta_data.set(['group_name', 'synth_proj_data'], 'entry1/tomo_entry/data') |
||
250 | data_obj.exp.meta_data.set(['link_type', 'synth_proj_data'], 'entry1') |
||
251 | |||
252 | self._populate_nexus_file(data_obj) |
||
253 | self._link_datafile_to_nexus_file(data_obj) |
||
254 | |||
255 | |||
256 | def _populate_nexus_file(self, data): |
||
257 | filename = self.exp.meta_data.get('nxs_filename') |
||
258 | name = data.data_info.get('name') |
||
259 | with h5py.File(filename, 'a') as nxs_file: |
||
260 | |||
261 | group_name = self.exp.meta_data.get(['group_name', name]) |
||
262 | link_type = self.exp.meta_data.get(['link_type', name]) |
||
263 | |||
264 | if name == 'phantom': |
||
265 | nxs_entry = nxs_file.create_group('entry') |
||
266 | if link_type == 'final_result': |
||
267 | group_name = 'final_result_' + data.get_name() |
||
268 | else: |
||
269 | link = nxs_entry.require_group(link_type.encode("ascii")) |
||
270 | link.attrs['NX_class'] = 'NXcollection' |
||
271 | nxs_entry = link |
||
272 | |||
273 | # delete the group if it already exists |
||
274 | if group_name in nxs_entry: |
||
275 | del nxs_entry[group_name] |
||
276 | |||
277 | plugin_entry = nxs_entry.require_group(group_name) |
||
278 | |||
279 | else: |
||
280 | plugin_entry = nxs_file.create_group(f'/{group_name}') |
||
281 | |||
282 | self.__output_data_patterns(data, plugin_entry) |
||
283 | self._output_metadata_dict(plugin_entry, data.meta_data.get_dictionary()) |
||
284 | self.__output_axis_labels(data, plugin_entry) |
||
285 | |||
286 | |||
287 | plugin_entry.attrs['NX_class'] = 'NXdata' |
||
288 | |||
289 | |||
290 | View Code Duplication | def __output_axis_labels(self, data, entry): |
|
291 | axis_labels = data.data_info.get("axis_labels") |
||
292 | ddict = data.meta_data.get_dictionary() |
||
293 | |||
294 | axes = [] |
||
295 | count = 0 |
||
296 | for labels in axis_labels: |
||
297 | name = list(labels.keys())[0] |
||
298 | axes.append(name) |
||
299 | entry.attrs[name + '_indices'] = count |
||
300 | |||
301 | mData = ddict[name] if name in list(ddict.keys()) \ |
||
302 | else np.arange(self.parameters['proj_data_dims'][count]) |
||
303 | if isinstance(mData, list): |
||
304 | mData = np.array(mData) |
||
305 | |||
306 | if 'U' in str(mData.dtype): |
||
307 | mData = mData.astype(np.string_) |
||
308 | if name not in list(entry.keys()): |
||
309 | axis_entry = entry.require_dataset(name, mData.shape, mData.dtype) |
||
310 | axis_entry[...] = mData[...] |
||
311 | axis_entry.attrs['units'] = list(labels.values())[0] |
||
312 | count += 1 |
||
313 | entry.attrs['axes'] = axes |
||
314 | |||
315 | View Code Duplication | def __output_data_patterns(self, data, entry): |
|
316 | data_patterns = data.data_info.get("data_patterns") |
||
317 | entry = entry.require_group('patterns') |
||
318 | entry.attrs['NX_class'] = 'NXcollection' |
||
319 | for pattern in data_patterns: |
||
320 | nx_data = entry.require_group(pattern) |
||
321 | nx_data.attrs['NX_class'] = 'NXparameters' |
||
322 | values = data_patterns[pattern] |
||
323 | self.__output_data(nx_data, values['core_dims'], 'core_dims') |
||
324 | self.__output_data(nx_data, values['slice_dims'], 'slice_dims') |
||
325 | |||
326 | def _output_metadata_dict(self, entry, mData): |
||
327 | entry.attrs['NX_class'] = 'NXcollection' |
||
328 | for key, value in mData.items(): |
||
329 | if key != 'rotation_angle': |
||
330 | nx_data = entry.require_group(key) |
||
331 | if isinstance(value, dict): |
||
332 | self._output_metadata_dict(nx_data, value) |
||
333 | else: |
||
334 | nx_data.attrs['NX_class'] = 'NXdata' |
||
335 | self.__output_data(nx_data, value, key) |
||
336 | |||
337 | View Code Duplication | def __output_data(self, entry, data, name): |
|
338 | if isinstance(data, dict): |
||
339 | entry = entry.require_group(name) |
||
340 | entry.attrs['NX_class'] = 'NXcollection' |
||
341 | for key, value in data.items(): |
||
342 | self.__output_data(entry, value, key) |
||
343 | else: |
||
344 | try: |
||
345 | self.__create_dataset(entry, name, data) |
||
346 | except Exception: |
||
347 | try: |
||
348 | import json |
||
349 | data = np.array([json.dumps(data).encode("ascii")]) |
||
350 | self.__create_dataset(entry, name, data) |
||
351 | except Exception: |
||
352 | try: |
||
353 | self.__create_dataset(entry, name, data) |
||
354 | except: |
||
355 | raise Exception('Unable to output %s to file.' % name) |
||
356 | |||
357 | def __create_dataset(self, entry, name, data): |
||
358 | if name not in list(entry.keys()): |
||
359 | entry.create_dataset(name, data=data) |
||
360 | else: |
||
361 | entry[name][...] = data |
||
362 | |||
363 | View Code Duplication | def _link_datafile_to_nexus_file(self, data): |
|
364 | filename = self.exp.meta_data.get('nxs_filename') |
||
365 | |||
366 | with h5py.File(filename, 'a') as nxs_file: |
||
367 | # entry path in nexus file |
||
368 | name = data.get_name() |
||
369 | group_name = self.exp.meta_data.get(['group_name', name]) |
||
370 | link = self.exp.meta_data.get(['link_type', name]) |
||
371 | name = data.get_name(orig=True) |
||
372 | nxs_entry = self.__add_nxs_entry(nxs_file, link, group_name, name) |
||
373 | self.__add_nxs_data(nxs_file, nxs_entry, link, group_name, data) |
||
374 | |||
375 | def __add_nxs_entry(self, nxs_file, link, group_name, name): |
||
376 | if name == 'phantom': |
||
377 | nxs_entry = '/entry/' + link |
||
378 | else: |
||
379 | nxs_entry = '' |
||
380 | nxs_entry += '_' + name if link == 'final_result' else "/" + group_name |
||
381 | nxs_entry = nxs_file[nxs_entry] |
||
382 | nxs_entry.attrs['signal'] = 'data' |
||
383 | return nxs_entry |
||
384 | |||
385 | View Code Duplication | def __add_nxs_data(self, nxs_file, nxs_entry, link, group_name, data): |
|
386 | data_entry = nxs_entry.name + '/data' |
||
387 | # output file path |
||
388 | h5file = data.backing_file.filename |
||
389 | |||
390 | if link == 'input_data': |
||
391 | dataset = self.__is_h5dataset(data) |
||
392 | if dataset: |
||
393 | nxs_file[data_entry] = \ |
||
394 | h5py.ExternalLink(os.path.abspath(h5file), dataset.name) |
||
395 | else: |
||
396 | # entry path in output file path |
||
397 | m_data = self.exp.meta_data.get |
||
398 | if not (link == 'intermediate' and |
||
399 | m_data('inter_path') != m_data('out_path')): |
||
400 | h5file = h5file.split(m_data('out_folder') + '/')[-1] |
||
401 | nxs_file[data_entry] = \ |
||
402 | h5py.ExternalLink(h5file, group_name + '/data') |
||
403 |