savu.core.transports.base_transport   F
last analyzed

Complexity

Total Complexity 144

Size/Duplication

Total Lines 692
Duplicated Lines 4.34 %

Importance

Changes 0
Metric Value
eloc 446
dl 30
loc 692
rs 2
c 0
b 0
f 0
wmc 144

42 Methods

Rating   Name   Duplication   Size   Complexity  
A BaseTransport._transport_load_plugin() 0 3 1
A BaseTransport._setup_h5_files() 0 16 4
B BaseTransport.__create_squeeze_function() 0 22 8
A BaseTransport._get_link_type() 0 6 3
A BaseTransport._set_global_frame_index() 0 14 4
A BaseTransport._set_functions() 0 17 2
A BaseTransport._get_filenames() 0 19 3
A BaseTransport._remove_excess_data() 0 20 4
A BaseTransport._set_file_details() 0 10 2
A BaseTransport.__create_expand_function() 0 24 5
B BaseTransport._return_all_data() 0 27 6
A BaseTransport._transport_pre_plugin() 0 2 1
A BaseTransport._transport_checkpoint() 0 4 1
A BaseTransport._initialise() 0 9 2
A BaseTransport._transport_kill_signal() 0 5 1
B BaseTransport._transport_process() 0 37 7
A BaseTransport._transport_pre_plugin_list_run() 0 2 1
A BaseTransport.__get_checkpoint_params() 0 6 2
A BaseTransport._transfer_all_data() 0 22 3
A BaseTransport._get_input_data() 0 13 3
D BaseTransport._populate_nexus_file() 0 48 12
A BaseTransport._transport_post_plugin_list_run() 0 2 1
B BaseTransport.remove_extra_slices() 0 24 5
A BaseTransport.__init__() 0 3 1
A BaseTransport._get_all_slice_lists() 0 24 5
A BaseTransport._transport_post_plugin() 0 2 1
A BaseTransport._transport_terminate_dataset() 0 2 1
A BaseTransport._update_slice_list() 0 5 1
A BaseTransport._transport_initialise() 0 7 1
B BaseTransport._process_loop() 0 22 7
A BaseTransport.process_setup() 0 18 2
A BaseTransport._log_completion_status() 0 3 1
A BaseTransport._get_output_data() 0 13 5
A BaseTransport._transport_update_plugin_list() 0 2 1
A BaseTransport.__create_dataset() 0 5 2
B BaseTransport.__output_axis_labels() 0 24 5
A BaseTransport.__output_data_patterns() 10 10 2
A BaseTransport._output_metadata_dict() 0 9 3
A BaseTransport._populate_pre_run_nexus_file() 0 14 4
A BaseTransport._output_metadata() 0 12 3
D BaseTransport.__output_data_type() 0 39 12
B BaseTransport.__output_data() 20 20 6

How to fix   Duplicated Code    Complexity   

Duplicated Code

Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.

Common duplication problems, and corresponding solutions are:

Complexity

 Tip:   Before tackling complexity, make sure that you eliminate any duplication first. This often can reduce the size of classes significantly.

Complex classes like savu.core.transports.base_transport often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# Copyright 2015 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: base_transport
17
   :platform: Unix
18
   :synopsis: A BaseTransport class which implements functions that control\
19
   the interaction between the data and plugin layers.
20
21
.. moduleauthor:: Nicola Wadeson <[email protected]>
22
23
"""
24
25
import os
26
import time
27
import copy
28
import h5py
29
import math
30
import logging
31
import numpy as np
32
33
import savu.core.utils as cu
34
import savu.plugins.utils as pu
35
from savu.data.data_structures.data_types.base_type import BaseType
36
from savu.core.iterate_plugin_group_utils import \
37
    check_if_end_plugin_in_iterate_group
38
39
NX_CLASS = 'NX_class'
40
41
42
class BaseTransport(object):
43
    """
44
    Implements functions that control the interaction between the data and
45
    plugin layers.
46
    """
47
48
    def __init__(self):
49
        self.pDict = None
50
        self.no_processing = False
51
52
    def _transport_initialise(self, options):
53
        """
54
        Any initial setup required by the transport mechanism on start up.\
55
        This is called before the experiment is initialised.
56
        """
57
        raise NotImplementedError("transport_control_setup needs to be "
58
                                  "implemented in %s", self.__class__)
59
60
    def _transport_update_plugin_list(self):
61
        """
62
        This method provides an opportunity to add or remove items from the
63
        plugin list before plugin list check.
64
        """
65
66
    def _transport_pre_plugin_list_run(self):
67
        """
68
        This method is called after all datasets have been created but BEFORE
69
        the plugin list is processed.
70
        """
71
72
    def _transport_load_plugin(self, exp, plugin_dict):
73
        """ This method is called before each plugin is loaded """
74
        return pu.plugin_loader(exp, plugin_dict)
75
76
    def _transport_pre_plugin(self):
77
        """
78
        This method is called directly BEFORE each plugin is executed, but \
79
        after the plugin is loaded.
80
        """
81
82
    def _transport_post_plugin(self):
83
        """
84
        This method is called directly AFTER each plugin is executed.
85
        """
86
87
    def _transport_post_plugin_list_run(self):
88
        """
89
        This method is called AFTER the full plugin list has been processed.
90
        """
91
92
    def _transport_terminate_dataset(self, data):
93
        """ A dataset that will subequently be removed by the framework.
94
95
        :param Data data: A data object to finalise.
96
        """
97
98
    def process_setup(self, plugin):
99
        pDict = {}
100
        pDict['in_data'], pDict['out_data'] = plugin.get_datasets()
101
        pDict['in_sl'] = self._get_all_slice_lists(pDict['in_data'], 'in')
102
        pDict['out_sl'] = self._get_all_slice_lists(pDict['out_data'], 'out')
103
        pDict['nIn'] = list(range(len(pDict['in_data'])))
104
        pDict['nOut'] = list(range(len(pDict['out_data'])))
105
        pDict['nProc'] = len(pDict['in_sl']['process'])
106
        if 'transfer' in list(pDict['in_sl'].keys()):
107
            pDict['nTrans'] = len(pDict['in_sl']['transfer'][0])
108
        else:
109
            pDict['nTrans'] = 1
110
        pDict['squeeze'] = self._set_functions(pDict['in_data'], 'squeeze')
111
        pDict['expand'] = self._set_functions(pDict['out_data'], 'expand')
112
113
        frames = [f for f in pDict['in_sl']['frames']]
114
        self._set_global_frame_index(plugin, frames, pDict['nProc'])
115
        self.pDict = pDict
116
117
    def _transport_process(self, plugin):
118
        """ Organise required data and execute the main plugin processing.
119
120
        :param plugin plugin: The current plugin instance.
121
        """
122
        logging.info("transport_process initialise")
123
        pDict, result, nTrans = self._initialise(plugin)
124
        logging.info("transport_process get_checkpoint_params")
125
        cp, sProc, sTrans = self.__get_checkpoint_params(plugin)
126
127
        prange = list(range(sProc, pDict['nProc']))
128
        kill = False
129
        for count in range(sTrans, nTrans):
130
            end = True if count == nTrans-1 else False
131
            self._log_completion_status(count, nTrans, plugin.name)
132
133
            # get the transfer data
134
            logging.info("Transferring the data")
135
            transfer_data = self._transfer_all_data(count)
136
137
            if count == nTrans-1 and plugin.fixed_length == False:
138
                shape = [data.shape for data in transfer_data]
139
                prange = self.remove_extra_slices(prange, shape)
140
141
            # loop over the process data
142
            logging.info("process frames loop")
143
            result, kill = self._process_loop(
144
                    plugin, prange, transfer_data, count, pDict, result, cp)
145
146
            logging.info("Returning the data")
147
            self._return_all_data(count, result, end)
148
149
            if kill:
150
                return 1
151
152
        if not kill:
153
            cu.user_message("%s - 100%% complete" % (plugin.name))
154
155
    def remove_extra_slices(self, prange, transfer_shape):
156
        # loop over datasets:
157
        for i, data in enumerate(self.pDict['in_data']):
158
            pData = data._get_plugin_data()
159
            mft = pData.meta_data.get("max_frames_transfer")
160
            mfp = pData.meta_data.get("max_frames_process")
161
            sdirs = data.get_slice_dimensions()
162
            finish = np.prod([transfer_shape[i][j] for j in sdirs])
163
            rem, full = math.modf((mft - finish)/mfp)
164
            full = int(full)
165
166
            if rem:
167
                rem = (mft-finish) - full
168
                self._update_slice_list("in_sl", i, full, sdirs[0], rem)
169
                for j, out_data in enumerate(self.pDict['out_data']):
170
                    out_pData = out_data._get_plugin_data()
171
                    out_mfp = out_pData.meta_data.get("max_frames_process")
172
                    out_sdir = data.get_slice_dimensions()[0]
173
                    out_rem = rem/(mfp/out_mfp)
174
                    if out_rem%1:
175
                        raise Exception("'Fixed_length' plugin option is invalid")
176
                    self._update_slice_list("out_sl", j, full, out_sdir, int(out_rem))
177
178
        return list(range(prange[0], prange[-1]+1-full))
0 ignored issues
show
introduced by
The variable full does not seem to be defined in case the for loop on line 157 is not entered. Are you sure this can never be the case?
Loading history...
179
180
    def _update_slice_list(self, key, idx, remove, dim, amount):
181
        sl = list(self.pDict[key]['process'][idx][-remove])
182
        s = sl[dim]
183
        sl[dim] = slice(s.start, s.stop - amount*s.step, s.step)
184
        self.pDict[key]['process'][idx][-1] = sl        
185
186
    def _process_loop(self, plugin, prange, tdata, count, pDict, result, cp):
187
        kill_signal = False
188
        for i in prange:
189
            if cp and cp.is_time_to_checkpoint(self, count, i):
190
                # kill signal sent so stop the processing
191
                return result, True
192
            data = self._get_input_data(plugin, tdata, i, count)
193
            res = self._get_output_data(
194
                    plugin.plugin_process_frames(data), i)
195
196
            for j in pDict['nOut']:
197
                if res is not None:
198
                    out_sl = pDict['out_sl']['process'][i][j]
199
                    if any("res_norm" in s for s in self.data_flow):
200
                        # an exception when the metadata is created automatically by a parameters in the plugin
201
                        # this is to fix CGLS_CUDA with a res_norm metadata
202
                        result[j][out_sl] = res[0][j, ]
203
                    else:
204
                        result[j][out_sl] = res[j]
205
                else:
206
                    result[j] = None
207
        return result, kill_signal
208
209
    def __get_checkpoint_params(self, plugin):
210
        cp = self.exp.checkpoint
211
        if cp:
212
            cp._initialise(plugin.get_communicator())
213
            return cp, cp.get_proc_idx(), cp.get_trans_idx()
214
        return None, 0, 0
215
216
    def _initialise(self, plugin):
217
        self.process_setup(plugin)
218
        pDict = self.pDict
219
        result = [np.empty(d._get_plugin_data().get_shape_transfer(),
220
                           dtype=np.float32) for d in pDict['out_data']]
221
        # loop over the transfer data
222
        nTrans = pDict['nTrans']
223
        self.no_processing = True if not nTrans else False
224
        return pDict, result, nTrans
225
226
    def _log_completion_status(self, count, nTrans, name):
227
        percent_complete: float = count / (nTrans * 0.01)
228
        cu.user_message("%s - %3i%% complete" % (name, percent_complete))
229
230
    def _transport_checkpoint(self):
231
        """ The framework has determined it is time to checkpoint.  What
232
        should the transport mechanism do? Override if appropriate. """
233
        return False
234
235
    def _transport_kill_signal(self):
236
        """ 
237
        An opportunity to send a kill signal to the framework.  Return
238
        True or False. """
239
        return False
240
241
    def _get_all_slice_lists(self, data_list, dtype):
242
        """ 
243
        Get all slice lists for the current process.
244
245
        :param list(Data) data_list: Datasets
246
        :returns: A list of dictionaries containing slice lists for each \
247
            dataset
248
        :rtype: list(dict)
249
        """
250
        sl_dict = {}
251
        for data in data_list:
252
            sl = data._get_transport_data().\
253
                    _get_slice_lists_per_process(dtype)
254
            for key, value in sl.items():
255
                if key not in sl_dict:
256
                    sl_dict[key] = [value]
257
                else:
258
                    sl_dict[key].append(value)
259
260
        for key in [k for k in ['process', 'unpad'] if k in list(sl_dict.keys())]:
261
            nData = list(range(len(sl_dict[key])))
262
            #rep = range(len(sl_dict[key][0]))
263
            sl_dict[key] = [[sl_dict[key][i][j] for i in nData if j < len(sl_dict[key][i])] for j in range(len(sl_dict[key][0]))]
264
        return sl_dict
265
266
    def _transfer_all_data(self, count):
267
        """ 
268
        Transfer data from file and pad if required.
269
270
        :param int count: The current frame index.
271
        :returns: All data for this frame and associated padded slice lists
272
        :rtype: list(np.ndarray), list(tuple(slice))
273
        """
274
        pDict = self.pDict
275
        data_list = pDict['in_data']
276
277
        if 'transfer' in list(pDict['in_sl'].keys()):
278
            slice_list = \
279
                [pDict['in_sl']['transfer'][i][count] for i in pDict['nIn']]
280
        else:
281
            slice_list = [slice(None)]*len(pDict['nIn'])
282
283
        section = []
284
        for i, item in enumerate(data_list):
285
            section.append(data_list[i]._get_transport_data().
286
                           _get_padded_data(slice_list[i]))
287
        return section
288
289
    def _get_input_data(self, plugin, trans_data, nproc, ntrans):
290
        data = []
291
        current_sl = []
292
        for d in self.pDict['nIn']:
293
            in_sl = self.pDict['in_sl']['process'][nproc][d]
294
            data.append(self.pDict['squeeze'][d](trans_data[d][in_sl]))
295
            entry = ntrans*self.pDict['nProc'] + nproc
296
            if entry < len(self.pDict['in_sl']['current'][d]):
297
                current_sl.append(self.pDict['in_sl']['current'][d][entry])
298
            else:
299
                current_sl.append(self.pDict['in_sl']['current'][d][-1])
300
        plugin.set_current_slice_list(current_sl)
301
        return data
302
303
    def _get_output_data(self, result, count):
304
        if result is None:
305
            return
306
        unpad_sl = self.pDict['out_sl']['unpad'][count]
307
        result = result if isinstance(result, list) else [result]
308
        for j in self.pDict['nOut']:
309
            if any("res_norm" in s for s in self.data_flow):
310
                # an exception when the metadata is created automatically by a parameters in the plugin
311
                # this is to fix CGLS_CUDA with a res_norm metadata
312
                result[0][j, ] = self.pDict['expand'][j](result[0][j, ])[unpad_sl[j]]
313
            else:
314
                result[j] = self.pDict['expand'][j](result[j])[unpad_sl[j]]
315
        return result
316
317
    def _return_all_data(self, count, result, end):
318
        """ 
319
        Transfer plugin results for current frame to backing files.
320
321
        :param int count: The current frame index.
322
        :param list(np.ndarray) result: plugin results
323
        :param bool end: True if this is the last entry in the slice list.
324
        """
325
        pDict = self.pDict
326
        data_list = pDict['out_data']
327
328
        slice_list = None
329
        if 'transfer' in list(pDict['out_sl'].keys()):
330
            slice_list = \
331
                [pDict['out_sl']['transfer'][i][count] for i in pDict['nOut'] \
332
                     if len(pDict['out_sl']['transfer'][i]) > count]
333
334
        result = [result] if type(result) is not list else result
335
336
        for i, item in enumerate(data_list):
337
            if result[i] is not None:
338
                if slice_list:
339
                    temp = self._remove_excess_data(
340
                            data_list[i], result[i], slice_list[i])
341
                    data_list[i].data[slice_list[i]] = temp
342
                else:
343
                    data_list[i].data = result[i]
344
345
    def _set_global_frame_index(self, plugin, frame_list, nProc):
346
        """ Convert the transfer global frame index to a process global frame
347
            index.
348
        """
349
        process_frames = []
350
        for f in frame_list:
351
            if len(f):
352
                process_frames.append(list(range(f[0]*nProc, (f[-1]+1)*nProc)))
353
354
        process_frames = np.array(process_frames)
355
        nframes = plugin.get_plugin_in_datasets()[0].get_total_frames()
356
        process_frames[process_frames >= nframes] = nframes - 1
357
        frames = process_frames[0] if process_frames.size else process_frames
358
        plugin.set_global_frame_index(frames)
359
360
    def _set_functions(self, data_list, name):
361
        """ Create a dictionary of functions to remove (squeeze) or re-add
362
        (expand) dimensions, of length 1, from each dataset in a list.
363
364
        :param list(Data) data_list: Datasets
365
        :param str name: 'squeeze' or 'expand'
366
        :returns: A dictionary of lambda functions
367
        :rtype: dict
368
        """
369
        str_name = 'self.' + name + '_output'
370
        function = {'expand': self.__create_expand_function,
371
                    'squeeze': self.__create_squeeze_function}
372
        ddict = {}
373
        for i, item in enumerate(data_list):
374
            ddict[i] = {i: str_name + str(i)}
375
            ddict[i] = function[name](data_list[i])
376
        return ddict
377
378
    def __create_expand_function(self, data):
379
        """ Create a function that re-adds missing dimensions of length 1.
380
381
        :param Data data: Dataset
382
        :returns: expansion function
383
        :rtype: lambda
384
        """
385
        slice_dirs = data.get_slice_dimensions()
386
        n_core_dirs = len(data.get_core_dimensions())
387
        new_slice = [slice(None)]*len(data.get_shape())
388
        possible_slices = [copy.copy(new_slice)]
389
390
        pData = data._get_plugin_data()
391
        if pData._get_rank_inc():
392
            possible_slices[0] += [0]*pData._get_rank_inc()
393
394
        if len(slice_dirs) > 1:
395
            for sl in slice_dirs[1:]:
396
                new_slice[sl] = None
397
        possible_slices.append(copy.copy(new_slice))
398
        new_slice[slice_dirs[0]] = None
399
        possible_slices.append(copy.copy(new_slice))
400
        possible_slices = possible_slices[::-1]
401
        return lambda x: x[tuple(possible_slices[len(x.shape)-n_core_dirs])]
402
403
    def __create_squeeze_function(self, data):
404
        """ Create a function that removes dimensions of length 1.
405
406
        :param Data data: Dataset
407
        :returns: squeeze function
408
        :rtype: lambda
409
        """
410
        pData = data._get_plugin_data()
411
        max_frames = pData._get_max_frames_process()
412
413
        pad = True if pData.padding and data.get_slice_dimensions()[0] in \
414
            list(pData.padding._get_padding_directions().keys()) else False
415
416
        n_core_dims = len(data.get_core_dimensions())
417
        squeeze_dims = data.get_slice_dimensions()
418
        if max_frames > 1 or pData._get_no_squeeze() or pad:
419
            squeeze_dims = squeeze_dims[1:]
420
            n_core_dims +=1
421
        if pData._get_rank_inc():
422
            sl = [(slice(None))]*n_core_dims + [None]*pData._get_rank_inc()
423
            return lambda x: np.squeeze(x[tuple(sl)], axis=squeeze_dims)
0 ignored issues
show
introduced by
The variable sl does not seem to be defined for all execution paths.
Loading history...
424
        return lambda x: np.squeeze(x, axis=squeeze_dims)
425
426
    def _remove_excess_data(self, data, result, slice_list):
427
        """ Remove any excess results due to padding for fixed length process \
428
        frames. """
429
430
        mData = data._get_plugin_data().meta_data.get_dictionary()
431
        temp = np.where(np.array(mData['size_list']) > 1)[0]
432
        sdir = mData['sdir'][temp[-1] if temp.size else 0]
433
434
        # Not currently working for basic_transport
435
        if isinstance(slice_list, slice):
436
            return
437
438
        sl = slice_list[sdir]
439
        shape = result.shape
440
441
        if shape[sdir] - (sl.stop - sl.start):
442
            unpad_sl = [slice(None)]*len(shape)
443
            unpad_sl[sdir] = slice(0, sl.stop - sl.start)
444
            result = result[tuple(unpad_sl)]
445
        return result
446
447
    def _setup_h5_files(self):
448
        out_data_dict = self.exp.index["out_data"]
449
450
        current_and_next = False
451
        if 'current_and_next' in self.exp.meta_data.get_dictionary():
452
            current_and_next = self.exp.meta_data.get('current_and_next')
453
454
        count = 0
455
        for key in out_data_dict.keys():
456
            out_data = out_data_dict[key]
457
            filename = self.exp.meta_data.get(["filename", key])
458
            out_data.backing_file = self.hdf5._open_backing_h5(filename, 'a')
459
            c_and_n = 0 if not current_and_next else current_and_next[key]
460
            out_data.group_name, out_data.group = self.hdf5._create_entries(
461
                out_data, key, c_and_n)
462
            count += 1
463
464
    def _set_file_details(self, files):
465
        self.exp.meta_data.set('link_type', files['link_type'])
466
        self.exp.meta_data.set('link_type', {})
467
        self.exp.meta_data.set('filename', {})
468
        self.exp.meta_data.set('group_name', {})
469
        for key in list(self.exp.index['out_data'].keys()):
470
            self.exp.meta_data.set(['link_type', key], files['link_type'][key])
471
            self.exp.meta_data.set(['filename', key], files['filename'][key])
472
            self.exp.meta_data.set(['group_name', key],
473
                                   files['group_name'][key])
474
475
    def _get_filenames(self, plugin_dict):
476
        count = self.exp.meta_data.get('nPlugin') + 1
477
        files = {"filename": {}, "group_name": {}, "link_type": {}}
478
        for key in list(self.exp.index["out_data"].keys()):
479
            name = key + '_p' + str(count) + '_' + \
480
                plugin_dict['id'].split('.')[-1] + '.h5'
481
            link_type = self._get_link_type(key)
482
            files['link_type'][key] = link_type
483
            if link_type == 'final_result':
484
                out_path = self.exp.meta_data.get('out_path')
485
            else:
486
                out_path = self.exp.meta_data.get('inter_path')
487
488
            filename = os.path.join(out_path, name)
489
            group_name = "%i-%s-%s" % (count, plugin_dict['name'], key)
490
            files["filename"][key] = filename
491
            files["group_name"][key] = group_name
492
493
        return files
494
495
    def _get_link_type(self, name):
496
        idx = self.exp.meta_data.get('nPlugin')
497
        temp = [e for entry in self.data_flow[idx+1:] for e in entry]
498
        if name in temp or self.exp.index['out_data'][name].remove:
499
            return 'intermediate'
500
        return 'final_result'
501
502
    def _populate_nexus_file(self, data, iterate_group=None):
503
        filename = self.exp.meta_data.get('nxs_filename')
504
505
        with h5py.File(filename, 'a') as nxs_file:
506
            nxs_entry = nxs_file['entry']
507
            name = data.data_info.get('name')
508
            group_name = self.exp.meta_data.get(['group_name', name])
509
            link_type = self.exp.meta_data.get(['link_type', name])
510
511
            if link_type == 'final_result':
512
                if iterate_group is not None and \
513
                    check_if_end_plugin_in_iterate_group(self.exp):
514
                    is_clone_data = 'clone' in name
515
                    is_even_iterations = \
516
                        iterate_group._ip_fixed_iterations % 2 == 0
517
                    # don't need to create group for:
518
                    # - clone dataset, if running an odd number of iterations
519
                    # - original dataset, if running an even number of
520
                    #   iterations
521
                    if is_clone_data and not is_even_iterations:
522
                        return
523
                    elif not is_clone_data and is_even_iterations:
524
                        return
525
                # the group name for the output of the iterative loop should be
526
                # named after the original dataset, regardless of if the link
527
                # eventually points to the original or the clone, for the sake
528
                # of the linkname referencing the dataset name set in
529
                # savu_config
530
                group_name = 'final_result_' + data.get_name(orig=True)
531
            else:
532
                link = nxs_entry.require_group(link_type.encode("ascii"))
533
                link.attrs[NX_CLASS] = 'NXcollection'
534
                nxs_entry = link
535
536
            # delete the group if it already exists
537
            if group_name in nxs_entry:
538
                del nxs_entry[group_name]
539
540
            plugin_entry = nxs_entry.require_group(group_name)
541
            plugin_entry.attrs[NX_CLASS] = 'NXdata'
542
            if iterate_group is not None and \
543
                check_if_end_plugin_in_iterate_group(self.exp):
544
                # always write the metadata under the name of the original
545
                # dataset, not the clone dataset
546
                self._output_metadata(data, plugin_entry,
547
                    data.get_name(orig=False))
548
            else:
549
                self._output_metadata(data, plugin_entry, name)
550
551
    def _populate_pre_run_nexus_file(self, data):
552
        filename = self.exp.meta_data.get('nxs_filename')
553
554
        data_path = self.exp.meta_data["data_path"]
555
        image_key_path = self.exp.meta_data["image_key_path"]
556
        name = data.data_info.get('name')
557
        group_name = self.exp.meta_data.get(['group_name', name])
558
        with h5py.File(filename, 'a') as nxs_file:
559
            if data_path in nxs_file:
560
                del nxs_file[data_path]
561
            nxs_file[data_path] = h5py.ExternalLink(os.path.abspath(data.backing_file.filename), f"{group_name}/data")
562
563
            if image_key_path in nxs_file:
564
                nxs_file[image_key_path][::] = data.data.image_key[::]
565
            #nxs_file[data_path].attrs.create("pre_run", True)
566
567
    def _output_metadata(self, data, entry, name, dump=False):
568
        self.__output_data_type(entry, data, name)
569
        mDict = data.meta_data.get_dictionary()
570
        self._output_metadata_dict(entry.require_group('meta_data'), mDict)
571
572
        if not dump:
573
            self.__output_axis_labels(data, entry)
574
            self.__output_data_patterns(data, entry)
575
            if self.exp.meta_data.get('link_type')[name] == 'input_data':
576
                # output the filename
577
                entry['file_path'] = \
578
                    os.path.abspath(self.exp.meta_data.get('data_file'))
579
580
    def __output_data_type(self, entry, data, name):
581
        data = data.data if 'data' in list(data.__dict__.keys()) else data
582
        if isinstance(data, h5py.Dataset):
583
            return
584
585
        entry = entry.require_group('data_type')
586
        entry.attrs[NX_CLASS] = 'NXcollection'
587
588
        ltype = self.exp.meta_data.get('link_type')
589
        if name in list(ltype.keys()) and ltype[name] == 'input_data':
590
            self.__output_data(entry, data.__class__.__name__, 'cls')
591
            return
592
593
        args, kwargs, cls, extras = data._get_parameters(data.get_clone_args())
594
595
        for key, value in kwargs.items():
596
            gp = entry.require_group('kwargs')
597
            if isinstance(value, BaseType):
598
                self.__output_data_type(gp.require_group(key), value, key)
599
            else:
600
                self.__output_data(gp, value, key)
601
602
        for key, value in extras.items():
603
            gp = entry.require_group('extras')
604
            if isinstance(value, BaseType):
605
                self.__output_data_type(gp.require_group(key), value, key)
606
            else:
607
                self.__output_data(gp, value, key)
608
609
        for i, item in enumerate(args):
610
            gp = entry.require_group('args')
611
            self.__output_data(gp, args[i], ''.join(['args', str(i)]))
612
613
        self.__output_data(entry, cls, 'cls')
614
615
        if 'data' in list(data.__dict__.keys()) and not \
616
                isinstance(data.data, h5py.Dataset):
617
            gp = entry.require_group('data')
618
            self.__output_data_type(gp, data.data, 'data')
619
620 View Code Duplication
    def __output_data(self, entry, data, name):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
621
        if isinstance(data, dict):
622
            entry = entry.require_group(name)
623
            entry.attrs[NX_CLASS] = 'NXcollection'
624
            for key, value in data.items():
625
                self.__output_data(entry, value, key)
626
        else:
627
            try:
628
                self.__create_dataset(entry, name, data)
629
            except Exception:
630
                try:
631
                    import json
632
                    data = np.array([json.dumps(data).encode("ascii")])
633
                    self.__create_dataset(entry, name, data)
634
                except Exception:
635
                    try:
636
                        data = cu._savu_encoder(data)
637
                        self.__create_dataset(entry, name, data)
638
                    except:
639
                        raise Exception('Unable to output %s to file.' % name)
640
641
    def __create_dataset(self, entry, name, data):
642
        if name not in list(entry.keys()):
643
            entry.create_dataset(name, data=data)
644
        else:
645
            entry[name][...] = data
646
647
    def __output_axis_labels(self, data, entry):
648
        axis_labels = data.data_info.get("axis_labels")
649
        ddict = data.meta_data.get_dictionary()
650
651
        axes = []
652
        count = 0
653
        for labels in axis_labels:
654
            name = list(labels.keys())[0]
655
            axes.append(name)
656
            entry.attrs[name + '_indices'] = count
657
658
            mData = ddict[name] if name in list(ddict.keys()) \
659
                else np.arange(data.get_shape()[count])
660
            if isinstance(mData, list):
661
                mData = np.array(mData)
662
663
            if 'U' in str(mData.dtype):
664
                mData = mData.astype(np.string_)
665
666
            axis_entry = entry.require_dataset(name, mData.shape, mData.dtype)
667
            axis_entry[...] = mData[...]
668
            axis_entry.attrs['units'] = list(labels.values())[0]
669
            count += 1
670
        entry.attrs['axes'] = axes
671
672 View Code Duplication
    def __output_data_patterns(self, data, entry):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
673
        data_patterns = data.data_info.get("data_patterns")
674
        entry = entry.require_group('patterns')
675
        entry.attrs[NX_CLASS] = 'NXcollection'
676
        for pattern in data_patterns:
677
            nx_data = entry.require_group(pattern)
678
            nx_data.attrs[NX_CLASS] = 'NXparameters'
679
            values = data_patterns[pattern]
680
            self.__output_data(nx_data, values['core_dims'], 'core_dims')
681
            self.__output_data(nx_data, values['slice_dims'], 'slice_dims')
682
683
    def _output_metadata_dict(self, entry, mData):
684
        entry.attrs[NX_CLASS] = 'NXcollection'
685
        for key, value in mData.items():
686
            nx_data = entry.require_group(key)
687
            if isinstance(value, dict):
688
                self._output_metadata_dict(nx_data, value)
689
            else:
690
                nx_data.attrs[NX_CLASS] = 'NXdata'
691
                self.__output_data(nx_data, value, key)
692