1
|
|
|
# Copyright 2015 Diamond Light Source Ltd. |
2
|
|
|
# |
3
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
4
|
|
|
# you may not use this file except in compliance with the License. |
5
|
|
|
# You may obtain a copy of the License at |
6
|
|
|
# |
7
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
8
|
|
|
# |
9
|
|
|
# Unless required by applicable law or agreed to in writing, software |
10
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS, |
11
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12
|
|
|
# See the License for the specific language governing permissions and |
13
|
|
|
# limitations under the License. |
14
|
|
|
|
15
|
|
|
""" |
16
|
|
|
.. module:: base_transport |
17
|
|
|
:platform: Unix |
18
|
|
|
:synopsis: A BaseTransport class which implements functions that control\ |
19
|
|
|
the interaction between the data and plugin layers. |
20
|
|
|
|
21
|
|
|
.. moduleauthor:: Nicola Wadeson <[email protected]> |
22
|
|
|
|
23
|
|
|
""" |
24
|
|
|
|
25
|
|
|
import os |
26
|
|
|
import time |
27
|
|
|
import copy |
28
|
|
|
import h5py |
29
|
|
|
import logging |
30
|
|
|
import numpy as np |
31
|
|
|
|
32
|
|
|
import savu.core.utils as cu |
33
|
|
|
import savu.plugins.utils as pu |
34
|
|
|
from savu.data.data_structures.data_types.base_type import BaseType |
35
|
|
|
|
36
|
|
|
NX_CLASS = 'NX_class' |
37
|
|
|
|
38
|
|
|
|
39
|
|
|
class BaseTransport(object): |
40
|
|
|
""" |
41
|
|
|
Implements functions that control the interaction between the data and |
42
|
|
|
plugin layers. |
43
|
|
|
""" |
44
|
|
|
|
45
|
|
|
def __init__(self): |
46
|
|
|
self.pDict = None |
47
|
|
|
self.no_processing = False |
48
|
|
|
|
49
|
|
|
def _transport_initialise(self, options): |
50
|
|
|
""" |
51
|
|
|
Any initial setup required by the transport mechanism on start up.\ |
52
|
|
|
This is called before the experiment is initialised. |
53
|
|
|
""" |
54
|
|
|
raise NotImplementedError("transport_control_setup needs to be " |
55
|
|
|
"implemented in %s", self.__class__) |
56
|
|
|
|
57
|
|
|
def _transport_update_plugin_list(self): |
58
|
|
|
""" |
59
|
|
|
This method provides an opportunity to add or remove items from the |
60
|
|
|
plugin list before plugin list check. |
61
|
|
|
""" |
62
|
|
|
pass |
63
|
|
|
|
64
|
|
|
def _transport_pre_plugin_list_run(self): |
65
|
|
|
""" |
66
|
|
|
This method is called after all datasets have been created but BEFORE |
67
|
|
|
the plugin list is processed. |
68
|
|
|
""" |
69
|
|
|
pass |
70
|
|
|
|
71
|
|
|
def _transport_load_plugin(self, exp, plugin_dict): |
72
|
|
|
""" This method is called before each plugin is loaded """ |
73
|
|
|
return pu.plugin_loader(exp, plugin_dict) |
74
|
|
|
|
75
|
|
|
def _transport_pre_plugin(self): |
76
|
|
|
""" |
77
|
|
|
This method is called directly BEFORE each plugin is executed, but \ |
78
|
|
|
after the plugin is loaded. |
79
|
|
|
""" |
80
|
|
|
pass |
81
|
|
|
|
82
|
|
|
def _transport_post_plugin(self): |
83
|
|
|
""" |
84
|
|
|
This method is called directly AFTER each plugin is executed. |
85
|
|
|
""" |
86
|
|
|
pass |
87
|
|
|
|
88
|
|
|
def _transport_post_plugin_list_run(self): |
89
|
|
|
""" |
90
|
|
|
This method is called AFTER the full plugin list has been processed. |
91
|
|
|
""" |
92
|
|
|
pass |
93
|
|
|
|
94
|
|
|
def _transport_terminate_dataset(self, data): |
95
|
|
|
""" A dataset that will subequently be removed by the framework. |
96
|
|
|
|
97
|
|
|
:param Data data: A data object to finalise. |
98
|
|
|
""" |
99
|
|
|
pass |
100
|
|
|
|
101
|
|
|
def process_setup(self, plugin): |
102
|
|
|
pDict = {} |
103
|
|
|
pDict['in_data'], pDict['out_data'] = plugin.get_datasets() |
104
|
|
|
pDict['in_sl'] = self._get_all_slice_lists(pDict['in_data'], 'in') |
105
|
|
|
pDict['out_sl'] = self._get_all_slice_lists(pDict['out_data'], 'out') |
106
|
|
|
pDict['nIn'] = list(range(len(pDict['in_data']))) |
107
|
|
|
pDict['nOut'] = list(range(len(pDict['out_data']))) |
108
|
|
|
pDict['nProc'] = len(pDict['in_sl']['process']) |
109
|
|
|
if 'transfer' in list(pDict['in_sl'].keys()): |
110
|
|
|
pDict['nTrans'] = len(pDict['in_sl']['transfer'][0]) |
111
|
|
|
else: |
112
|
|
|
pDict['nTrans'] = 1 |
113
|
|
|
pDict['squeeze'] = self._set_functions(pDict['in_data'], 'squeeze') |
114
|
|
|
pDict['expand'] = self._set_functions(pDict['out_data'], 'expand') |
115
|
|
|
|
116
|
|
|
frames = [f for f in pDict['in_sl']['frames']] |
117
|
|
|
self._set_global_frame_index(plugin, frames, pDict['nProc']) |
118
|
|
|
self.pDict = pDict |
119
|
|
|
|
120
|
|
|
def _transport_process(self, plugin): |
121
|
|
|
""" Organise required data and execute the main plugin processing. |
122
|
|
|
|
123
|
|
|
:param plugin plugin: The current plugin instance. |
124
|
|
|
""" |
125
|
|
|
logging.info("transport_process initialise") |
126
|
|
|
pDict, result, nTrans = self._initialise(plugin) |
127
|
|
|
logging.info("transport_process get_checkpoint_params") |
128
|
|
|
cp, sProc, sTrans = self.__get_checkpoint_params(plugin) |
129
|
|
|
|
130
|
|
|
count = 0 # temporary solution |
131
|
|
|
prange = list(range(sProc, pDict['nProc'])) |
132
|
|
|
kill = False |
133
|
|
|
for count in range(sTrans, nTrans): |
134
|
|
|
end = True if count == nTrans-1 else False |
135
|
|
|
self._log_completion_status(count, nTrans, plugin.name) |
136
|
|
|
|
137
|
|
|
# get the transfer data |
138
|
|
|
logging.info("Transferring the data") |
139
|
|
|
transfer_data = self._transfer_all_data(count) |
140
|
|
|
|
141
|
|
|
# loop over the process data |
142
|
|
|
logging.info("process frames loop") |
143
|
|
|
result, kill = self._process_loop( |
144
|
|
|
plugin, prange, transfer_data, count, pDict, result, cp) |
145
|
|
|
|
146
|
|
|
logging.info("Returning the data") |
147
|
|
|
self._return_all_data(count, result, end) |
148
|
|
|
|
149
|
|
|
if kill: |
150
|
|
|
return 1 |
151
|
|
|
|
152
|
|
|
if not kill: |
153
|
|
|
cu.user_message("%s - 100%% complete" % (plugin.name)) |
154
|
|
|
|
155
|
|
|
def _process_loop(self, plugin, prange, tdata, count, pDict, result, cp): |
156
|
|
|
kill_signal = False |
157
|
|
|
for i in prange: |
158
|
|
|
if cp and cp.is_time_to_checkpoint(self, count, i): |
159
|
|
|
# kill signal sent so stop the processing |
160
|
|
|
return result, True |
161
|
|
|
data = self._get_input_data(plugin, tdata, i, count) |
162
|
|
|
res = self._get_output_data( |
163
|
|
|
plugin.plugin_process_frames(data), i) |
164
|
|
|
|
165
|
|
|
for j in pDict['nOut']: |
166
|
|
|
if res is not None: |
167
|
|
|
out_sl = pDict['out_sl']['process'][i][j] |
168
|
|
|
result[j][out_sl] = res[j] |
169
|
|
|
else: |
170
|
|
|
result[j] = None |
171
|
|
|
return result, kill_signal |
172
|
|
|
|
173
|
|
|
def __get_checkpoint_params(self, plugin): |
174
|
|
|
cp = self.exp.checkpoint |
175
|
|
|
if cp: |
176
|
|
|
cp._initialise(plugin.get_communicator()) |
177
|
|
|
return cp, cp.get_proc_idx(), cp.get_trans_idx() |
178
|
|
|
return None, 0, 0 |
179
|
|
|
|
180
|
|
|
def _initialise(self, plugin): |
181
|
|
|
self.process_setup(plugin) |
182
|
|
|
pDict = self.pDict |
183
|
|
|
result = [np.empty(d._get_plugin_data().get_shape_transfer(), |
184
|
|
|
dtype=np.float32) for d in pDict['out_data']] |
185
|
|
|
# loop over the transfer data |
186
|
|
|
nTrans = pDict['nTrans'] |
187
|
|
|
self.no_processing = True if not nTrans else False |
188
|
|
|
return pDict, result, nTrans |
189
|
|
|
|
190
|
|
|
def _log_completion_status(self, count, nTrans, name): |
191
|
|
|
percent_complete: float = count / (nTrans * 0.01) |
192
|
|
|
cu.user_message("%s - %3i%% complete" % (name, percent_complete)) |
193
|
|
|
|
194
|
|
|
def _transport_checkpoint(self): |
195
|
|
|
""" The framework has determined it is time to checkpoint. What |
196
|
|
|
should the transport mechanism do? Override if appropriate. """ |
197
|
|
|
return False |
198
|
|
|
|
199
|
|
|
def _transport_kill_signal(self): |
200
|
|
|
""" |
201
|
|
|
An opportunity to send a kill signal to the framework. Return |
202
|
|
|
True or False. """ |
203
|
|
|
return False |
204
|
|
|
|
205
|
|
|
def _get_all_slice_lists(self, data_list, dtype): |
206
|
|
|
""" |
207
|
|
|
Get all slice lists for the current process. |
208
|
|
|
|
209
|
|
|
:param list(Data) data_list: Datasets |
210
|
|
|
:returns: A list of dictionaries containing slice lists for each \ |
211
|
|
|
dataset |
212
|
|
|
:rtype: list(dict) |
213
|
|
|
""" |
214
|
|
|
sl_dict = {} |
215
|
|
|
for data in data_list: |
216
|
|
|
sl = data._get_transport_data()._get_slice_lists_per_process(dtype) |
217
|
|
|
for key, value in sl.items(): |
218
|
|
|
if key not in sl_dict: |
219
|
|
|
sl_dict[key] = [value] |
220
|
|
|
else: |
221
|
|
|
sl_dict[key].append(value) |
222
|
|
|
|
223
|
|
|
for key in [k for k in ['process', 'unpad'] if k in list(sl_dict.keys())]: |
224
|
|
|
nData = list(range(len(sl_dict[key]))) |
225
|
|
|
#rep = range(len(sl_dict[key][0])) |
226
|
|
|
sl_dict[key] = [[sl_dict[key][i][j] for i in nData if j < len(sl_dict[key][i])] for j in range(len(sl_dict[key][0]))] |
227
|
|
|
return sl_dict |
228
|
|
|
|
229
|
|
|
def _transfer_all_data(self, count): |
230
|
|
|
""" |
231
|
|
|
Transfer data from file and pad if required. |
232
|
|
|
|
233
|
|
|
:param int count: The current frame index. |
234
|
|
|
:returns: All data for this frame and associated padded slice lists |
235
|
|
|
:rtype: list(np.ndarray), list(tuple(slice)) |
236
|
|
|
""" |
237
|
|
|
pDict = self.pDict |
238
|
|
|
data_list = pDict['in_data'] |
239
|
|
|
|
240
|
|
|
if 'transfer' in list(pDict['in_sl'].keys()): |
241
|
|
|
slice_list = \ |
242
|
|
|
[pDict['in_sl']['transfer'][i][count] for i in pDict['nIn']] |
243
|
|
|
else: |
244
|
|
|
slice_list = [slice(None)]*len(pDict['nIn']) |
245
|
|
|
|
246
|
|
|
section = [] |
247
|
|
|
for idx in range(len(data_list)): |
248
|
|
|
section.append(data_list[idx]._get_transport_data(). |
249
|
|
|
_get_padded_data(slice_list[idx])) |
250
|
|
|
return section |
251
|
|
|
|
252
|
|
|
def _get_input_data(self, plugin, trans_data, nproc, ntrans): |
253
|
|
|
data = [] |
254
|
|
|
current_sl = [] |
255
|
|
|
for d in self.pDict['nIn']: |
256
|
|
|
in_sl = self.pDict['in_sl']['process'][nproc][d] |
257
|
|
|
data.append(self.pDict['squeeze'][d](trans_data[d][in_sl])) |
258
|
|
|
entry = ntrans*self.pDict['nProc'] + nproc |
259
|
|
|
if entry < len(self.pDict['in_sl']['current'][d]): |
260
|
|
|
current_sl.append(self.pDict['in_sl']['current'][d][entry]) |
261
|
|
|
else: |
262
|
|
|
current_sl.append(self.pDict['in_sl']['current'][d][-1]) |
263
|
|
|
plugin.set_current_slice_list(current_sl) |
264
|
|
|
return data |
265
|
|
|
|
266
|
|
|
def _get_output_data(self, result, count): |
267
|
|
|
if result is None: |
268
|
|
|
return |
269
|
|
|
unpad_sl = self.pDict['out_sl']['unpad'][count] |
270
|
|
|
result = result if isinstance(result, list) else [result] |
271
|
|
|
for j in self.pDict['nOut']: |
272
|
|
|
result[j] = self.pDict['expand'][j](result[j])[unpad_sl[j]] |
273
|
|
|
return result |
274
|
|
|
|
275
|
|
|
def _return_all_data(self, count, result, end): |
276
|
|
|
""" |
277
|
|
|
Transfer plugin results for current frame to backing files. |
278
|
|
|
|
279
|
|
|
:param int count: The current frame index. |
280
|
|
|
:param list(np.ndarray) result: plugin results |
281
|
|
|
:param bool end: True if this is the last entry in the slice list. |
282
|
|
|
""" |
283
|
|
|
pDict = self.pDict |
284
|
|
|
data_list = pDict['out_data'] |
285
|
|
|
|
286
|
|
|
slice_list = None |
287
|
|
|
if 'transfer' in list(pDict['out_sl'].keys()): |
288
|
|
|
slice_list = \ |
289
|
|
|
[pDict['out_sl']['transfer'][i][count] for i in pDict['nOut'] \ |
290
|
|
|
if len(pDict['out_sl']['transfer'][i]) > count] |
291
|
|
|
|
292
|
|
|
result = [result] if type(result) is not list else result |
293
|
|
|
|
294
|
|
|
for idx in range(len(data_list)): |
295
|
|
|
if result[idx] is not None: |
296
|
|
|
if slice_list: |
297
|
|
|
temp = self._remove_excess_data( |
298
|
|
|
data_list[idx], result[idx], slice_list[idx]) |
299
|
|
|
data_list[idx].data[slice_list[idx]] = temp |
300
|
|
|
else: |
301
|
|
|
data_list[idx].data = result[idx] |
302
|
|
|
|
303
|
|
|
def _set_global_frame_index(self, plugin, frame_list, nProc): |
304
|
|
|
""" Convert the transfer global frame index to a process global frame |
305
|
|
|
index. |
306
|
|
|
""" |
307
|
|
|
process_frames = [] |
308
|
|
|
for f in frame_list: |
309
|
|
|
if len(f): |
310
|
|
|
process_frames.append(list(range(f[0]*nProc, (f[-1]+1)*nProc))) |
311
|
|
|
|
312
|
|
|
process_frames = np.array(process_frames) |
313
|
|
|
nframes = plugin.get_plugin_in_datasets()[0].get_total_frames() |
314
|
|
|
process_frames[process_frames >= nframes] = nframes - 1 |
315
|
|
|
frames = process_frames[0] if process_frames.size else process_frames |
316
|
|
|
plugin.set_global_frame_index(frames) |
317
|
|
|
|
318
|
|
|
def _set_functions(self, data_list, name): |
319
|
|
|
""" Create a dictionary of functions to remove (squeeze) or re-add |
320
|
|
|
(expand) dimensions, of length 1, from each dataset in a list. |
321
|
|
|
|
322
|
|
|
:param list(Data) data_list: Datasets |
323
|
|
|
:param str name: 'squeeze' or 'expand' |
324
|
|
|
:returns: A dictionary of lambda functions |
325
|
|
|
:rtype: dict |
326
|
|
|
""" |
327
|
|
|
str_name = 'self.' + name + '_output' |
328
|
|
|
function = {'expand': self.__create_expand_function, |
329
|
|
|
'squeeze': self.__create_squeeze_function} |
330
|
|
|
ddict = {} |
331
|
|
|
for i in range(len(data_list)): |
332
|
|
|
ddict[i] = {i: str_name + str(i)} |
333
|
|
|
ddict[i] = function[name](data_list[i]) |
334
|
|
|
return ddict |
335
|
|
|
|
336
|
|
|
def __create_expand_function(self, data): |
337
|
|
|
""" Create a function that re-adds missing dimensions of length 1. |
338
|
|
|
|
339
|
|
|
:param Data data: Dataset |
340
|
|
|
:returns: expansion function |
341
|
|
|
:rtype: lambda |
342
|
|
|
""" |
343
|
|
|
slice_dirs = data.get_slice_dimensions() |
344
|
|
|
n_core_dirs = len(data.get_core_dimensions()) |
345
|
|
|
new_slice = [slice(None)]*len(data.get_shape()) |
346
|
|
|
possible_slices = [copy.copy(new_slice)] |
347
|
|
|
|
348
|
|
|
pData = data._get_plugin_data() |
349
|
|
|
if pData._get_rank_inc(): |
350
|
|
|
possible_slices[0] += [0]*pData._get_rank_inc() |
351
|
|
|
|
352
|
|
|
if len(slice_dirs) > 1: |
353
|
|
|
for sl in slice_dirs[1:]: |
354
|
|
|
new_slice[sl] = None |
355
|
|
|
possible_slices.append(copy.copy(new_slice)) |
356
|
|
|
new_slice[slice_dirs[0]] = None |
357
|
|
|
possible_slices.append(copy.copy(new_slice)) |
358
|
|
|
possible_slices = possible_slices[::-1] |
359
|
|
|
return lambda x: x[tuple(possible_slices[len(x.shape)-n_core_dirs])] |
360
|
|
|
|
361
|
|
|
def __create_squeeze_function(self, data): |
362
|
|
|
""" Create a function that removes dimensions of length 1. |
363
|
|
|
|
364
|
|
|
:param Data data: Dataset |
365
|
|
|
:returns: squeeze function |
366
|
|
|
:rtype: lambda |
367
|
|
|
""" |
368
|
|
|
pData = data._get_plugin_data() |
369
|
|
|
max_frames = pData._get_max_frames_process() |
370
|
|
|
|
371
|
|
|
pad = True if pData.padding and data.get_slice_dimensions()[0] in \ |
372
|
|
|
list(pData.padding._get_padding_directions().keys()) else False |
373
|
|
|
|
374
|
|
|
n_core_dims = len(data.get_core_dimensions()) |
375
|
|
|
squeeze_dims = data.get_slice_dimensions() |
376
|
|
|
if max_frames > 1 or pData._get_no_squeeze() or pad: |
377
|
|
|
squeeze_dims = squeeze_dims[1:] |
378
|
|
|
n_core_dims +=1 |
379
|
|
|
if pData._get_rank_inc(): |
380
|
|
|
sl = [(slice(None))]*n_core_dims + [None]*pData._get_rank_inc() |
381
|
|
|
return lambda x: np.squeeze(x[tuple(sl)], axis=squeeze_dims) |
|
|
|
|
382
|
|
|
return lambda x: np.squeeze(x, axis=squeeze_dims) |
383
|
|
|
|
384
|
|
|
def _remove_excess_data(self, data, result, slice_list): |
385
|
|
|
""" Remove any excess results due to padding for fixed length process \ |
386
|
|
|
frames. """ |
387
|
|
|
|
388
|
|
|
mData = data._get_plugin_data().meta_data.get_dictionary() |
389
|
|
|
temp = np.where(np.array(mData['size_list']) > 1)[0] |
390
|
|
|
sdir = mData['sdir'][temp[-1] if temp.size else 0] |
391
|
|
|
|
392
|
|
|
# Not currently working for basic_transport |
393
|
|
|
if isinstance(slice_list, slice): |
394
|
|
|
return |
395
|
|
|
|
396
|
|
|
sl = slice_list[sdir] |
397
|
|
|
shape = result.shape |
398
|
|
|
|
399
|
|
|
if shape[sdir] - (sl.stop - sl.start): |
400
|
|
|
unpad_sl = [slice(None)]*len(shape) |
401
|
|
|
unpad_sl[sdir] = slice(0, sl.stop - sl.start) |
402
|
|
|
result = result[tuple(unpad_sl)] |
403
|
|
|
return result |
404
|
|
|
|
405
|
|
|
def _setup_h5_files(self): |
406
|
|
|
out_data_dict = self.exp.index["out_data"] |
407
|
|
|
|
408
|
|
|
current_and_next = False |
409
|
|
|
if 'current_and_next' in self.exp.meta_data.get_dictionary(): |
410
|
|
|
current_and_next = self.exp.meta_data.get('current_and_next') |
411
|
|
|
|
412
|
|
|
count = 0 |
413
|
|
|
for key in out_data_dict.keys(): |
414
|
|
|
out_data = out_data_dict[key] |
415
|
|
|
filename = self.exp.meta_data.get(["filename", key]) |
416
|
|
|
out_data.backing_file = self.hdf5._open_backing_h5(filename, 'a') |
417
|
|
|
c_and_n = 0 if not current_and_next else current_and_next[key] |
418
|
|
|
out_data.group_name, out_data.group = self.hdf5._create_entries( |
419
|
|
|
out_data, key, c_and_n) |
420
|
|
|
count += 1 |
421
|
|
|
|
422
|
|
|
def _set_file_details(self, files): |
423
|
|
|
self.exp.meta_data.set('link_type', files['link_type']) |
424
|
|
|
self.exp.meta_data.set('link_type', {}) |
425
|
|
|
self.exp.meta_data.set('filename', {}) |
426
|
|
|
self.exp.meta_data.set('group_name', {}) |
427
|
|
|
for key in list(self.exp.index['out_data'].keys()): |
428
|
|
|
self.exp.meta_data.set(['link_type', key], files['link_type'][key]) |
429
|
|
|
self.exp.meta_data.set(['filename', key], files['filename'][key]) |
430
|
|
|
self.exp.meta_data.set(['group_name', key], |
431
|
|
|
files['group_name'][key]) |
432
|
|
|
|
433
|
|
|
def _get_filenames(self, plugin_dict): |
434
|
|
|
count = self.exp.meta_data.get('nPlugin') + 1 |
435
|
|
|
files = {"filename": {}, "group_name": {}, "link_type": {}} |
436
|
|
|
for key in list(self.exp.index["out_data"].keys()): |
437
|
|
|
name = key + '_p' + str(count) + '_' + \ |
438
|
|
|
plugin_dict['id'].split('.')[-1] + '.h5' |
439
|
|
|
link_type = self._get_link_type(key) |
440
|
|
|
files['link_type'][key] = link_type |
441
|
|
|
if link_type == 'final_result': |
442
|
|
|
out_path = self.exp.meta_data.get('out_path') |
443
|
|
|
else: |
444
|
|
|
out_path = self.exp.meta_data.get('inter_path') |
445
|
|
|
|
446
|
|
|
filename = os.path.join(out_path, name) |
447
|
|
|
group_name = "%i-%s-%s" % (count, plugin_dict['name'], key) |
448
|
|
|
files["filename"][key] = filename |
449
|
|
|
files["group_name"][key] = group_name |
450
|
|
|
|
451
|
|
|
return files |
452
|
|
|
|
453
|
|
|
def _get_link_type(self, name): |
454
|
|
|
idx = self.exp.meta_data.get('nPlugin') |
455
|
|
|
temp = [e for entry in self.data_flow[idx+1:] for e in entry] |
456
|
|
|
if name in temp or self.exp.index['out_data'][name].remove: |
457
|
|
|
return 'intermediate' |
458
|
|
|
return 'final_result' |
459
|
|
|
|
460
|
|
|
def _populate_nexus_file(self, data): |
461
|
|
|
filename = self.exp.meta_data.get('nxs_filename') |
462
|
|
|
|
463
|
|
|
with h5py.File(filename, 'a') as nxs_file: |
464
|
|
|
nxs_entry = nxs_file['entry'] |
465
|
|
|
name = data.data_info.get('name') |
466
|
|
|
group_name = self.exp.meta_data.get(['group_name', name]) |
467
|
|
|
link_type = self.exp.meta_data.get(['link_type', name]) |
468
|
|
|
|
469
|
|
|
if link_type == 'final_result': |
470
|
|
|
group_name = 'final_result_' + data.get_name() |
471
|
|
|
else: |
472
|
|
|
link = nxs_entry.require_group(link_type.encode("ascii")) |
473
|
|
|
link.attrs[NX_CLASS] = 'NXcollection' |
474
|
|
|
nxs_entry = link |
475
|
|
|
|
476
|
|
|
# delete the group if it already exists |
477
|
|
|
if group_name in nxs_entry: |
478
|
|
|
del nxs_entry[group_name] |
479
|
|
|
|
480
|
|
|
plugin_entry = nxs_entry.require_group(group_name) |
481
|
|
|
plugin_entry.attrs[NX_CLASS] = 'NXdata' |
482
|
|
|
self._output_metadata(data, plugin_entry, name) |
483
|
|
|
|
484
|
|
|
def _output_metadata(self, data, entry, name, dump=False): |
485
|
|
|
self.__output_data_type(entry, data, name) |
486
|
|
|
mDict = data.meta_data.get_dictionary() |
487
|
|
|
self._output_metadata_dict(entry.require_group('meta_data'), mDict) |
488
|
|
|
|
489
|
|
|
if not dump: |
490
|
|
|
self.__output_axis_labels(data, entry) |
491
|
|
|
self.__output_data_patterns(data, entry) |
492
|
|
|
if self.exp.meta_data.get('link_type')[name] == 'input_data': |
493
|
|
|
# output the filename |
494
|
|
|
entry['file_path'] = \ |
495
|
|
|
os.path.abspath(self.exp.meta_data.get('data_file')) |
496
|
|
|
|
497
|
|
|
def __output_data_type(self, entry, data, name): |
498
|
|
|
data = data.data if 'data' in list(data.__dict__.keys()) else data |
499
|
|
|
if isinstance(data, h5py.Dataset): |
500
|
|
|
return |
501
|
|
|
|
502
|
|
|
entry = entry.require_group('data_type') |
503
|
|
|
entry.attrs[NX_CLASS] = 'NXcollection' |
504
|
|
|
|
505
|
|
|
ltype = self.exp.meta_data.get('link_type') |
506
|
|
|
if name in list(ltype.keys()) and ltype[name] == 'input_data': |
507
|
|
|
self.__output_data(entry, data.__class__.__name__, 'cls') |
508
|
|
|
return |
509
|
|
|
|
510
|
|
|
args, kwargs, cls, extras = data._get_parameters(data.get_clone_args()) |
511
|
|
|
|
512
|
|
|
for key, value in kwargs.items(): |
513
|
|
|
gp = entry.require_group('kwargs') |
514
|
|
|
if isinstance(value, BaseType): |
515
|
|
|
self.__output_data_type(gp.require_group(key), value, key) |
516
|
|
|
else: |
517
|
|
|
self.__output_data(gp, value, key) |
518
|
|
|
|
519
|
|
|
for key, value in extras.items(): |
520
|
|
|
gp = entry.require_group('extras') |
521
|
|
|
if isinstance(value, BaseType): |
522
|
|
|
self.__output_data_type(gp.require_group(key), value, key) |
523
|
|
|
else: |
524
|
|
|
self.__output_data(gp, value, key) |
525
|
|
|
|
526
|
|
|
for i in range(len(args)): |
527
|
|
|
gp = entry.require_group('args') |
528
|
|
|
self.__output_data(gp, args[i], ''.join(['args', str(i)])) |
529
|
|
|
|
530
|
|
|
self.__output_data(entry, cls, 'cls') |
531
|
|
|
|
532
|
|
|
if 'data' in list(data.__dict__.keys()) and not \ |
533
|
|
|
isinstance(data.data, h5py.Dataset): |
534
|
|
|
gp = entry.require_group('data') |
535
|
|
|
self.__output_data_type(gp, data.data, 'data') |
536
|
|
|
|
537
|
|
|
def __output_data(self, entry, data, name): |
538
|
|
|
if isinstance(data, dict): |
539
|
|
|
entry = entry.require_group(name) |
540
|
|
|
entry.attrs[NX_CLASS] = 'NXcollection' |
541
|
|
|
for key, value in data.items(): |
542
|
|
|
self.__output_data(entry, value, key) |
543
|
|
|
else: |
544
|
|
|
try: |
545
|
|
|
self.__create_dataset(entry, name, data) |
546
|
|
|
except Exception: |
547
|
|
|
try: |
548
|
|
|
import json |
549
|
|
|
data = np.array([json.dumps(data).encode("ascii")]) |
550
|
|
|
self.__create_dataset(entry, name, data) |
551
|
|
|
except Exception: |
552
|
|
|
try: |
553
|
|
|
data = cu._savu_encoder(data) |
554
|
|
|
self.__create_dataset(entry, name, data) |
555
|
|
|
except: |
556
|
|
|
raise Exception('Unable to output %s to file.' % name) |
557
|
|
|
|
558
|
|
|
def __create_dataset(self, entry, name, data): |
559
|
|
|
if name not in list(entry.keys()): |
560
|
|
|
entry.create_dataset(name, data=data) |
561
|
|
|
else: |
562
|
|
|
entry[name][...] = data |
563
|
|
|
|
564
|
|
|
def __output_axis_labels(self, data, entry): |
565
|
|
|
axis_labels = data.data_info.get("axis_labels") |
566
|
|
|
ddict = data.meta_data.get_dictionary() |
567
|
|
|
|
568
|
|
|
axes = [] |
569
|
|
|
count = 0 |
570
|
|
|
for labels in axis_labels: |
571
|
|
|
name = list(labels.keys())[0] |
572
|
|
|
axes.append(name) |
573
|
|
|
entry.attrs[name + '_indices'] = count |
574
|
|
|
|
575
|
|
|
mData = ddict[name] if name in list(ddict.keys()) \ |
576
|
|
|
else np.arange(data.get_shape()[count]) |
577
|
|
|
if isinstance(mData, list): |
578
|
|
|
mData = np.array(mData) |
579
|
|
|
|
580
|
|
|
if 'U' in str(mData.dtype): |
581
|
|
|
mData = mData.astype(np.string_) |
582
|
|
|
|
583
|
|
|
axis_entry = entry.require_dataset(name, mData.shape, mData.dtype) |
584
|
|
|
axis_entry[...] = mData[...] |
585
|
|
|
axis_entry.attrs['units'] = list(labels.values())[0] |
586
|
|
|
count += 1 |
587
|
|
|
entry.attrs['axes'] = axes |
588
|
|
|
|
589
|
|
|
def __output_data_patterns(self, data, entry): |
590
|
|
|
data_patterns = data.data_info.get("data_patterns") |
591
|
|
|
entry = entry.require_group('patterns') |
592
|
|
|
entry.attrs[NX_CLASS] = 'NXcollection' |
593
|
|
|
for pattern in data_patterns: |
594
|
|
|
nx_data = entry.require_group(pattern) |
595
|
|
|
nx_data.attrs[NX_CLASS] = 'NXparameters' |
596
|
|
|
values = data_patterns[pattern] |
597
|
|
|
self.__output_data(nx_data, values['core_dims'], 'core_dims') |
598
|
|
|
self.__output_data(nx_data, values['slice_dims'], 'slice_dims') |
599
|
|
|
|
600
|
|
|
def _output_metadata_dict(self, entry, mData): |
601
|
|
|
entry.attrs[NX_CLASS] = 'NXcollection' |
602
|
|
|
for key, value in mData.items(): |
603
|
|
|
nx_data = entry.require_group(key) |
604
|
|
|
if isinstance(value, dict): |
605
|
|
|
self._output_metadata_dict(nx_data, value) |
606
|
|
|
else: |
607
|
|
|
nx_data.attrs[NX_CLASS] = 'NXdata' |
608
|
|
|
self.__output_data(nx_data, value, key) |
609
|
|
|
|