Total Complexity | 144 |
Total Lines | 692 |
Duplicated Lines | 4.34 % |
Changes | 0 |
Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.
Common duplication problems, and corresponding solutions are:
Complex classes like savu.core.transports.base_transport often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | # Copyright 2015 Diamond Light Source Ltd. |
||
2 | # |
||
3 | # Licensed under the Apache License, Version 2.0 (the "License"); |
||
4 | # you may not use this file except in compliance with the License. |
||
5 | # You may obtain a copy of the License at |
||
6 | # |
||
7 | # http://www.apache.org/licenses/LICENSE-2.0 |
||
8 | # |
||
9 | # Unless required by applicable law or agreed to in writing, software |
||
10 | # distributed under the License is distributed on an "AS IS" BASIS, |
||
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||
12 | # See the License for the specific language governing permissions and |
||
13 | # limitations under the License. |
||
14 | |||
15 | """ |
||
16 | .. module:: base_transport |
||
17 | :platform: Unix |
||
18 | :synopsis: A BaseTransport class which implements functions that control\ |
||
19 | the interaction between the data and plugin layers. |
||
20 | |||
21 | .. moduleauthor:: Nicola Wadeson <[email protected]> |
||
22 | |||
23 | """ |
||
24 | |||
25 | import os |
||
26 | import time |
||
27 | import copy |
||
28 | import h5py |
||
29 | import math |
||
30 | import logging |
||
31 | import numpy as np |
||
32 | |||
33 | import savu.core.utils as cu |
||
34 | import savu.plugins.utils as pu |
||
35 | from savu.data.data_structures.data_types.base_type import BaseType |
||
36 | from savu.core.iterate_plugin_group_utils import \ |
||
37 | check_if_end_plugin_in_iterate_group |
||
38 | |||
39 | NX_CLASS = 'NX_class' |
||
40 | |||
41 | |||
42 | class BaseTransport(object): |
||
43 | """ |
||
44 | Implements functions that control the interaction between the data and |
||
45 | plugin layers. |
||
46 | """ |
||
47 | |||
48 | def __init__(self): |
||
49 | self.pDict = None |
||
50 | self.no_processing = False |
||
51 | |||
52 | def _transport_initialise(self, options): |
||
53 | """ |
||
54 | Any initial setup required by the transport mechanism on start up.\ |
||
55 | This is called before the experiment is initialised. |
||
56 | """ |
||
57 | raise NotImplementedError("transport_control_setup needs to be " |
||
58 | "implemented in %s", self.__class__) |
||
59 | |||
60 | def _transport_update_plugin_list(self): |
||
61 | """ |
||
62 | This method provides an opportunity to add or remove items from the |
||
63 | plugin list before plugin list check. |
||
64 | """ |
||
65 | |||
66 | def _transport_pre_plugin_list_run(self): |
||
67 | """ |
||
68 | This method is called after all datasets have been created but BEFORE |
||
69 | the plugin list is processed. |
||
70 | """ |
||
71 | |||
72 | def _transport_load_plugin(self, exp, plugin_dict): |
||
73 | """ This method is called before each plugin is loaded """ |
||
74 | return pu.plugin_loader(exp, plugin_dict) |
||
75 | |||
76 | def _transport_pre_plugin(self): |
||
77 | """ |
||
78 | This method is called directly BEFORE each plugin is executed, but \ |
||
79 | after the plugin is loaded. |
||
80 | """ |
||
81 | |||
82 | def _transport_post_plugin(self): |
||
83 | """ |
||
84 | This method is called directly AFTER each plugin is executed. |
||
85 | """ |
||
86 | |||
87 | def _transport_post_plugin_list_run(self): |
||
88 | """ |
||
89 | This method is called AFTER the full plugin list has been processed. |
||
90 | """ |
||
91 | |||
92 | def _transport_terminate_dataset(self, data): |
||
93 | """ A dataset that will subequently be removed by the framework. |
||
94 | |||
95 | :param Data data: A data object to finalise. |
||
96 | """ |
||
97 | |||
98 | def process_setup(self, plugin): |
||
99 | pDict = {} |
||
100 | pDict['in_data'], pDict['out_data'] = plugin.get_datasets() |
||
101 | pDict['in_sl'] = self._get_all_slice_lists(pDict['in_data'], 'in') |
||
102 | pDict['out_sl'] = self._get_all_slice_lists(pDict['out_data'], 'out') |
||
103 | pDict['nIn'] = list(range(len(pDict['in_data']))) |
||
104 | pDict['nOut'] = list(range(len(pDict['out_data']))) |
||
105 | pDict['nProc'] = len(pDict['in_sl']['process']) |
||
106 | if 'transfer' in list(pDict['in_sl'].keys()): |
||
107 | pDict['nTrans'] = len(pDict['in_sl']['transfer'][0]) |
||
108 | else: |
||
109 | pDict['nTrans'] = 1 |
||
110 | pDict['squeeze'] = self._set_functions(pDict['in_data'], 'squeeze') |
||
111 | pDict['expand'] = self._set_functions(pDict['out_data'], 'expand') |
||
112 | |||
113 | frames = [f for f in pDict['in_sl']['frames']] |
||
114 | self._set_global_frame_index(plugin, frames, pDict['nProc']) |
||
115 | self.pDict = pDict |
||
116 | |||
117 | def _transport_process(self, plugin): |
||
118 | """ Organise required data and execute the main plugin processing. |
||
119 | |||
120 | :param plugin plugin: The current plugin instance. |
||
121 | """ |
||
122 | logging.info("transport_process initialise") |
||
123 | pDict, result, nTrans = self._initialise(plugin) |
||
124 | logging.info("transport_process get_checkpoint_params") |
||
125 | cp, sProc, sTrans = self.__get_checkpoint_params(plugin) |
||
126 | |||
127 | prange = list(range(sProc, pDict['nProc'])) |
||
128 | kill = False |
||
129 | for count in range(sTrans, nTrans): |
||
130 | end = True if count == nTrans-1 else False |
||
131 | self._log_completion_status(count, nTrans, plugin.name) |
||
132 | |||
133 | # get the transfer data |
||
134 | logging.info("Transferring the data") |
||
135 | transfer_data = self._transfer_all_data(count) |
||
136 | |||
137 | if count == nTrans-1 and plugin.fixed_length == False: |
||
138 | shape = [data.shape for data in transfer_data] |
||
139 | prange = self.remove_extra_slices(prange, shape) |
||
140 | |||
141 | # loop over the process data |
||
142 | logging.info("process frames loop") |
||
143 | result, kill = self._process_loop( |
||
144 | plugin, prange, transfer_data, count, pDict, result, cp) |
||
145 | |||
146 | logging.info("Returning the data") |
||
147 | self._return_all_data(count, result, end) |
||
148 | |||
149 | if kill: |
||
150 | return 1 |
||
151 | |||
152 | if not kill: |
||
153 | cu.user_message("%s - 100%% complete" % (plugin.name)) |
||
154 | |||
155 | def remove_extra_slices(self, prange, transfer_shape): |
||
156 | # loop over datasets: |
||
157 | for i, data in enumerate(self.pDict['in_data']): |
||
158 | pData = data._get_plugin_data() |
||
159 | mft = pData.meta_data.get("max_frames_transfer") |
||
160 | mfp = pData.meta_data.get("max_frames_process") |
||
161 | sdirs = data.get_slice_dimensions() |
||
162 | finish = np.prod([transfer_shape[i][j] for j in sdirs]) |
||
163 | rem, full = math.modf((mft - finish)/mfp) |
||
164 | full = int(full) |
||
165 | |||
166 | if rem: |
||
167 | rem = (mft-finish) - full |
||
168 | self._update_slice_list("in_sl", i, full, sdirs[0], rem) |
||
169 | for j, out_data in enumerate(self.pDict['out_data']): |
||
170 | out_pData = out_data._get_plugin_data() |
||
171 | out_mfp = out_pData.meta_data.get("max_frames_process") |
||
172 | out_sdir = data.get_slice_dimensions()[0] |
||
173 | out_rem = rem/(mfp/out_mfp) |
||
174 | if out_rem%1: |
||
175 | raise Exception("'Fixed_length' plugin option is invalid") |
||
176 | self._update_slice_list("out_sl", j, full, out_sdir, int(out_rem)) |
||
177 | |||
178 | return list(range(prange[0], prange[-1]+1-full)) |
||
|
|||
179 | |||
180 | def _update_slice_list(self, key, idx, remove, dim, amount): |
||
181 | sl = list(self.pDict[key]['process'][idx][-remove]) |
||
182 | s = sl[dim] |
||
183 | sl[dim] = slice(s.start, s.stop - amount*s.step, s.step) |
||
184 | self.pDict[key]['process'][idx][-1] = sl |
||
185 | |||
186 | def _process_loop(self, plugin, prange, tdata, count, pDict, result, cp): |
||
187 | kill_signal = False |
||
188 | for i in prange: |
||
189 | if cp and cp.is_time_to_checkpoint(self, count, i): |
||
190 | # kill signal sent so stop the processing |
||
191 | return result, True |
||
192 | data = self._get_input_data(plugin, tdata, i, count) |
||
193 | res = self._get_output_data( |
||
194 | plugin.plugin_process_frames(data), i) |
||
195 | |||
196 | for j in pDict['nOut']: |
||
197 | if res is not None: |
||
198 | out_sl = pDict['out_sl']['process'][i][j] |
||
199 | if any("res_norm" in s for s in self.data_flow): |
||
200 | # an exception when the metadata is created automatically by a parameters in the plugin |
||
201 | # this is to fix CGLS_CUDA with a res_norm metadata |
||
202 | result[j][out_sl] = res[0][j, ] |
||
203 | else: |
||
204 | result[j][out_sl] = res[j] |
||
205 | else: |
||
206 | result[j] = None |
||
207 | return result, kill_signal |
||
208 | |||
209 | def __get_checkpoint_params(self, plugin): |
||
210 | cp = self.exp.checkpoint |
||
211 | if cp: |
||
212 | cp._initialise(plugin.get_communicator()) |
||
213 | return cp, cp.get_proc_idx(), cp.get_trans_idx() |
||
214 | return None, 0, 0 |
||
215 | |||
216 | def _initialise(self, plugin): |
||
217 | self.process_setup(plugin) |
||
218 | pDict = self.pDict |
||
219 | result = [np.empty(d._get_plugin_data().get_shape_transfer(), |
||
220 | dtype=np.float32) for d in pDict['out_data']] |
||
221 | # loop over the transfer data |
||
222 | nTrans = pDict['nTrans'] |
||
223 | self.no_processing = True if not nTrans else False |
||
224 | return pDict, result, nTrans |
||
225 | |||
226 | def _log_completion_status(self, count, nTrans, name): |
||
227 | percent_complete: float = count / (nTrans * 0.01) |
||
228 | cu.user_message("%s - %3i%% complete" % (name, percent_complete)) |
||
229 | |||
230 | def _transport_checkpoint(self): |
||
231 | """ The framework has determined it is time to checkpoint. What |
||
232 | should the transport mechanism do? Override if appropriate. """ |
||
233 | return False |
||
234 | |||
235 | def _transport_kill_signal(self): |
||
236 | """ |
||
237 | An opportunity to send a kill signal to the framework. Return |
||
238 | True or False. """ |
||
239 | return False |
||
240 | |||
241 | def _get_all_slice_lists(self, data_list, dtype): |
||
242 | """ |
||
243 | Get all slice lists for the current process. |
||
244 | |||
245 | :param list(Data) data_list: Datasets |
||
246 | :returns: A list of dictionaries containing slice lists for each \ |
||
247 | dataset |
||
248 | :rtype: list(dict) |
||
249 | """ |
||
250 | sl_dict = {} |
||
251 | for data in data_list: |
||
252 | sl = data._get_transport_data().\ |
||
253 | _get_slice_lists_per_process(dtype) |
||
254 | for key, value in sl.items(): |
||
255 | if key not in sl_dict: |
||
256 | sl_dict[key] = [value] |
||
257 | else: |
||
258 | sl_dict[key].append(value) |
||
259 | |||
260 | for key in [k for k in ['process', 'unpad'] if k in list(sl_dict.keys())]: |
||
261 | nData = list(range(len(sl_dict[key]))) |
||
262 | #rep = range(len(sl_dict[key][0])) |
||
263 | sl_dict[key] = [[sl_dict[key][i][j] for i in nData if j < len(sl_dict[key][i])] for j in range(len(sl_dict[key][0]))] |
||
264 | return sl_dict |
||
265 | |||
266 | def _transfer_all_data(self, count): |
||
267 | """ |
||
268 | Transfer data from file and pad if required. |
||
269 | |||
270 | :param int count: The current frame index. |
||
271 | :returns: All data for this frame and associated padded slice lists |
||
272 | :rtype: list(np.ndarray), list(tuple(slice)) |
||
273 | """ |
||
274 | pDict = self.pDict |
||
275 | data_list = pDict['in_data'] |
||
276 | |||
277 | if 'transfer' in list(pDict['in_sl'].keys()): |
||
278 | slice_list = \ |
||
279 | [pDict['in_sl']['transfer'][i][count] for i in pDict['nIn']] |
||
280 | else: |
||
281 | slice_list = [slice(None)]*len(pDict['nIn']) |
||
282 | |||
283 | section = [] |
||
284 | for i, item in enumerate(data_list): |
||
285 | section.append(data_list[i]._get_transport_data(). |
||
286 | _get_padded_data(slice_list[i])) |
||
287 | return section |
||
288 | |||
289 | def _get_input_data(self, plugin, trans_data, nproc, ntrans): |
||
290 | data = [] |
||
291 | current_sl = [] |
||
292 | for d in self.pDict['nIn']: |
||
293 | in_sl = self.pDict['in_sl']['process'][nproc][d] |
||
294 | data.append(self.pDict['squeeze'][d](trans_data[d][in_sl])) |
||
295 | entry = ntrans*self.pDict['nProc'] + nproc |
||
296 | if entry < len(self.pDict['in_sl']['current'][d]): |
||
297 | current_sl.append(self.pDict['in_sl']['current'][d][entry]) |
||
298 | else: |
||
299 | current_sl.append(self.pDict['in_sl']['current'][d][-1]) |
||
300 | plugin.set_current_slice_list(current_sl) |
||
301 | return data |
||
302 | |||
303 | def _get_output_data(self, result, count): |
||
304 | if result is None: |
||
305 | return |
||
306 | unpad_sl = self.pDict['out_sl']['unpad'][count] |
||
307 | result = result if isinstance(result, list) else [result] |
||
308 | for j in self.pDict['nOut']: |
||
309 | if any("res_norm" in s for s in self.data_flow): |
||
310 | # an exception when the metadata is created automatically by a parameters in the plugin |
||
311 | # this is to fix CGLS_CUDA with a res_norm metadata |
||
312 | result[0][j, ] = self.pDict['expand'][j](result[0][j, ])[unpad_sl[j]] |
||
313 | else: |
||
314 | result[j] = self.pDict['expand'][j](result[j])[unpad_sl[j]] |
||
315 | return result |
||
316 | |||
317 | def _return_all_data(self, count, result, end): |
||
318 | """ |
||
319 | Transfer plugin results for current frame to backing files. |
||
320 | |||
321 | :param int count: The current frame index. |
||
322 | :param list(np.ndarray) result: plugin results |
||
323 | :param bool end: True if this is the last entry in the slice list. |
||
324 | """ |
||
325 | pDict = self.pDict |
||
326 | data_list = pDict['out_data'] |
||
327 | |||
328 | slice_list = None |
||
329 | if 'transfer' in list(pDict['out_sl'].keys()): |
||
330 | slice_list = \ |
||
331 | [pDict['out_sl']['transfer'][i][count] for i in pDict['nOut'] \ |
||
332 | if len(pDict['out_sl']['transfer'][i]) > count] |
||
333 | |||
334 | result = [result] if type(result) is not list else result |
||
335 | |||
336 | for i, item in enumerate(data_list): |
||
337 | if result[i] is not None: |
||
338 | if slice_list: |
||
339 | temp = self._remove_excess_data( |
||
340 | data_list[i], result[i], slice_list[i]) |
||
341 | data_list[i].data[slice_list[i]] = temp |
||
342 | else: |
||
343 | data_list[i].data = result[i] |
||
344 | |||
345 | def _set_global_frame_index(self, plugin, frame_list, nProc): |
||
346 | """ Convert the transfer global frame index to a process global frame |
||
347 | index. |
||
348 | """ |
||
349 | process_frames = [] |
||
350 | for f in frame_list: |
||
351 | if len(f): |
||
352 | process_frames.append(list(range(f[0]*nProc, (f[-1]+1)*nProc))) |
||
353 | |||
354 | process_frames = np.array(process_frames) |
||
355 | nframes = plugin.get_plugin_in_datasets()[0].get_total_frames() |
||
356 | process_frames[process_frames >= nframes] = nframes - 1 |
||
357 | frames = process_frames[0] if process_frames.size else process_frames |
||
358 | plugin.set_global_frame_index(frames) |
||
359 | |||
360 | def _set_functions(self, data_list, name): |
||
361 | """ Create a dictionary of functions to remove (squeeze) or re-add |
||
362 | (expand) dimensions, of length 1, from each dataset in a list. |
||
363 | |||
364 | :param list(Data) data_list: Datasets |
||
365 | :param str name: 'squeeze' or 'expand' |
||
366 | :returns: A dictionary of lambda functions |
||
367 | :rtype: dict |
||
368 | """ |
||
369 | str_name = 'self.' + name + '_output' |
||
370 | function = {'expand': self.__create_expand_function, |
||
371 | 'squeeze': self.__create_squeeze_function} |
||
372 | ddict = {} |
||
373 | for i, item in enumerate(data_list): |
||
374 | ddict[i] = {i: str_name + str(i)} |
||
375 | ddict[i] = function[name](data_list[i]) |
||
376 | return ddict |
||
377 | |||
378 | def __create_expand_function(self, data): |
||
379 | """ Create a function that re-adds missing dimensions of length 1. |
||
380 | |||
381 | :param Data data: Dataset |
||
382 | :returns: expansion function |
||
383 | :rtype: lambda |
||
384 | """ |
||
385 | slice_dirs = data.get_slice_dimensions() |
||
386 | n_core_dirs = len(data.get_core_dimensions()) |
||
387 | new_slice = [slice(None)]*len(data.get_shape()) |
||
388 | possible_slices = [copy.copy(new_slice)] |
||
389 | |||
390 | pData = data._get_plugin_data() |
||
391 | if pData._get_rank_inc(): |
||
392 | possible_slices[0] += [0]*pData._get_rank_inc() |
||
393 | |||
394 | if len(slice_dirs) > 1: |
||
395 | for sl in slice_dirs[1:]: |
||
396 | new_slice[sl] = None |
||
397 | possible_slices.append(copy.copy(new_slice)) |
||
398 | new_slice[slice_dirs[0]] = None |
||
399 | possible_slices.append(copy.copy(new_slice)) |
||
400 | possible_slices = possible_slices[::-1] |
||
401 | return lambda x: x[tuple(possible_slices[len(x.shape)-n_core_dirs])] |
||
402 | |||
403 | def __create_squeeze_function(self, data): |
||
404 | """ Create a function that removes dimensions of length 1. |
||
405 | |||
406 | :param Data data: Dataset |
||
407 | :returns: squeeze function |
||
408 | :rtype: lambda |
||
409 | """ |
||
410 | pData = data._get_plugin_data() |
||
411 | max_frames = pData._get_max_frames_process() |
||
412 | |||
413 | pad = True if pData.padding and data.get_slice_dimensions()[0] in \ |
||
414 | list(pData.padding._get_padding_directions().keys()) else False |
||
415 | |||
416 | n_core_dims = len(data.get_core_dimensions()) |
||
417 | squeeze_dims = data.get_slice_dimensions() |
||
418 | if max_frames > 1 or pData._get_no_squeeze() or pad: |
||
419 | squeeze_dims = squeeze_dims[1:] |
||
420 | n_core_dims +=1 |
||
421 | if pData._get_rank_inc(): |
||
422 | sl = [(slice(None))]*n_core_dims + [None]*pData._get_rank_inc() |
||
423 | return lambda x: np.squeeze(x[tuple(sl)], axis=squeeze_dims) |
||
424 | return lambda x: np.squeeze(x, axis=squeeze_dims) |
||
425 | |||
426 | def _remove_excess_data(self, data, result, slice_list): |
||
427 | """ Remove any excess results due to padding for fixed length process \ |
||
428 | frames. """ |
||
429 | |||
430 | mData = data._get_plugin_data().meta_data.get_dictionary() |
||
431 | temp = np.where(np.array(mData['size_list']) > 1)[0] |
||
432 | sdir = mData['sdir'][temp[-1] if temp.size else 0] |
||
433 | |||
434 | # Not currently working for basic_transport |
||
435 | if isinstance(slice_list, slice): |
||
436 | return |
||
437 | |||
438 | sl = slice_list[sdir] |
||
439 | shape = result.shape |
||
440 | |||
441 | if shape[sdir] - (sl.stop - sl.start): |
||
442 | unpad_sl = [slice(None)]*len(shape) |
||
443 | unpad_sl[sdir] = slice(0, sl.stop - sl.start) |
||
444 | result = result[tuple(unpad_sl)] |
||
445 | return result |
||
446 | |||
447 | def _setup_h5_files(self): |
||
448 | out_data_dict = self.exp.index["out_data"] |
||
449 | |||
450 | current_and_next = False |
||
451 | if 'current_and_next' in self.exp.meta_data.get_dictionary(): |
||
452 | current_and_next = self.exp.meta_data.get('current_and_next') |
||
453 | |||
454 | count = 0 |
||
455 | for key in out_data_dict.keys(): |
||
456 | out_data = out_data_dict[key] |
||
457 | filename = self.exp.meta_data.get(["filename", key]) |
||
458 | out_data.backing_file = self.hdf5._open_backing_h5(filename, 'a') |
||
459 | c_and_n = 0 if not current_and_next else current_and_next[key] |
||
460 | out_data.group_name, out_data.group = self.hdf5._create_entries( |
||
461 | out_data, key, c_and_n) |
||
462 | count += 1 |
||
463 | |||
464 | def _set_file_details(self, files): |
||
465 | self.exp.meta_data.set('link_type', files['link_type']) |
||
466 | self.exp.meta_data.set('link_type', {}) |
||
467 | self.exp.meta_data.set('filename', {}) |
||
468 | self.exp.meta_data.set('group_name', {}) |
||
469 | for key in list(self.exp.index['out_data'].keys()): |
||
470 | self.exp.meta_data.set(['link_type', key], files['link_type'][key]) |
||
471 | self.exp.meta_data.set(['filename', key], files['filename'][key]) |
||
472 | self.exp.meta_data.set(['group_name', key], |
||
473 | files['group_name'][key]) |
||
474 | |||
475 | def _get_filenames(self, plugin_dict): |
||
476 | count = self.exp.meta_data.get('nPlugin') + 1 |
||
477 | files = {"filename": {}, "group_name": {}, "link_type": {}} |
||
478 | for key in list(self.exp.index["out_data"].keys()): |
||
479 | name = key + '_p' + str(count) + '_' + \ |
||
480 | plugin_dict['id'].split('.')[-1] + '.h5' |
||
481 | link_type = self._get_link_type(key) |
||
482 | files['link_type'][key] = link_type |
||
483 | if link_type == 'final_result': |
||
484 | out_path = self.exp.meta_data.get('out_path') |
||
485 | else: |
||
486 | out_path = self.exp.meta_data.get('inter_path') |
||
487 | |||
488 | filename = os.path.join(out_path, name) |
||
489 | group_name = "%i-%s-%s" % (count, plugin_dict['name'], key) |
||
490 | files["filename"][key] = filename |
||
491 | files["group_name"][key] = group_name |
||
492 | |||
493 | return files |
||
494 | |||
495 | def _get_link_type(self, name): |
||
496 | idx = self.exp.meta_data.get('nPlugin') |
||
497 | temp = [e for entry in self.data_flow[idx+1:] for e in entry] |
||
498 | if name in temp or self.exp.index['out_data'][name].remove: |
||
499 | return 'intermediate' |
||
500 | return 'final_result' |
||
501 | |||
502 | def _populate_nexus_file(self, data, iterate_group=None): |
||
503 | filename = self.exp.meta_data.get('nxs_filename') |
||
504 | |||
505 | with h5py.File(filename, 'a') as nxs_file: |
||
506 | nxs_entry = nxs_file['entry'] |
||
507 | name = data.data_info.get('name') |
||
508 | group_name = self.exp.meta_data.get(['group_name', name]) |
||
509 | link_type = self.exp.meta_data.get(['link_type', name]) |
||
510 | |||
511 | if link_type == 'final_result': |
||
512 | if iterate_group is not None and \ |
||
513 | check_if_end_plugin_in_iterate_group(self.exp): |
||
514 | is_clone_data = 'clone' in name |
||
515 | is_even_iterations = \ |
||
516 | iterate_group._ip_fixed_iterations % 2 == 0 |
||
517 | # don't need to create group for: |
||
518 | # - clone dataset, if running an odd number of iterations |
||
519 | # - original dataset, if running an even number of |
||
520 | # iterations |
||
521 | if is_clone_data and not is_even_iterations: |
||
522 | return |
||
523 | elif not is_clone_data and is_even_iterations: |
||
524 | return |
||
525 | # the group name for the output of the iterative loop should be |
||
526 | # named after the original dataset, regardless of if the link |
||
527 | # eventually points to the original or the clone, for the sake |
||
528 | # of the linkname referencing the dataset name set in |
||
529 | # savu_config |
||
530 | group_name = 'final_result_' + data.get_name(orig=True) |
||
531 | else: |
||
532 | link = nxs_entry.require_group(link_type.encode("ascii")) |
||
533 | link.attrs[NX_CLASS] = 'NXcollection' |
||
534 | nxs_entry = link |
||
535 | |||
536 | # delete the group if it already exists |
||
537 | if group_name in nxs_entry: |
||
538 | del nxs_entry[group_name] |
||
539 | |||
540 | plugin_entry = nxs_entry.require_group(group_name) |
||
541 | plugin_entry.attrs[NX_CLASS] = 'NXdata' |
||
542 | if iterate_group is not None and \ |
||
543 | check_if_end_plugin_in_iterate_group(self.exp): |
||
544 | # always write the metadata under the name of the original |
||
545 | # dataset, not the clone dataset |
||
546 | self._output_metadata(data, plugin_entry, |
||
547 | data.get_name(orig=False)) |
||
548 | else: |
||
549 | self._output_metadata(data, plugin_entry, name) |
||
550 | |||
551 | def _populate_pre_run_nexus_file(self, data): |
||
552 | filename = self.exp.meta_data.get('nxs_filename') |
||
553 | |||
554 | data_path = self.exp.meta_data["data_path"] |
||
555 | image_key_path = self.exp.meta_data["image_key_path"] |
||
556 | name = data.data_info.get('name') |
||
557 | group_name = self.exp.meta_data.get(['group_name', name]) |
||
558 | with h5py.File(filename, 'a') as nxs_file: |
||
559 | if data_path in nxs_file: |
||
560 | del nxs_file[data_path] |
||
561 | nxs_file[data_path] = h5py.ExternalLink(os.path.abspath(data.backing_file.filename), f"{group_name}/data") |
||
562 | |||
563 | if image_key_path in nxs_file: |
||
564 | nxs_file[image_key_path][::] = data.data.image_key[::] |
||
565 | #nxs_file[data_path].attrs.create("pre_run", True) |
||
566 | |||
567 | def _output_metadata(self, data, entry, name, dump=False): |
||
568 | self.__output_data_type(entry, data, name) |
||
569 | mDict = data.meta_data.get_dictionary() |
||
570 | self._output_metadata_dict(entry.require_group('meta_data'), mDict) |
||
571 | |||
572 | if not dump: |
||
573 | self.__output_axis_labels(data, entry) |
||
574 | self.__output_data_patterns(data, entry) |
||
575 | if self.exp.meta_data.get('link_type')[name] == 'input_data': |
||
576 | # output the filename |
||
577 | entry['file_path'] = \ |
||
578 | os.path.abspath(self.exp.meta_data.get('data_file')) |
||
579 | |||
580 | def __output_data_type(self, entry, data, name): |
||
581 | data = data.data if 'data' in list(data.__dict__.keys()) else data |
||
582 | if isinstance(data, h5py.Dataset): |
||
583 | return |
||
584 | |||
585 | entry = entry.require_group('data_type') |
||
586 | entry.attrs[NX_CLASS] = 'NXcollection' |
||
587 | |||
588 | ltype = self.exp.meta_data.get('link_type') |
||
589 | if name in list(ltype.keys()) and ltype[name] == 'input_data': |
||
590 | self.__output_data(entry, data.__class__.__name__, 'cls') |
||
591 | return |
||
592 | |||
593 | args, kwargs, cls, extras = data._get_parameters(data.get_clone_args()) |
||
594 | |||
595 | for key, value in kwargs.items(): |
||
596 | gp = entry.require_group('kwargs') |
||
597 | if isinstance(value, BaseType): |
||
598 | self.__output_data_type(gp.require_group(key), value, key) |
||
599 | else: |
||
600 | self.__output_data(gp, value, key) |
||
601 | |||
602 | for key, value in extras.items(): |
||
603 | gp = entry.require_group('extras') |
||
604 | if isinstance(value, BaseType): |
||
605 | self.__output_data_type(gp.require_group(key), value, key) |
||
606 | else: |
||
607 | self.__output_data(gp, value, key) |
||
608 | |||
609 | for i, item in enumerate(args): |
||
610 | gp = entry.require_group('args') |
||
611 | self.__output_data(gp, args[i], ''.join(['args', str(i)])) |
||
612 | |||
613 | self.__output_data(entry, cls, 'cls') |
||
614 | |||
615 | if 'data' in list(data.__dict__.keys()) and not \ |
||
616 | isinstance(data.data, h5py.Dataset): |
||
617 | gp = entry.require_group('data') |
||
618 | self.__output_data_type(gp, data.data, 'data') |
||
619 | |||
620 | View Code Duplication | def __output_data(self, entry, data, name): |
|
621 | if isinstance(data, dict): |
||
622 | entry = entry.require_group(name) |
||
623 | entry.attrs[NX_CLASS] = 'NXcollection' |
||
624 | for key, value in data.items(): |
||
625 | self.__output_data(entry, value, key) |
||
626 | else: |
||
627 | try: |
||
628 | self.__create_dataset(entry, name, data) |
||
629 | except Exception: |
||
630 | try: |
||
631 | import json |
||
632 | data = np.array([json.dumps(data).encode("ascii")]) |
||
633 | self.__create_dataset(entry, name, data) |
||
634 | except Exception: |
||
635 | try: |
||
636 | data = cu._savu_encoder(data) |
||
637 | self.__create_dataset(entry, name, data) |
||
638 | except: |
||
639 | raise Exception('Unable to output %s to file.' % name) |
||
640 | |||
641 | def __create_dataset(self, entry, name, data): |
||
642 | if name not in list(entry.keys()): |
||
643 | entry.create_dataset(name, data=data) |
||
644 | else: |
||
645 | entry[name][...] = data |
||
646 | |||
647 | def __output_axis_labels(self, data, entry): |
||
648 | axis_labels = data.data_info.get("axis_labels") |
||
649 | ddict = data.meta_data.get_dictionary() |
||
650 | |||
651 | axes = [] |
||
652 | count = 0 |
||
653 | for labels in axis_labels: |
||
654 | name = list(labels.keys())[0] |
||
655 | axes.append(name) |
||
656 | entry.attrs[name + '_indices'] = count |
||
657 | |||
658 | mData = ddict[name] if name in list(ddict.keys()) \ |
||
659 | else np.arange(data.get_shape()[count]) |
||
660 | if isinstance(mData, list): |
||
661 | mData = np.array(mData) |
||
662 | |||
663 | if 'U' in str(mData.dtype): |
||
664 | mData = mData.astype(np.string_) |
||
665 | |||
666 | axis_entry = entry.require_dataset(name, mData.shape, mData.dtype) |
||
667 | axis_entry[...] = mData[...] |
||
668 | axis_entry.attrs['units'] = list(labels.values())[0] |
||
669 | count += 1 |
||
670 | entry.attrs['axes'] = axes |
||
671 | |||
672 | View Code Duplication | def __output_data_patterns(self, data, entry): |
|
673 | data_patterns = data.data_info.get("data_patterns") |
||
674 | entry = entry.require_group('patterns') |
||
675 | entry.attrs[NX_CLASS] = 'NXcollection' |
||
676 | for pattern in data_patterns: |
||
677 | nx_data = entry.require_group(pattern) |
||
678 | nx_data.attrs[NX_CLASS] = 'NXparameters' |
||
679 | values = data_patterns[pattern] |
||
680 | self.__output_data(nx_data, values['core_dims'], 'core_dims') |
||
681 | self.__output_data(nx_data, values['slice_dims'], 'slice_dims') |
||
682 | |||
683 | def _output_metadata_dict(self, entry, mData): |
||
684 | entry.attrs[NX_CLASS] = 'NXcollection' |
||
685 | for key, value in mData.items(): |
||
686 | nx_data = entry.require_group(key) |
||
687 | if isinstance(value, dict): |
||
688 | self._output_metadata_dict(nx_data, value) |
||
689 | else: |
||
690 | nx_data.attrs[NX_CLASS] = 'NXdata' |
||
691 | self.__output_data(nx_data, value, key) |
||
692 |