|
1
|
|
|
# Copyright 2015 Diamond Light Source Ltd. |
|
2
|
|
|
# |
|
3
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
4
|
|
|
# you may not use this file except in compliance with the License. |
|
5
|
|
|
# You may obtain a copy of the License at |
|
6
|
|
|
# |
|
7
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
8
|
|
|
# |
|
9
|
|
|
# Unless required by applicable law or agreed to in writing, software |
|
10
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
11
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
12
|
|
|
# See the License for the specific language governing permissions and |
|
13
|
|
|
# limitations under the License. |
|
14
|
|
|
|
|
15
|
|
|
""" |
|
16
|
|
|
.. module:: base_transport |
|
17
|
|
|
:platform: Unix |
|
18
|
|
|
:synopsis: A BaseTransport class which implements functions that control\ |
|
19
|
|
|
the interaction between the data and plugin layers. |
|
20
|
|
|
|
|
21
|
|
|
.. moduleauthor:: Nicola Wadeson <[email protected]> |
|
22
|
|
|
|
|
23
|
|
|
""" |
|
24
|
|
|
|
|
25
|
|
|
import os |
|
26
|
|
|
import time |
|
27
|
|
|
import copy |
|
28
|
|
|
import h5py |
|
29
|
|
|
import logging |
|
30
|
|
|
import numpy as np |
|
31
|
|
|
|
|
32
|
|
|
import savu.core.utils as cu |
|
33
|
|
|
import savu.plugins.utils as pu |
|
34
|
|
|
from savu.data.data_structures.data_types.base_type import BaseType |
|
35
|
|
|
|
|
36
|
|
|
NX_CLASS = 'NX_class' |
|
37
|
|
|
|
|
38
|
|
|
|
|
39
|
|
|
class BaseTransport(object): |
|
40
|
|
|
""" |
|
41
|
|
|
Implements functions that control the interaction between the data and |
|
42
|
|
|
plugin layers. |
|
43
|
|
|
""" |
|
44
|
|
|
|
|
45
|
|
|
def __init__(self): |
|
46
|
|
|
self.pDict = None |
|
47
|
|
|
self.no_processing = False |
|
48
|
|
|
|
|
49
|
|
|
def _transport_initialise(self, options): |
|
50
|
|
|
""" |
|
51
|
|
|
Any initial setup required by the transport mechanism on start up.\ |
|
52
|
|
|
This is called before the experiment is initialised. |
|
53
|
|
|
""" |
|
54
|
|
|
raise NotImplementedError("transport_control_setup needs to be " |
|
55
|
|
|
"implemented in %s", self.__class__) |
|
56
|
|
|
|
|
57
|
|
|
def _transport_update_plugin_list(self): |
|
58
|
|
|
""" |
|
59
|
|
|
This method provides an opportunity to add or remove items from the |
|
60
|
|
|
plugin list before plugin list check. |
|
61
|
|
|
""" |
|
62
|
|
|
pass |
|
63
|
|
|
|
|
64
|
|
|
def _transport_pre_plugin_list_run(self): |
|
65
|
|
|
""" |
|
66
|
|
|
This method is called after all datasets have been created but BEFORE |
|
67
|
|
|
the plugin list is processed. |
|
68
|
|
|
""" |
|
69
|
|
|
pass |
|
70
|
|
|
|
|
71
|
|
|
def _transport_load_plugin(self, exp, plugin_dict): |
|
72
|
|
|
""" This method is called before each plugin is loaded """ |
|
73
|
|
|
return pu.plugin_loader(exp, plugin_dict) |
|
74
|
|
|
|
|
75
|
|
|
def _transport_pre_plugin(self): |
|
76
|
|
|
""" |
|
77
|
|
|
This method is called directly BEFORE each plugin is executed, but \ |
|
78
|
|
|
after the plugin is loaded. |
|
79
|
|
|
""" |
|
80
|
|
|
pass |
|
81
|
|
|
|
|
82
|
|
|
def _transport_post_plugin(self): |
|
83
|
|
|
""" |
|
84
|
|
|
This method is called directly AFTER each plugin is executed. |
|
85
|
|
|
""" |
|
86
|
|
|
pass |
|
87
|
|
|
|
|
88
|
|
|
def _transport_post_plugin_list_run(self): |
|
89
|
|
|
""" |
|
90
|
|
|
This method is called AFTER the full plugin list has been processed. |
|
91
|
|
|
""" |
|
92
|
|
|
pass |
|
93
|
|
|
|
|
94
|
|
|
def _transport_terminate_dataset(self, data): |
|
95
|
|
|
""" A dataset that will subequently be removed by the framework. |
|
96
|
|
|
|
|
97
|
|
|
:param Data data: A data object to finalise. |
|
98
|
|
|
""" |
|
99
|
|
|
pass |
|
100
|
|
|
|
|
101
|
|
|
def process_setup(self, plugin): |
|
102
|
|
|
pDict = {} |
|
103
|
|
|
pDict['in_data'], pDict['out_data'] = plugin.get_datasets() |
|
104
|
|
|
pDict['in_sl'] = self._get_all_slice_lists(pDict['in_data'], 'in') |
|
105
|
|
|
pDict['out_sl'] = self._get_all_slice_lists(pDict['out_data'], 'out') |
|
106
|
|
|
pDict['nIn'] = list(range(len(pDict['in_data']))) |
|
107
|
|
|
pDict['nOut'] = list(range(len(pDict['out_data']))) |
|
108
|
|
|
pDict['nProc'] = len(pDict['in_sl']['process']) |
|
109
|
|
|
if 'transfer' in list(pDict['in_sl'].keys()): |
|
110
|
|
|
pDict['nTrans'] = len(pDict['in_sl']['transfer'][0]) |
|
111
|
|
|
else: |
|
112
|
|
|
pDict['nTrans'] = 1 |
|
113
|
|
|
pDict['squeeze'] = self._set_functions(pDict['in_data'], 'squeeze') |
|
114
|
|
|
pDict['expand'] = self._set_functions(pDict['out_data'], 'expand') |
|
115
|
|
|
|
|
116
|
|
|
frames = [f for f in pDict['in_sl']['frames']] |
|
117
|
|
|
self._set_global_frame_index(plugin, frames, pDict['nProc']) |
|
118
|
|
|
self.pDict = pDict |
|
119
|
|
|
|
|
120
|
|
|
def _transport_process(self, plugin): |
|
121
|
|
|
""" Organise required data and execute the main plugin processing. |
|
122
|
|
|
|
|
123
|
|
|
:param plugin plugin: The current plugin instance. |
|
124
|
|
|
""" |
|
125
|
|
|
logging.info("transport_process initialise") |
|
126
|
|
|
pDict, result, nTrans = self._initialise(plugin) |
|
127
|
|
|
logging.info("transport_process get_checkpoint_params") |
|
128
|
|
|
cp, sProc, sTrans = self.__get_checkpoint_params(plugin) |
|
129
|
|
|
|
|
130
|
|
|
count = 0 # temporary solution |
|
131
|
|
|
prange = list(range(sProc, pDict['nProc'])) |
|
132
|
|
|
kill = False |
|
133
|
|
|
for count in range(sTrans, nTrans): |
|
134
|
|
|
end = True if count == nTrans-1 else False |
|
135
|
|
|
self._log_completion_status(count, nTrans, plugin.name) |
|
136
|
|
|
|
|
137
|
|
|
# get the transfer data |
|
138
|
|
|
logging.info("Transferring the data") |
|
139
|
|
|
transfer_data = self._transfer_all_data(count) |
|
140
|
|
|
|
|
141
|
|
|
# loop over the process data |
|
142
|
|
|
logging.info("process frames loop") |
|
143
|
|
|
result, kill = self._process_loop( |
|
144
|
|
|
plugin, prange, transfer_data, count, pDict, result, cp) |
|
145
|
|
|
|
|
146
|
|
|
logging.info("Returning the data") |
|
147
|
|
|
self._return_all_data(count, result, end) |
|
148
|
|
|
|
|
149
|
|
|
if kill: |
|
150
|
|
|
return 1 |
|
151
|
|
|
|
|
152
|
|
|
if not kill: |
|
153
|
|
|
cu.user_message("%s - 100%% complete" % (plugin.name)) |
|
154
|
|
|
|
|
155
|
|
|
def _process_loop(self, plugin, prange, tdata, count, pDict, result, cp): |
|
156
|
|
|
kill_signal = False |
|
157
|
|
|
for i in prange: |
|
158
|
|
|
if cp and cp.is_time_to_checkpoint(self, count, i): |
|
159
|
|
|
# kill signal sent so stop the processing |
|
160
|
|
|
return result, True |
|
161
|
|
|
data = self._get_input_data(plugin, tdata, i, count) |
|
162
|
|
|
res = self._get_output_data( |
|
163
|
|
|
plugin.plugin_process_frames(data), i) |
|
164
|
|
|
|
|
165
|
|
|
for j in pDict['nOut']: |
|
166
|
|
|
if res is not None: |
|
167
|
|
|
out_sl = pDict['out_sl']['process'][i][j] |
|
168
|
|
|
result[j][out_sl] = res[j] |
|
169
|
|
|
else: |
|
170
|
|
|
result[j] = None |
|
171
|
|
|
return result, kill_signal |
|
172
|
|
|
|
|
173
|
|
|
def __get_checkpoint_params(self, plugin): |
|
174
|
|
|
cp = self.exp.checkpoint |
|
175
|
|
|
if cp: |
|
176
|
|
|
cp._initialise(plugin.get_communicator()) |
|
177
|
|
|
return cp, cp.get_proc_idx(), cp.get_trans_idx() |
|
178
|
|
|
return None, 0, 0 |
|
179
|
|
|
|
|
180
|
|
|
def _initialise(self, plugin): |
|
181
|
|
|
self.process_setup(plugin) |
|
182
|
|
|
pDict = self.pDict |
|
183
|
|
|
result = [np.empty(d._get_plugin_data().get_shape_transfer(), |
|
184
|
|
|
dtype=np.float32) for d in pDict['out_data']] |
|
185
|
|
|
# loop over the transfer data |
|
186
|
|
|
nTrans = pDict['nTrans'] |
|
187
|
|
|
self.no_processing = True if not nTrans else False |
|
188
|
|
|
return pDict, result, nTrans |
|
189
|
|
|
|
|
190
|
|
|
def _log_completion_status(self, count, nTrans, name): |
|
191
|
|
|
percent_complete: float = count / (nTrans * 0.01) |
|
192
|
|
|
cu.user_message("%s - %3i%% complete" % (name, percent_complete)) |
|
193
|
|
|
|
|
194
|
|
|
def _transport_checkpoint(self): |
|
195
|
|
|
""" The framework has determined it is time to checkpoint. What |
|
196
|
|
|
should the transport mechanism do? Override if appropriate. """ |
|
197
|
|
|
return False |
|
198
|
|
|
|
|
199
|
|
|
def _transport_kill_signal(self): |
|
200
|
|
|
""" |
|
201
|
|
|
An opportunity to send a kill signal to the framework. Return |
|
202
|
|
|
True or False. """ |
|
203
|
|
|
return False |
|
204
|
|
|
|
|
205
|
|
|
def _get_all_slice_lists(self, data_list, dtype): |
|
206
|
|
|
""" |
|
207
|
|
|
Get all slice lists for the current process. |
|
208
|
|
|
|
|
209
|
|
|
:param list(Data) data_list: Datasets |
|
210
|
|
|
:returns: A list of dictionaries containing slice lists for each \ |
|
211
|
|
|
dataset |
|
212
|
|
|
:rtype: list(dict) |
|
213
|
|
|
""" |
|
214
|
|
|
sl_dict = {} |
|
215
|
|
|
for data in data_list: |
|
216
|
|
|
sl = data._get_transport_data()._get_slice_lists_per_process(dtype) |
|
217
|
|
|
for key, value in sl.items(): |
|
218
|
|
|
if key not in sl_dict: |
|
219
|
|
|
sl_dict[key] = [value] |
|
220
|
|
|
else: |
|
221
|
|
|
sl_dict[key].append(value) |
|
222
|
|
|
|
|
223
|
|
|
for key in [k for k in ['process', 'unpad'] if k in list(sl_dict.keys())]: |
|
224
|
|
|
nData = list(range(len(sl_dict[key]))) |
|
225
|
|
|
#rep = range(len(sl_dict[key][0])) |
|
226
|
|
|
sl_dict[key] = [[sl_dict[key][i][j] for i in nData if j < len(sl_dict[key][i])] for j in range(len(sl_dict[key][0]))] |
|
227
|
|
|
return sl_dict |
|
228
|
|
|
|
|
229
|
|
|
def _transfer_all_data(self, count): |
|
230
|
|
|
""" |
|
231
|
|
|
Transfer data from file and pad if required. |
|
232
|
|
|
|
|
233
|
|
|
:param int count: The current frame index. |
|
234
|
|
|
:returns: All data for this frame and associated padded slice lists |
|
235
|
|
|
:rtype: list(np.ndarray), list(tuple(slice)) |
|
236
|
|
|
""" |
|
237
|
|
|
pDict = self.pDict |
|
238
|
|
|
data_list = pDict['in_data'] |
|
239
|
|
|
|
|
240
|
|
|
if 'transfer' in list(pDict['in_sl'].keys()): |
|
241
|
|
|
slice_list = \ |
|
242
|
|
|
[pDict['in_sl']['transfer'][i][count] for i in pDict['nIn']] |
|
243
|
|
|
else: |
|
244
|
|
|
slice_list = [slice(None)]*len(pDict['nIn']) |
|
245
|
|
|
|
|
246
|
|
|
section = [] |
|
247
|
|
|
for idx in range(len(data_list)): |
|
248
|
|
|
section.append(data_list[idx]._get_transport_data(). |
|
249
|
|
|
_get_padded_data(slice_list[idx])) |
|
250
|
|
|
return section |
|
251
|
|
|
|
|
252
|
|
|
def _get_input_data(self, plugin, trans_data, nproc, ntrans): |
|
253
|
|
|
data = [] |
|
254
|
|
|
current_sl = [] |
|
255
|
|
|
for d in self.pDict['nIn']: |
|
256
|
|
|
in_sl = self.pDict['in_sl']['process'][nproc][d] |
|
257
|
|
|
data.append(self.pDict['squeeze'][d](trans_data[d][in_sl])) |
|
258
|
|
|
entry = ntrans*self.pDict['nProc'] + nproc |
|
259
|
|
|
if entry < len(self.pDict['in_sl']['current'][d]): |
|
260
|
|
|
current_sl.append(self.pDict['in_sl']['current'][d][entry]) |
|
261
|
|
|
else: |
|
262
|
|
|
current_sl.append(self.pDict['in_sl']['current'][d][-1]) |
|
263
|
|
|
plugin.set_current_slice_list(current_sl) |
|
264
|
|
|
return data |
|
265
|
|
|
|
|
266
|
|
|
def _get_output_data(self, result, count): |
|
267
|
|
|
if result is None: |
|
268
|
|
|
return |
|
269
|
|
|
unpad_sl = self.pDict['out_sl']['unpad'][count] |
|
270
|
|
|
result = result if isinstance(result, list) else [result] |
|
271
|
|
|
for j in self.pDict['nOut']: |
|
272
|
|
|
result[j] = self.pDict['expand'][j](result[j])[unpad_sl[j]] |
|
273
|
|
|
return result |
|
274
|
|
|
|
|
275
|
|
|
def _return_all_data(self, count, result, end): |
|
276
|
|
|
""" |
|
277
|
|
|
Transfer plugin results for current frame to backing files. |
|
278
|
|
|
|
|
279
|
|
|
:param int count: The current frame index. |
|
280
|
|
|
:param list(np.ndarray) result: plugin results |
|
281
|
|
|
:param bool end: True if this is the last entry in the slice list. |
|
282
|
|
|
""" |
|
283
|
|
|
pDict = self.pDict |
|
284
|
|
|
data_list = pDict['out_data'] |
|
285
|
|
|
|
|
286
|
|
|
slice_list = None |
|
287
|
|
|
if 'transfer' in list(pDict['out_sl'].keys()): |
|
288
|
|
|
slice_list = \ |
|
289
|
|
|
[pDict['out_sl']['transfer'][i][count] for i in pDict['nOut'] \ |
|
290
|
|
|
if len(pDict['out_sl']['transfer'][i]) > count] |
|
291
|
|
|
|
|
292
|
|
|
result = [result] if type(result) is not list else result |
|
293
|
|
|
|
|
294
|
|
|
for idx in range(len(data_list)): |
|
295
|
|
|
if result[idx] is not None: |
|
296
|
|
|
if slice_list: |
|
297
|
|
|
temp = self._remove_excess_data( |
|
298
|
|
|
data_list[idx], result[idx], slice_list[idx]) |
|
299
|
|
|
data_list[idx].data[slice_list[idx]] = temp |
|
300
|
|
|
else: |
|
301
|
|
|
data_list[idx].data = result[idx] |
|
302
|
|
|
|
|
303
|
|
|
def _set_global_frame_index(self, plugin, frame_list, nProc): |
|
304
|
|
|
""" Convert the transfer global frame index to a process global frame |
|
305
|
|
|
index. |
|
306
|
|
|
""" |
|
307
|
|
|
process_frames = [] |
|
308
|
|
|
for f in frame_list: |
|
309
|
|
|
if len(f): |
|
310
|
|
|
process_frames.append(list(range(f[0]*nProc, (f[-1]+1)*nProc))) |
|
311
|
|
|
|
|
312
|
|
|
process_frames = np.array(process_frames) |
|
313
|
|
|
nframes = plugin.get_plugin_in_datasets()[0].get_total_frames() |
|
314
|
|
|
process_frames[process_frames >= nframes] = nframes - 1 |
|
315
|
|
|
frames = process_frames[0] if process_frames.size else process_frames |
|
316
|
|
|
plugin.set_global_frame_index(frames) |
|
317
|
|
|
|
|
318
|
|
|
def _set_functions(self, data_list, name): |
|
319
|
|
|
""" Create a dictionary of functions to remove (squeeze) or re-add |
|
320
|
|
|
(expand) dimensions, of length 1, from each dataset in a list. |
|
321
|
|
|
|
|
322
|
|
|
:param list(Data) data_list: Datasets |
|
323
|
|
|
:param str name: 'squeeze' or 'expand' |
|
324
|
|
|
:returns: A dictionary of lambda functions |
|
325
|
|
|
:rtype: dict |
|
326
|
|
|
""" |
|
327
|
|
|
str_name = 'self.' + name + '_output' |
|
328
|
|
|
function = {'expand': self.__create_expand_function, |
|
329
|
|
|
'squeeze': self.__create_squeeze_function} |
|
330
|
|
|
ddict = {} |
|
331
|
|
|
for i in range(len(data_list)): |
|
332
|
|
|
ddict[i] = {i: str_name + str(i)} |
|
333
|
|
|
ddict[i] = function[name](data_list[i]) |
|
334
|
|
|
return ddict |
|
335
|
|
|
|
|
336
|
|
|
def __create_expand_function(self, data): |
|
337
|
|
|
""" Create a function that re-adds missing dimensions of length 1. |
|
338
|
|
|
|
|
339
|
|
|
:param Data data: Dataset |
|
340
|
|
|
:returns: expansion function |
|
341
|
|
|
:rtype: lambda |
|
342
|
|
|
""" |
|
343
|
|
|
slice_dirs = data.get_slice_dimensions() |
|
344
|
|
|
n_core_dirs = len(data.get_core_dimensions()) |
|
345
|
|
|
new_slice = [slice(None)]*len(data.get_shape()) |
|
346
|
|
|
possible_slices = [copy.copy(new_slice)] |
|
347
|
|
|
|
|
348
|
|
|
pData = data._get_plugin_data() |
|
349
|
|
|
if pData._get_rank_inc(): |
|
350
|
|
|
possible_slices[0] += [0]*pData._get_rank_inc() |
|
351
|
|
|
|
|
352
|
|
|
if len(slice_dirs) > 1: |
|
353
|
|
|
for sl in slice_dirs[1:]: |
|
354
|
|
|
new_slice[sl] = None |
|
355
|
|
|
possible_slices.append(copy.copy(new_slice)) |
|
356
|
|
|
new_slice[slice_dirs[0]] = None |
|
357
|
|
|
possible_slices.append(copy.copy(new_slice)) |
|
358
|
|
|
possible_slices = possible_slices[::-1] |
|
359
|
|
|
return lambda x: x[tuple(possible_slices[len(x.shape)-n_core_dirs])] |
|
360
|
|
|
|
|
361
|
|
|
def __create_squeeze_function(self, data): |
|
362
|
|
|
""" Create a function that removes dimensions of length 1. |
|
363
|
|
|
|
|
364
|
|
|
:param Data data: Dataset |
|
365
|
|
|
:returns: squeeze function |
|
366
|
|
|
:rtype: lambda |
|
367
|
|
|
""" |
|
368
|
|
|
pData = data._get_plugin_data() |
|
369
|
|
|
max_frames = pData._get_max_frames_process() |
|
370
|
|
|
|
|
371
|
|
|
pad = True if pData.padding and data.get_slice_dimensions()[0] in \ |
|
372
|
|
|
list(pData.padding._get_padding_directions().keys()) else False |
|
373
|
|
|
|
|
374
|
|
|
n_core_dims = len(data.get_core_dimensions()) |
|
375
|
|
|
squeeze_dims = data.get_slice_dimensions() |
|
376
|
|
|
if max_frames > 1 or pData._get_no_squeeze() or pad: |
|
377
|
|
|
squeeze_dims = squeeze_dims[1:] |
|
378
|
|
|
n_core_dims +=1 |
|
379
|
|
|
if pData._get_rank_inc(): |
|
380
|
|
|
sl = [(slice(None))]*n_core_dims + [None]*pData._get_rank_inc() |
|
381
|
|
|
return lambda x: np.squeeze(x[tuple(sl)], axis=squeeze_dims) |
|
|
|
|
|
|
382
|
|
|
return lambda x: np.squeeze(x, axis=squeeze_dims) |
|
383
|
|
|
|
|
384
|
|
|
def _remove_excess_data(self, data, result, slice_list): |
|
385
|
|
|
""" Remove any excess results due to padding for fixed length process \ |
|
386
|
|
|
frames. """ |
|
387
|
|
|
|
|
388
|
|
|
mData = data._get_plugin_data().meta_data.get_dictionary() |
|
389
|
|
|
temp = np.where(np.array(mData['size_list']) > 1)[0] |
|
390
|
|
|
sdir = mData['sdir'][temp[-1] if temp.size else 0] |
|
391
|
|
|
|
|
392
|
|
|
# Not currently working for basic_transport |
|
393
|
|
|
if isinstance(slice_list, slice): |
|
394
|
|
|
return |
|
395
|
|
|
|
|
396
|
|
|
sl = slice_list[sdir] |
|
397
|
|
|
shape = result.shape |
|
398
|
|
|
|
|
399
|
|
|
if shape[sdir] - (sl.stop - sl.start): |
|
400
|
|
|
unpad_sl = [slice(None)]*len(shape) |
|
401
|
|
|
unpad_sl[sdir] = slice(0, sl.stop - sl.start) |
|
402
|
|
|
result = result[tuple(unpad_sl)] |
|
403
|
|
|
return result |
|
404
|
|
|
|
|
405
|
|
|
def _setup_h5_files(self): |
|
406
|
|
|
out_data_dict = self.exp.index["out_data"] |
|
407
|
|
|
|
|
408
|
|
|
current_and_next = False |
|
409
|
|
|
if 'current_and_next' in self.exp.meta_data.get_dictionary(): |
|
410
|
|
|
current_and_next = self.exp.meta_data.get('current_and_next') |
|
411
|
|
|
|
|
412
|
|
|
count = 0 |
|
413
|
|
|
for key in out_data_dict.keys(): |
|
414
|
|
|
out_data = out_data_dict[key] |
|
415
|
|
|
filename = self.exp.meta_data.get(["filename", key]) |
|
416
|
|
|
out_data.backing_file = self.hdf5._open_backing_h5(filename, 'a') |
|
417
|
|
|
c_and_n = 0 if not current_and_next else current_and_next[key] |
|
418
|
|
|
out_data.group_name, out_data.group = self.hdf5._create_entries( |
|
419
|
|
|
out_data, key, c_and_n) |
|
420
|
|
|
count += 1 |
|
421
|
|
|
|
|
422
|
|
|
def _set_file_details(self, files): |
|
423
|
|
|
self.exp.meta_data.set('link_type', files['link_type']) |
|
424
|
|
|
self.exp.meta_data.set('link_type', {}) |
|
425
|
|
|
self.exp.meta_data.set('filename', {}) |
|
426
|
|
|
self.exp.meta_data.set('group_name', {}) |
|
427
|
|
|
for key in list(self.exp.index['out_data'].keys()): |
|
428
|
|
|
self.exp.meta_data.set(['link_type', key], files['link_type'][key]) |
|
429
|
|
|
self.exp.meta_data.set(['filename', key], files['filename'][key]) |
|
430
|
|
|
self.exp.meta_data.set(['group_name', key], |
|
431
|
|
|
files['group_name'][key]) |
|
432
|
|
|
|
|
433
|
|
|
def _get_filenames(self, plugin_dict): |
|
434
|
|
|
count = self.exp.meta_data.get('nPlugin') + 1 |
|
435
|
|
|
files = {"filename": {}, "group_name": {}, "link_type": {}} |
|
436
|
|
|
for key in list(self.exp.index["out_data"].keys()): |
|
437
|
|
|
name = key + '_p' + str(count) + '_' + \ |
|
438
|
|
|
plugin_dict['id'].split('.')[-1] + '.h5' |
|
439
|
|
|
link_type = self._get_link_type(key) |
|
440
|
|
|
files['link_type'][key] = link_type |
|
441
|
|
|
if link_type == 'final_result': |
|
442
|
|
|
out_path = self.exp.meta_data.get('out_path') |
|
443
|
|
|
else: |
|
444
|
|
|
out_path = self.exp.meta_data.get('inter_path') |
|
445
|
|
|
|
|
446
|
|
|
filename = os.path.join(out_path, name) |
|
447
|
|
|
group_name = "%i-%s-%s" % (count, plugin_dict['name'], key) |
|
448
|
|
|
files["filename"][key] = filename |
|
449
|
|
|
files["group_name"][key] = group_name |
|
450
|
|
|
|
|
451
|
|
|
return files |
|
452
|
|
|
|
|
453
|
|
|
def _get_link_type(self, name): |
|
454
|
|
|
idx = self.exp.meta_data.get('nPlugin') |
|
455
|
|
|
temp = [e for entry in self.data_flow[idx+1:] for e in entry] |
|
456
|
|
|
if name in temp or self.exp.index['out_data'][name].remove: |
|
457
|
|
|
return 'intermediate' |
|
458
|
|
|
return 'final_result' |
|
459
|
|
|
|
|
460
|
|
|
def _populate_nexus_file(self, data): |
|
461
|
|
|
filename = self.exp.meta_data.get('nxs_filename') |
|
462
|
|
|
|
|
463
|
|
|
with h5py.File(filename, 'a') as nxs_file: |
|
464
|
|
|
nxs_entry = nxs_file['entry'] |
|
465
|
|
|
name = data.data_info.get('name') |
|
466
|
|
|
group_name = self.exp.meta_data.get(['group_name', name]) |
|
467
|
|
|
link_type = self.exp.meta_data.get(['link_type', name]) |
|
468
|
|
|
|
|
469
|
|
|
if link_type == 'final_result': |
|
470
|
|
|
group_name = 'final_result_' + data.get_name() |
|
471
|
|
|
else: |
|
472
|
|
|
link = nxs_entry.require_group(link_type.encode("ascii")) |
|
473
|
|
|
link.attrs[NX_CLASS] = 'NXcollection' |
|
474
|
|
|
nxs_entry = link |
|
475
|
|
|
|
|
476
|
|
|
# delete the group if it already exists |
|
477
|
|
|
if group_name in nxs_entry: |
|
478
|
|
|
del nxs_entry[group_name] |
|
479
|
|
|
|
|
480
|
|
|
plugin_entry = nxs_entry.require_group(group_name) |
|
481
|
|
|
plugin_entry.attrs[NX_CLASS] = 'NXdata' |
|
482
|
|
|
self._output_metadata(data, plugin_entry, name) |
|
483
|
|
|
|
|
484
|
|
|
def _output_metadata(self, data, entry, name, dump=False): |
|
485
|
|
|
self.__output_data_type(entry, data, name) |
|
486
|
|
|
mDict = data.meta_data.get_dictionary() |
|
487
|
|
|
self._output_metadata_dict(entry.require_group('meta_data'), mDict) |
|
488
|
|
|
|
|
489
|
|
|
if not dump: |
|
490
|
|
|
self.__output_axis_labels(data, entry) |
|
491
|
|
|
self.__output_data_patterns(data, entry) |
|
492
|
|
|
if self.exp.meta_data.get('link_type')[name] == 'input_data': |
|
493
|
|
|
# output the filename |
|
494
|
|
|
entry['file_path'] = \ |
|
495
|
|
|
os.path.abspath(self.exp.meta_data.get('data_file')) |
|
496
|
|
|
|
|
497
|
|
|
def __output_data_type(self, entry, data, name): |
|
498
|
|
|
data = data.data if 'data' in list(data.__dict__.keys()) else data |
|
499
|
|
|
if isinstance(data, h5py.Dataset): |
|
500
|
|
|
return |
|
501
|
|
|
|
|
502
|
|
|
entry = entry.require_group('data_type') |
|
503
|
|
|
entry.attrs[NX_CLASS] = 'NXcollection' |
|
504
|
|
|
|
|
505
|
|
|
ltype = self.exp.meta_data.get('link_type') |
|
506
|
|
|
if name in list(ltype.keys()) and ltype[name] == 'input_data': |
|
507
|
|
|
self.__output_data(entry, data.__class__.__name__, 'cls') |
|
508
|
|
|
return |
|
509
|
|
|
|
|
510
|
|
|
args, kwargs, cls, extras = data._get_parameters(data.get_clone_args()) |
|
511
|
|
|
|
|
512
|
|
|
for key, value in kwargs.items(): |
|
513
|
|
|
gp = entry.require_group('kwargs') |
|
514
|
|
|
if isinstance(value, BaseType): |
|
515
|
|
|
self.__output_data_type(gp.require_group(key), value, key) |
|
516
|
|
|
else: |
|
517
|
|
|
self.__output_data(gp, value, key) |
|
518
|
|
|
|
|
519
|
|
|
for key, value in extras.items(): |
|
520
|
|
|
gp = entry.require_group('extras') |
|
521
|
|
|
if isinstance(value, BaseType): |
|
522
|
|
|
self.__output_data_type(gp.require_group(key), value, key) |
|
523
|
|
|
else: |
|
524
|
|
|
self.__output_data(gp, value, key) |
|
525
|
|
|
|
|
526
|
|
|
for i in range(len(args)): |
|
527
|
|
|
gp = entry.require_group('args') |
|
528
|
|
|
self.__output_data(gp, args[i], ''.join(['args', str(i)])) |
|
529
|
|
|
|
|
530
|
|
|
self.__output_data(entry, cls, 'cls') |
|
531
|
|
|
|
|
532
|
|
|
if 'data' in list(data.__dict__.keys()) and not \ |
|
533
|
|
|
isinstance(data.data, h5py.Dataset): |
|
534
|
|
|
gp = entry.require_group('data') |
|
535
|
|
|
self.__output_data_type(gp, data.data, 'data') |
|
536
|
|
|
|
|
537
|
|
|
def __output_data(self, entry, data, name): |
|
538
|
|
|
if isinstance(data, dict): |
|
539
|
|
|
entry = entry.require_group(name) |
|
540
|
|
|
entry.attrs[NX_CLASS] = 'NXcollection' |
|
541
|
|
|
for key, value in data.items(): |
|
542
|
|
|
self.__output_data(entry, value, key) |
|
543
|
|
|
else: |
|
544
|
|
|
try: |
|
545
|
|
|
self.__create_dataset(entry, name, data) |
|
546
|
|
|
except Exception: |
|
547
|
|
|
try: |
|
548
|
|
|
import json |
|
549
|
|
|
data = np.array([json.dumps(data).encode("ascii")]) |
|
550
|
|
|
self.__create_dataset(entry, name, data) |
|
551
|
|
|
except Exception: |
|
552
|
|
|
try: |
|
553
|
|
|
data = cu._savu_encoder(data) |
|
554
|
|
|
self.__create_dataset(entry, name, data) |
|
555
|
|
|
except: |
|
556
|
|
|
raise Exception('Unable to output %s to file.' % name) |
|
557
|
|
|
|
|
558
|
|
|
def __create_dataset(self, entry, name, data): |
|
559
|
|
|
if name not in list(entry.keys()): |
|
560
|
|
|
entry.create_dataset(name, data=data) |
|
561
|
|
|
else: |
|
562
|
|
|
entry[name][...] = data |
|
563
|
|
|
|
|
564
|
|
|
def __output_axis_labels(self, data, entry): |
|
565
|
|
|
axis_labels = data.data_info.get("axis_labels") |
|
566
|
|
|
ddict = data.meta_data.get_dictionary() |
|
567
|
|
|
|
|
568
|
|
|
axes = [] |
|
569
|
|
|
count = 0 |
|
570
|
|
|
for labels in axis_labels: |
|
571
|
|
|
name = list(labels.keys())[0] |
|
572
|
|
|
axes.append(name) |
|
573
|
|
|
entry.attrs[name + '_indices'] = count |
|
574
|
|
|
|
|
575
|
|
|
mData = ddict[name] if name in list(ddict.keys()) \ |
|
576
|
|
|
else np.arange(data.get_shape()[count]) |
|
577
|
|
|
if isinstance(mData, list): |
|
578
|
|
|
mData = np.array(mData) |
|
579
|
|
|
|
|
580
|
|
|
if 'U' in str(mData.dtype): |
|
581
|
|
|
mData = mData.astype(np.string_) |
|
582
|
|
|
|
|
583
|
|
|
axis_entry = entry.require_dataset(name, mData.shape, mData.dtype) |
|
584
|
|
|
axis_entry[...] = mData[...] |
|
585
|
|
|
axis_entry.attrs['units'] = list(labels.values())[0] |
|
586
|
|
|
count += 1 |
|
587
|
|
|
entry.attrs['axes'] = axes |
|
588
|
|
|
|
|
589
|
|
|
def __output_data_patterns(self, data, entry): |
|
590
|
|
|
data_patterns = data.data_info.get("data_patterns") |
|
591
|
|
|
entry = entry.require_group('patterns') |
|
592
|
|
|
entry.attrs[NX_CLASS] = 'NXcollection' |
|
593
|
|
|
for pattern in data_patterns: |
|
594
|
|
|
nx_data = entry.require_group(pattern) |
|
595
|
|
|
nx_data.attrs[NX_CLASS] = 'NXparameters' |
|
596
|
|
|
values = data_patterns[pattern] |
|
597
|
|
|
self.__output_data(nx_data, values['core_dims'], 'core_dims') |
|
598
|
|
|
self.__output_data(nx_data, values['slice_dims'], 'slice_dims') |
|
599
|
|
|
|
|
600
|
|
|
def _output_metadata_dict(self, entry, mData): |
|
601
|
|
|
entry.attrs[NX_CLASS] = 'NXcollection' |
|
602
|
|
|
for key, value in mData.items(): |
|
603
|
|
|
nx_data = entry.require_group(key) |
|
604
|
|
|
if isinstance(value, dict): |
|
605
|
|
|
self._output_metadata_dict(nx_data, value) |
|
606
|
|
|
else: |
|
607
|
|
|
nx_data.attrs[NX_CLASS] = 'NXdata' |
|
608
|
|
|
self.__output_data(nx_data, value, key) |
|
609
|
|
|
|