| Total Complexity | 144 | 
| Total Lines | 692 | 
| Duplicated Lines | 4.34 % | 
| Changes | 0 | ||
Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.
Common duplication problems, and corresponding solutions are:
Complex classes like savu.core.transports.base_transport often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | # Copyright 2015 Diamond Light Source Ltd.  | 
            ||
| 2 | #  | 
            ||
| 3 | # Licensed under the Apache License, Version 2.0 (the "License");  | 
            ||
| 4 | # you may not use this file except in compliance with the License.  | 
            ||
| 5 | # You may obtain a copy of the License at  | 
            ||
| 6 | #  | 
            ||
| 7 | # http://www.apache.org/licenses/LICENSE-2.0  | 
            ||
| 8 | #  | 
            ||
| 9 | # Unless required by applicable law or agreed to in writing, software  | 
            ||
| 10 | # distributed under the License is distributed on an "AS IS" BASIS,  | 
            ||
| 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  | 
            ||
| 12 | # See the License for the specific language governing permissions and  | 
            ||
| 13 | # limitations under the License.  | 
            ||
| 14 | |||
| 15 | """  | 
            ||
| 16 | .. module:: base_transport  | 
            ||
| 17 | :platform: Unix  | 
            ||
| 18 | :synopsis: A BaseTransport class which implements functions that control\  | 
            ||
| 19 | the interaction between the data and plugin layers.  | 
            ||
| 20 | |||
| 21 | .. moduleauthor:: Nicola Wadeson <[email protected]>  | 
            ||
| 22 | |||
| 23 | """  | 
            ||
| 24 | |||
| 25 | import os  | 
            ||
| 26 | import time  | 
            ||
| 27 | import copy  | 
            ||
| 28 | import h5py  | 
            ||
| 29 | import math  | 
            ||
| 30 | import logging  | 
            ||
| 31 | import numpy as np  | 
            ||
| 32 | |||
| 33 | import savu.core.utils as cu  | 
            ||
| 34 | import savu.plugins.utils as pu  | 
            ||
| 35 | from savu.data.data_structures.data_types.base_type import BaseType  | 
            ||
| 36 | from savu.core.iterate_plugin_group_utils import \  | 
            ||
| 37 | check_if_end_plugin_in_iterate_group  | 
            ||
| 38 | |||
| 39 | NX_CLASS = 'NX_class'  | 
            ||
| 40 | |||
| 41 | |||
| 42 | class BaseTransport(object):  | 
            ||
| 43 | """  | 
            ||
| 44 | Implements functions that control the interaction between the data and  | 
            ||
| 45 | plugin layers.  | 
            ||
| 46 | """  | 
            ||
| 47 | |||
| 48 | def __init__(self):  | 
            ||
| 49 | self.pDict = None  | 
            ||
| 50 | self.no_processing = False  | 
            ||
| 51 | |||
| 52 | def _transport_initialise(self, options):  | 
            ||
| 53 | """  | 
            ||
| 54 | Any initial setup required by the transport mechanism on start up.\  | 
            ||
| 55 | This is called before the experiment is initialised.  | 
            ||
| 56 | """  | 
            ||
| 57 |         raise NotImplementedError("transport_control_setup needs to be " | 
            ||
| 58 | "implemented in %s", self.__class__)  | 
            ||
| 59 | |||
| 60 | def _transport_update_plugin_list(self):  | 
            ||
| 61 | """  | 
            ||
| 62 | This method provides an opportunity to add or remove items from the  | 
            ||
| 63 | plugin list before plugin list check.  | 
            ||
| 64 | """  | 
            ||
| 65 | |||
| 66 | def _transport_pre_plugin_list_run(self):  | 
            ||
| 67 | """  | 
            ||
| 68 | This method is called after all datasets have been created but BEFORE  | 
            ||
| 69 | the plugin list is processed.  | 
            ||
| 70 | """  | 
            ||
| 71 | |||
| 72 | def _transport_load_plugin(self, exp, plugin_dict):  | 
            ||
| 73 | """ This method is called before each plugin is loaded """  | 
            ||
| 74 | return pu.plugin_loader(exp, plugin_dict)  | 
            ||
| 75 | |||
| 76 | def _transport_pre_plugin(self):  | 
            ||
| 77 | """  | 
            ||
| 78 | This method is called directly BEFORE each plugin is executed, but \  | 
            ||
| 79 | after the plugin is loaded.  | 
            ||
| 80 | """  | 
            ||
| 81 | |||
| 82 | def _transport_post_plugin(self):  | 
            ||
| 83 | """  | 
            ||
| 84 | This method is called directly AFTER each plugin is executed.  | 
            ||
| 85 | """  | 
            ||
| 86 | |||
| 87 | def _transport_post_plugin_list_run(self):  | 
            ||
| 88 | """  | 
            ||
| 89 | This method is called AFTER the full plugin list has been processed.  | 
            ||
| 90 | """  | 
            ||
| 91 | |||
| 92 | def _transport_terminate_dataset(self, data):  | 
            ||
| 93 | """ A dataset that will subequently be removed by the framework.  | 
            ||
| 94 | |||
| 95 | :param Data data: A data object to finalise.  | 
            ||
| 96 | """  | 
            ||
| 97 | |||
| 98 | def process_setup(self, plugin):  | 
            ||
| 99 |         pDict = {} | 
            ||
| 100 | pDict['in_data'], pDict['out_data'] = plugin.get_datasets()  | 
            ||
| 101 | pDict['in_sl'] = self._get_all_slice_lists(pDict['in_data'], 'in')  | 
            ||
| 102 | pDict['out_sl'] = self._get_all_slice_lists(pDict['out_data'], 'out')  | 
            ||
| 103 | pDict['nIn'] = list(range(len(pDict['in_data'])))  | 
            ||
| 104 | pDict['nOut'] = list(range(len(pDict['out_data'])))  | 
            ||
| 105 | pDict['nProc'] = len(pDict['in_sl']['process'])  | 
            ||
| 106 | if 'transfer' in list(pDict['in_sl'].keys()):  | 
            ||
| 107 | pDict['nTrans'] = len(pDict['in_sl']['transfer'][0])  | 
            ||
| 108 | else:  | 
            ||
| 109 | pDict['nTrans'] = 1  | 
            ||
| 110 | pDict['squeeze'] = self._set_functions(pDict['in_data'], 'squeeze')  | 
            ||
| 111 | pDict['expand'] = self._set_functions(pDict['out_data'], 'expand')  | 
            ||
| 112 | |||
| 113 | frames = [f for f in pDict['in_sl']['frames']]  | 
            ||
| 114 | self._set_global_frame_index(plugin, frames, pDict['nProc'])  | 
            ||
| 115 | self.pDict = pDict  | 
            ||
| 116 | |||
| 117 | def _transport_process(self, plugin):  | 
            ||
| 118 | """ Organise required data and execute the main plugin processing.  | 
            ||
| 119 | |||
| 120 | :param plugin plugin: The current plugin instance.  | 
            ||
| 121 | """  | 
            ||
| 122 |         logging.info("transport_process initialise") | 
            ||
| 123 | pDict, result, nTrans = self._initialise(plugin)  | 
            ||
| 124 |         logging.info("transport_process get_checkpoint_params") | 
            ||
| 125 | cp, sProc, sTrans = self.__get_checkpoint_params(plugin)  | 
            ||
| 126 | |||
| 127 | prange = list(range(sProc, pDict['nProc']))  | 
            ||
| 128 | kill = False  | 
            ||
| 129 | for count in range(sTrans, nTrans):  | 
            ||
| 130 | end = True if count == nTrans-1 else False  | 
            ||
| 131 | self._log_completion_status(count, nTrans, plugin.name)  | 
            ||
| 132 | |||
| 133 | # get the transfer data  | 
            ||
| 134 |             logging.info("Transferring the data") | 
            ||
| 135 | transfer_data = self._transfer_all_data(count)  | 
            ||
| 136 | |||
| 137 | if count == nTrans-1 and plugin.fixed_length == False:  | 
            ||
| 138 | shape = [data.shape for data in transfer_data]  | 
            ||
| 139 | prange = self.remove_extra_slices(prange, shape)  | 
            ||
| 140 | |||
| 141 | # loop over the process data  | 
            ||
| 142 |             logging.info("process frames loop") | 
            ||
| 143 | result, kill = self._process_loop(  | 
            ||
| 144 | plugin, prange, transfer_data, count, pDict, result, cp)  | 
            ||
| 145 | |||
| 146 |             logging.info("Returning the data") | 
            ||
| 147 | self._return_all_data(count, result, end)  | 
            ||
| 148 | |||
| 149 | if kill:  | 
            ||
| 150 | return 1  | 
            ||
| 151 | |||
| 152 | if not kill:  | 
            ||
| 153 |             cu.user_message("%s - 100%% complete" % (plugin.name)) | 
            ||
| 154 | |||
| 155 | def remove_extra_slices(self, prange, transfer_shape):  | 
            ||
| 156 | # loop over datasets:  | 
            ||
| 157 | for i, data in enumerate(self.pDict['in_data']):  | 
            ||
| 158 | pData = data._get_plugin_data()  | 
            ||
| 159 |             mft = pData.meta_data.get("max_frames_transfer") | 
            ||
| 160 |             mfp = pData.meta_data.get("max_frames_process") | 
            ||
| 161 | sdirs = data.get_slice_dimensions()  | 
            ||
| 162 | finish = np.prod([transfer_shape[i][j] for j in sdirs])  | 
            ||
| 163 | rem, full = math.modf((mft - finish)/mfp)  | 
            ||
| 164 | full = int(full)  | 
            ||
| 165 | |||
| 166 | if rem:  | 
            ||
| 167 | rem = (mft-finish) - full  | 
            ||
| 168 |                 self._update_slice_list("in_sl", i, full, sdirs[0], rem) | 
            ||
| 169 | for j, out_data in enumerate(self.pDict['out_data']):  | 
            ||
| 170 | out_pData = out_data._get_plugin_data()  | 
            ||
| 171 |                     out_mfp = out_pData.meta_data.get("max_frames_process") | 
            ||
| 172 | out_sdir = data.get_slice_dimensions()[0]  | 
            ||
| 173 | out_rem = rem/(mfp/out_mfp)  | 
            ||
| 174 | if out_rem%1:  | 
            ||
| 175 |                         raise Exception("'Fixed_length' plugin option is invalid") | 
            ||
| 176 |                     self._update_slice_list("out_sl", j, full, out_sdir, int(out_rem)) | 
            ||
| 177 | |||
| 178 | return list(range(prange[0], prange[-1]+1-full))  | 
            ||
| 
                                                                                                    
                        
                         | 
                |||
| 179 | |||
| 180 | def _update_slice_list(self, key, idx, remove, dim, amount):  | 
            ||
| 181 | sl = list(self.pDict[key]['process'][idx][-remove])  | 
            ||
| 182 | s = sl[dim]  | 
            ||
| 183 | sl[dim] = slice(s.start, s.stop - amount*s.step, s.step)  | 
            ||
| 184 | self.pDict[key]['process'][idx][-1] = sl  | 
            ||
| 185 | |||
| 186 | def _process_loop(self, plugin, prange, tdata, count, pDict, result, cp):  | 
            ||
| 187 | kill_signal = False  | 
            ||
| 188 | for i in prange:  | 
            ||
| 189 | if cp and cp.is_time_to_checkpoint(self, count, i):  | 
            ||
| 190 | # kill signal sent so stop the processing  | 
            ||
| 191 | return result, True  | 
            ||
| 192 | data = self._get_input_data(plugin, tdata, i, count)  | 
            ||
| 193 | res = self._get_output_data(  | 
            ||
| 194 | plugin.plugin_process_frames(data), i)  | 
            ||
| 195 | |||
| 196 | for j in pDict['nOut']:  | 
            ||
| 197 | if res is not None:  | 
            ||
| 198 | out_sl = pDict['out_sl']['process'][i][j]  | 
            ||
| 199 |                     if any("res_norm" in s for s in self.data_flow): | 
            ||
| 200 | # an exception when the metadata is created automatically by a parameters in the plugin  | 
            ||
| 201 | # this is to fix CGLS_CUDA with a res_norm metadata  | 
            ||
| 202 | result[j][out_sl] = res[0][j, ]  | 
            ||
| 203 | else:  | 
            ||
| 204 | result[j][out_sl] = res[j]  | 
            ||
| 205 | else:  | 
            ||
| 206 | result[j] = None  | 
            ||
| 207 | return result, kill_signal  | 
            ||
| 208 | |||
| 209 | def __get_checkpoint_params(self, plugin):  | 
            ||
| 210 | cp = self.exp.checkpoint  | 
            ||
| 211 | if cp:  | 
            ||
| 212 | cp._initialise(plugin.get_communicator())  | 
            ||
| 213 | return cp, cp.get_proc_idx(), cp.get_trans_idx()  | 
            ||
| 214 | return None, 0, 0  | 
            ||
| 215 | |||
| 216 | def _initialise(self, plugin):  | 
            ||
| 217 | self.process_setup(plugin)  | 
            ||
| 218 | pDict = self.pDict  | 
            ||
| 219 | result = [np.empty(d._get_plugin_data().get_shape_transfer(),  | 
            ||
| 220 | dtype=np.float32) for d in pDict['out_data']]  | 
            ||
| 221 | # loop over the transfer data  | 
            ||
| 222 | nTrans = pDict['nTrans']  | 
            ||
| 223 | self.no_processing = True if not nTrans else False  | 
            ||
| 224 | return pDict, result, nTrans  | 
            ||
| 225 | |||
| 226 | def _log_completion_status(self, count, nTrans, name):  | 
            ||
| 227 | percent_complete: float = count / (nTrans * 0.01)  | 
            ||
| 228 |         cu.user_message("%s - %3i%% complete" % (name, percent_complete)) | 
            ||
| 229 | |||
| 230 | def _transport_checkpoint(self):  | 
            ||
| 231 | """ The framework has determined it is time to checkpoint. What  | 
            ||
| 232 | should the transport mechanism do? Override if appropriate. """  | 
            ||
| 233 | return False  | 
            ||
| 234 | |||
| 235 | def _transport_kill_signal(self):  | 
            ||
| 236 | """  | 
            ||
| 237 | An opportunity to send a kill signal to the framework. Return  | 
            ||
| 238 | True or False. """  | 
            ||
| 239 | return False  | 
            ||
| 240 | |||
| 241 | def _get_all_slice_lists(self, data_list, dtype):  | 
            ||
| 242 | """  | 
            ||
| 243 | Get all slice lists for the current process.  | 
            ||
| 244 | |||
| 245 | :param list(Data) data_list: Datasets  | 
            ||
| 246 | :returns: A list of dictionaries containing slice lists for each \  | 
            ||
| 247 | dataset  | 
            ||
| 248 | :rtype: list(dict)  | 
            ||
| 249 | """  | 
            ||
| 250 |         sl_dict = {} | 
            ||
| 251 | for data in data_list:  | 
            ||
| 252 | sl = data._get_transport_data().\  | 
            ||
| 253 | _get_slice_lists_per_process(dtype)  | 
            ||
| 254 | for key, value in sl.items():  | 
            ||
| 255 | if key not in sl_dict:  | 
            ||
| 256 | sl_dict[key] = [value]  | 
            ||
| 257 | else:  | 
            ||
| 258 | sl_dict[key].append(value)  | 
            ||
| 259 | |||
| 260 | for key in [k for k in ['process', 'unpad'] if k in list(sl_dict.keys())]:  | 
            ||
| 261 | nData = list(range(len(sl_dict[key])))  | 
            ||
| 262 | #rep = range(len(sl_dict[key][0]))  | 
            ||
| 263 | sl_dict[key] = [[sl_dict[key][i][j] for i in nData if j < len(sl_dict[key][i])] for j in range(len(sl_dict[key][0]))]  | 
            ||
| 264 | return sl_dict  | 
            ||
| 265 | |||
| 266 | def _transfer_all_data(self, count):  | 
            ||
| 267 | """  | 
            ||
| 268 | Transfer data from file and pad if required.  | 
            ||
| 269 | |||
| 270 | :param int count: The current frame index.  | 
            ||
| 271 | :returns: All data for this frame and associated padded slice lists  | 
            ||
| 272 | :rtype: list(np.ndarray), list(tuple(slice))  | 
            ||
| 273 | """  | 
            ||
| 274 | pDict = self.pDict  | 
            ||
| 275 | data_list = pDict['in_data']  | 
            ||
| 276 | |||
| 277 | if 'transfer' in list(pDict['in_sl'].keys()):  | 
            ||
| 278 | slice_list = \  | 
            ||
| 279 | [pDict['in_sl']['transfer'][i][count] for i in pDict['nIn']]  | 
            ||
| 280 | else:  | 
            ||
| 281 | slice_list = [slice(None)]*len(pDict['nIn'])  | 
            ||
| 282 | |||
| 283 | section = []  | 
            ||
| 284 | for i, item in enumerate(data_list):  | 
            ||
| 285 | section.append(data_list[i]._get_transport_data().  | 
            ||
| 286 | _get_padded_data(slice_list[i]))  | 
            ||
| 287 | return section  | 
            ||
| 288 | |||
| 289 | def _get_input_data(self, plugin, trans_data, nproc, ntrans):  | 
            ||
| 290 | data = []  | 
            ||
| 291 | current_sl = []  | 
            ||
| 292 | for d in self.pDict['nIn']:  | 
            ||
| 293 | in_sl = self.pDict['in_sl']['process'][nproc][d]  | 
            ||
| 294 | data.append(self.pDict['squeeze'][d](trans_data[d][in_sl]))  | 
            ||
| 295 | entry = ntrans*self.pDict['nProc'] + nproc  | 
            ||
| 296 | if entry < len(self.pDict['in_sl']['current'][d]):  | 
            ||
| 297 | current_sl.append(self.pDict['in_sl']['current'][d][entry])  | 
            ||
| 298 | else:  | 
            ||
| 299 | current_sl.append(self.pDict['in_sl']['current'][d][-1])  | 
            ||
| 300 | plugin.set_current_slice_list(current_sl)  | 
            ||
| 301 | return data  | 
            ||
| 302 | |||
| 303 | def _get_output_data(self, result, count):  | 
            ||
| 304 | if result is None:  | 
            ||
| 305 | return  | 
            ||
| 306 | unpad_sl = self.pDict['out_sl']['unpad'][count]  | 
            ||
| 307 | result = result if isinstance(result, list) else [result]  | 
            ||
| 308 | for j in self.pDict['nOut']:  | 
            ||
| 309 |             if any("res_norm" in s for s in self.data_flow): | 
            ||
| 310 | # an exception when the metadata is created automatically by a parameters in the plugin  | 
            ||
| 311 | # this is to fix CGLS_CUDA with a res_norm metadata  | 
            ||
| 312 | result[0][j, ] = self.pDict['expand'][j](result[0][j, ])[unpad_sl[j]]  | 
            ||
| 313 | else:  | 
            ||
| 314 | result[j] = self.pDict['expand'][j](result[j])[unpad_sl[j]]  | 
            ||
| 315 | return result  | 
            ||
| 316 | |||
| 317 | def _return_all_data(self, count, result, end):  | 
            ||
| 318 | """  | 
            ||
| 319 | Transfer plugin results for current frame to backing files.  | 
            ||
| 320 | |||
| 321 | :param int count: The current frame index.  | 
            ||
| 322 | :param list(np.ndarray) result: plugin results  | 
            ||
| 323 | :param bool end: True if this is the last entry in the slice list.  | 
            ||
| 324 | """  | 
            ||
| 325 | pDict = self.pDict  | 
            ||
| 326 | data_list = pDict['out_data']  | 
            ||
| 327 | |||
| 328 | slice_list = None  | 
            ||
| 329 | if 'transfer' in list(pDict['out_sl'].keys()):  | 
            ||
| 330 | slice_list = \  | 
            ||
| 331 | [pDict['out_sl']['transfer'][i][count] for i in pDict['nOut'] \  | 
            ||
| 332 | if len(pDict['out_sl']['transfer'][i]) > count]  | 
            ||
| 333 | |||
| 334 | result = [result] if type(result) is not list else result  | 
            ||
| 335 | |||
| 336 | for i, item in enumerate(data_list):  | 
            ||
| 337 | if result[i] is not None:  | 
            ||
| 338 | if slice_list:  | 
            ||
| 339 | temp = self._remove_excess_data(  | 
            ||
| 340 | data_list[i], result[i], slice_list[i])  | 
            ||
| 341 | data_list[i].data[slice_list[i]] = temp  | 
            ||
| 342 | else:  | 
            ||
| 343 | data_list[i].data = result[i]  | 
            ||
| 344 | |||
| 345 | def _set_global_frame_index(self, plugin, frame_list, nProc):  | 
            ||
| 346 | """ Convert the transfer global frame index to a process global frame  | 
            ||
| 347 | index.  | 
            ||
| 348 | """  | 
            ||
| 349 | process_frames = []  | 
            ||
| 350 | for f in frame_list:  | 
            ||
| 351 | if len(f):  | 
            ||
| 352 | process_frames.append(list(range(f[0]*nProc, (f[-1]+1)*nProc)))  | 
            ||
| 353 | |||
| 354 | process_frames = np.array(process_frames)  | 
            ||
| 355 | nframes = plugin.get_plugin_in_datasets()[0].get_total_frames()  | 
            ||
| 356 | process_frames[process_frames >= nframes] = nframes - 1  | 
            ||
| 357 | frames = process_frames[0] if process_frames.size else process_frames  | 
            ||
| 358 | plugin.set_global_frame_index(frames)  | 
            ||
| 359 | |||
| 360 | def _set_functions(self, data_list, name):  | 
            ||
| 361 | """ Create a dictionary of functions to remove (squeeze) or re-add  | 
            ||
| 362 | (expand) dimensions, of length 1, from each dataset in a list.  | 
            ||
| 363 | |||
| 364 | :param list(Data) data_list: Datasets  | 
            ||
| 365 | :param str name: 'squeeze' or 'expand'  | 
            ||
| 366 | :returns: A dictionary of lambda functions  | 
            ||
| 367 | :rtype: dict  | 
            ||
| 368 | """  | 
            ||
| 369 | str_name = 'self.' + name + '_output'  | 
            ||
| 370 |         function = {'expand': self.__create_expand_function, | 
            ||
| 371 | 'squeeze': self.__create_squeeze_function}  | 
            ||
| 372 |         ddict = {} | 
            ||
| 373 | for i, item in enumerate(data_list):  | 
            ||
| 374 |             ddict[i] = {i: str_name + str(i)} | 
            ||
| 375 | ddict[i] = function[name](data_list[i])  | 
            ||
| 376 | return ddict  | 
            ||
| 377 | |||
| 378 | def __create_expand_function(self, data):  | 
            ||
| 379 | """ Create a function that re-adds missing dimensions of length 1.  | 
            ||
| 380 | |||
| 381 | :param Data data: Dataset  | 
            ||
| 382 | :returns: expansion function  | 
            ||
| 383 | :rtype: lambda  | 
            ||
| 384 | """  | 
            ||
| 385 | slice_dirs = data.get_slice_dimensions()  | 
            ||
| 386 | n_core_dirs = len(data.get_core_dimensions())  | 
            ||
| 387 | new_slice = [slice(None)]*len(data.get_shape())  | 
            ||
| 388 | possible_slices = [copy.copy(new_slice)]  | 
            ||
| 389 | |||
| 390 | pData = data._get_plugin_data()  | 
            ||
| 391 | if pData._get_rank_inc():  | 
            ||
| 392 | possible_slices[0] += [0]*pData._get_rank_inc()  | 
            ||
| 393 | |||
| 394 | if len(slice_dirs) > 1:  | 
            ||
| 395 | for sl in slice_dirs[1:]:  | 
            ||
| 396 | new_slice[sl] = None  | 
            ||
| 397 | possible_slices.append(copy.copy(new_slice))  | 
            ||
| 398 | new_slice[slice_dirs[0]] = None  | 
            ||
| 399 | possible_slices.append(copy.copy(new_slice))  | 
            ||
| 400 | possible_slices = possible_slices[::-1]  | 
            ||
| 401 | return lambda x: x[tuple(possible_slices[len(x.shape)-n_core_dirs])]  | 
            ||
| 402 | |||
| 403 | def __create_squeeze_function(self, data):  | 
            ||
| 404 | """ Create a function that removes dimensions of length 1.  | 
            ||
| 405 | |||
| 406 | :param Data data: Dataset  | 
            ||
| 407 | :returns: squeeze function  | 
            ||
| 408 | :rtype: lambda  | 
            ||
| 409 | """  | 
            ||
| 410 | pData = data._get_plugin_data()  | 
            ||
| 411 | max_frames = pData._get_max_frames_process()  | 
            ||
| 412 | |||
| 413 | pad = True if pData.padding and data.get_slice_dimensions()[0] in \  | 
            ||
| 414 | list(pData.padding._get_padding_directions().keys()) else False  | 
            ||
| 415 | |||
| 416 | n_core_dims = len(data.get_core_dimensions())  | 
            ||
| 417 | squeeze_dims = data.get_slice_dimensions()  | 
            ||
| 418 | if max_frames > 1 or pData._get_no_squeeze() or pad:  | 
            ||
| 419 | squeeze_dims = squeeze_dims[1:]  | 
            ||
| 420 | n_core_dims +=1  | 
            ||
| 421 | if pData._get_rank_inc():  | 
            ||
| 422 | sl = [(slice(None))]*n_core_dims + [None]*pData._get_rank_inc()  | 
            ||
| 423 | return lambda x: np.squeeze(x[tuple(sl)], axis=squeeze_dims)  | 
            ||
| 424 | return lambda x: np.squeeze(x, axis=squeeze_dims)  | 
            ||
| 425 | |||
| 426 | def _remove_excess_data(self, data, result, slice_list):  | 
            ||
| 427 | """ Remove any excess results due to padding for fixed length process \  | 
            ||
| 428 | frames. """  | 
            ||
| 429 | |||
| 430 | mData = data._get_plugin_data().meta_data.get_dictionary()  | 
            ||
| 431 | temp = np.where(np.array(mData['size_list']) > 1)[0]  | 
            ||
| 432 | sdir = mData['sdir'][temp[-1] if temp.size else 0]  | 
            ||
| 433 | |||
| 434 | # Not currently working for basic_transport  | 
            ||
| 435 | if isinstance(slice_list, slice):  | 
            ||
| 436 | return  | 
            ||
| 437 | |||
| 438 | sl = slice_list[sdir]  | 
            ||
| 439 | shape = result.shape  | 
            ||
| 440 | |||
| 441 | if shape[sdir] - (sl.stop - sl.start):  | 
            ||
| 442 | unpad_sl = [slice(None)]*len(shape)  | 
            ||
| 443 | unpad_sl[sdir] = slice(0, sl.stop - sl.start)  | 
            ||
| 444 | result = result[tuple(unpad_sl)]  | 
            ||
| 445 | return result  | 
            ||
| 446 | |||
| 447 | def _setup_h5_files(self):  | 
            ||
| 448 | out_data_dict = self.exp.index["out_data"]  | 
            ||
| 449 | |||
| 450 | current_and_next = False  | 
            ||
| 451 | if 'current_and_next' in self.exp.meta_data.get_dictionary():  | 
            ||
| 452 |             current_and_next = self.exp.meta_data.get('current_and_next') | 
            ||
| 453 | |||
| 454 | count = 0  | 
            ||
| 455 | for key in out_data_dict.keys():  | 
            ||
| 456 | out_data = out_data_dict[key]  | 
            ||
| 457 | filename = self.exp.meta_data.get(["filename", key])  | 
            ||
| 458 | out_data.backing_file = self.hdf5._open_backing_h5(filename, 'a')  | 
            ||
| 459 | c_and_n = 0 if not current_and_next else current_and_next[key]  | 
            ||
| 460 | out_data.group_name, out_data.group = self.hdf5._create_entries(  | 
            ||
| 461 | out_data, key, c_and_n)  | 
            ||
| 462 | count += 1  | 
            ||
| 463 | |||
| 464 | def _set_file_details(self, files):  | 
            ||
| 465 |         self.exp.meta_data.set('link_type', files['link_type']) | 
            ||
| 466 |         self.exp.meta_data.set('link_type', {}) | 
            ||
| 467 |         self.exp.meta_data.set('filename', {}) | 
            ||
| 468 |         self.exp.meta_data.set('group_name', {}) | 
            ||
| 469 | for key in list(self.exp.index['out_data'].keys()):  | 
            ||
| 470 | self.exp.meta_data.set(['link_type', key], files['link_type'][key])  | 
            ||
| 471 | self.exp.meta_data.set(['filename', key], files['filename'][key])  | 
            ||
| 472 | self.exp.meta_data.set(['group_name', key],  | 
            ||
| 473 | files['group_name'][key])  | 
            ||
| 474 | |||
| 475 | def _get_filenames(self, plugin_dict):  | 
            ||
| 476 |         count = self.exp.meta_data.get('nPlugin') + 1 | 
            ||
| 477 |         files = {"filename": {}, "group_name": {}, "link_type": {}} | 
            ||
| 478 | for key in list(self.exp.index["out_data"].keys()):  | 
            ||
| 479 | name = key + '_p' + str(count) + '_' + \  | 
            ||
| 480 |                 plugin_dict['id'].split('.')[-1] + '.h5' | 
            ||
| 481 | link_type = self._get_link_type(key)  | 
            ||
| 482 | files['link_type'][key] = link_type  | 
            ||
| 483 | if link_type == 'final_result':  | 
            ||
| 484 |                 out_path = self.exp.meta_data.get('out_path') | 
            ||
| 485 | else:  | 
            ||
| 486 |                 out_path = self.exp.meta_data.get('inter_path') | 
            ||
| 487 | |||
| 488 | filename = os.path.join(out_path, name)  | 
            ||
| 489 | group_name = "%i-%s-%s" % (count, plugin_dict['name'], key)  | 
            ||
| 490 | files["filename"][key] = filename  | 
            ||
| 491 | files["group_name"][key] = group_name  | 
            ||
| 492 | |||
| 493 | return files  | 
            ||
| 494 | |||
| 495 | def _get_link_type(self, name):  | 
            ||
| 496 |         idx = self.exp.meta_data.get('nPlugin') | 
            ||
| 497 | temp = [e for entry in self.data_flow[idx+1:] for e in entry]  | 
            ||
| 498 | if name in temp or self.exp.index['out_data'][name].remove:  | 
            ||
| 499 | return 'intermediate'  | 
            ||
| 500 | return 'final_result'  | 
            ||
| 501 | |||
| 502 | def _populate_nexus_file(self, data, iterate_group=None):  | 
            ||
| 503 |         filename = self.exp.meta_data.get('nxs_filename') | 
            ||
| 504 | |||
| 505 | with h5py.File(filename, 'a') as nxs_file:  | 
            ||
| 506 | nxs_entry = nxs_file['entry']  | 
            ||
| 507 |             name = data.data_info.get('name') | 
            ||
| 508 | group_name = self.exp.meta_data.get(['group_name', name])  | 
            ||
| 509 | link_type = self.exp.meta_data.get(['link_type', name])  | 
            ||
| 510 | |||
| 511 | if link_type == 'final_result':  | 
            ||
| 512 | if iterate_group is not None and \  | 
            ||
| 513 | check_if_end_plugin_in_iterate_group(self.exp):  | 
            ||
| 514 | is_clone_data = 'clone' in name  | 
            ||
| 515 | is_even_iterations = \  | 
            ||
| 516 | iterate_group._ip_fixed_iterations % 2 == 0  | 
            ||
| 517 | # don't need to create group for:  | 
            ||
| 518 | # - clone dataset, if running an odd number of iterations  | 
            ||
| 519 | # - original dataset, if running an even number of  | 
            ||
| 520 | # iterations  | 
            ||
| 521 | if is_clone_data and not is_even_iterations:  | 
            ||
| 522 | return  | 
            ||
| 523 | elif not is_clone_data and is_even_iterations:  | 
            ||
| 524 | return  | 
            ||
| 525 | # the group name for the output of the iterative loop should be  | 
            ||
| 526 | # named after the original dataset, regardless of if the link  | 
            ||
| 527 | # eventually points to the original or the clone, for the sake  | 
            ||
| 528 | # of the linkname referencing the dataset name set in  | 
            ||
| 529 | # savu_config  | 
            ||
| 530 | group_name = 'final_result_' + data.get_name(orig=True)  | 
            ||
| 531 | else:  | 
            ||
| 532 |                 link = nxs_entry.require_group(link_type.encode("ascii")) | 
            ||
| 533 | link.attrs[NX_CLASS] = 'NXcollection'  | 
            ||
| 534 | nxs_entry = link  | 
            ||
| 535 | |||
| 536 | # delete the group if it already exists  | 
            ||
| 537 | if group_name in nxs_entry:  | 
            ||
| 538 | del nxs_entry[group_name]  | 
            ||
| 539 | |||
| 540 | plugin_entry = nxs_entry.require_group(group_name)  | 
            ||
| 541 | plugin_entry.attrs[NX_CLASS] = 'NXdata'  | 
            ||
| 542 | if iterate_group is not None and \  | 
            ||
| 543 | check_if_end_plugin_in_iterate_group(self.exp):  | 
            ||
| 544 | # always write the metadata under the name of the original  | 
            ||
| 545 | # dataset, not the clone dataset  | 
            ||
| 546 | self._output_metadata(data, plugin_entry,  | 
            ||
| 547 | data.get_name(orig=False))  | 
            ||
| 548 | else:  | 
            ||
| 549 | self._output_metadata(data, plugin_entry, name)  | 
            ||
| 550 | |||
| 551 | def _populate_pre_run_nexus_file(self, data):  | 
            ||
| 552 |         filename = self.exp.meta_data.get('nxs_filename') | 
            ||
| 553 | |||
| 554 | data_path = self.exp.meta_data["data_path"]  | 
            ||
| 555 | image_key_path = self.exp.meta_data["image_key_path"]  | 
            ||
| 556 |         name = data.data_info.get('name') | 
            ||
| 557 | group_name = self.exp.meta_data.get(['group_name', name])  | 
            ||
| 558 | with h5py.File(filename, 'a') as nxs_file:  | 
            ||
| 559 | if data_path in nxs_file:  | 
            ||
| 560 | del nxs_file[data_path]  | 
            ||
| 561 |             nxs_file[data_path] = h5py.ExternalLink(os.path.abspath(data.backing_file.filename), f"{group_name}/data") | 
            ||
| 562 | |||
| 563 | if image_key_path in nxs_file:  | 
            ||
| 564 | nxs_file[image_key_path][::] = data.data.image_key[::]  | 
            ||
| 565 |             #nxs_file[data_path].attrs.create("pre_run", True) | 
            ||
| 566 | |||
| 567 | def _output_metadata(self, data, entry, name, dump=False):  | 
            ||
| 568 | self.__output_data_type(entry, data, name)  | 
            ||
| 569 | mDict = data.meta_data.get_dictionary()  | 
            ||
| 570 |         self._output_metadata_dict(entry.require_group('meta_data'), mDict) | 
            ||
| 571 | |||
| 572 | if not dump:  | 
            ||
| 573 | self.__output_axis_labels(data, entry)  | 
            ||
| 574 | self.__output_data_patterns(data, entry)  | 
            ||
| 575 |             if self.exp.meta_data.get('link_type')[name] == 'input_data': | 
            ||
| 576 | # output the filename  | 
            ||
| 577 | entry['file_path'] = \  | 
            ||
| 578 |                     os.path.abspath(self.exp.meta_data.get('data_file')) | 
            ||
| 579 | |||
| 580 | def __output_data_type(self, entry, data, name):  | 
            ||
| 581 | data = data.data if 'data' in list(data.__dict__.keys()) else data  | 
            ||
| 582 | if isinstance(data, h5py.Dataset):  | 
            ||
| 583 | return  | 
            ||
| 584 | |||
| 585 |         entry = entry.require_group('data_type') | 
            ||
| 586 | entry.attrs[NX_CLASS] = 'NXcollection'  | 
            ||
| 587 | |||
| 588 |         ltype = self.exp.meta_data.get('link_type') | 
            ||
| 589 | if name in list(ltype.keys()) and ltype[name] == 'input_data':  | 
            ||
| 590 | self.__output_data(entry, data.__class__.__name__, 'cls')  | 
            ||
| 591 | return  | 
            ||
| 592 | |||
| 593 | args, kwargs, cls, extras = data._get_parameters(data.get_clone_args())  | 
            ||
| 594 | |||
| 595 | for key, value in kwargs.items():  | 
            ||
| 596 |             gp = entry.require_group('kwargs') | 
            ||
| 597 | if isinstance(value, BaseType):  | 
            ||
| 598 | self.__output_data_type(gp.require_group(key), value, key)  | 
            ||
| 599 | else:  | 
            ||
| 600 | self.__output_data(gp, value, key)  | 
            ||
| 601 | |||
| 602 | for key, value in extras.items():  | 
            ||
| 603 |             gp = entry.require_group('extras') | 
            ||
| 604 | if isinstance(value, BaseType):  | 
            ||
| 605 | self.__output_data_type(gp.require_group(key), value, key)  | 
            ||
| 606 | else:  | 
            ||
| 607 | self.__output_data(gp, value, key)  | 
            ||
| 608 | |||
| 609 | for i, item in enumerate(args):  | 
            ||
| 610 |             gp = entry.require_group('args') | 
            ||
| 611 | self.__output_data(gp, args[i], ''.join(['args', str(i)]))  | 
            ||
| 612 | |||
| 613 | self.__output_data(entry, cls, 'cls')  | 
            ||
| 614 | |||
| 615 | if 'data' in list(data.__dict__.keys()) and not \  | 
            ||
| 616 | isinstance(data.data, h5py.Dataset):  | 
            ||
| 617 |             gp = entry.require_group('data') | 
            ||
| 618 | self.__output_data_type(gp, data.data, 'data')  | 
            ||
| 619 | |||
| 620 | View Code Duplication | def __output_data(self, entry, data, name):  | 
            |
| 621 | if isinstance(data, dict):  | 
            ||
| 622 | entry = entry.require_group(name)  | 
            ||
| 623 | entry.attrs[NX_CLASS] = 'NXcollection'  | 
            ||
| 624 | for key, value in data.items():  | 
            ||
| 625 | self.__output_data(entry, value, key)  | 
            ||
| 626 | else:  | 
            ||
| 627 | try:  | 
            ||
| 628 | self.__create_dataset(entry, name, data)  | 
            ||
| 629 | except Exception:  | 
            ||
| 630 | try:  | 
            ||
| 631 | import json  | 
            ||
| 632 |                     data = np.array([json.dumps(data).encode("ascii")]) | 
            ||
| 633 | self.__create_dataset(entry, name, data)  | 
            ||
| 634 | except Exception:  | 
            ||
| 635 | try:  | 
            ||
| 636 | data = cu._savu_encoder(data)  | 
            ||
| 637 | self.__create_dataset(entry, name, data)  | 
            ||
| 638 | except:  | 
            ||
| 639 |                         raise Exception('Unable to output %s to file.' % name) | 
            ||
| 640 | |||
| 641 | def __create_dataset(self, entry, name, data):  | 
            ||
| 642 | if name not in list(entry.keys()):  | 
            ||
| 643 | entry.create_dataset(name, data=data)  | 
            ||
| 644 | else:  | 
            ||
| 645 | entry[name][...] = data  | 
            ||
| 646 | |||
| 647 | def __output_axis_labels(self, data, entry):  | 
            ||
| 648 |         axis_labels = data.data_info.get("axis_labels") | 
            ||
| 649 | ddict = data.meta_data.get_dictionary()  | 
            ||
| 650 | |||
| 651 | axes = []  | 
            ||
| 652 | count = 0  | 
            ||
| 653 | for labels in axis_labels:  | 
            ||
| 654 | name = list(labels.keys())[0]  | 
            ||
| 655 | axes.append(name)  | 
            ||
| 656 | entry.attrs[name + '_indices'] = count  | 
            ||
| 657 | |||
| 658 | mData = ddict[name] if name in list(ddict.keys()) \  | 
            ||
| 659 | else np.arange(data.get_shape()[count])  | 
            ||
| 660 | if isinstance(mData, list):  | 
            ||
| 661 | mData = np.array(mData)  | 
            ||
| 662 | |||
| 663 | if 'U' in str(mData.dtype):  | 
            ||
| 664 | mData = mData.astype(np.string_)  | 
            ||
| 665 | |||
| 666 | axis_entry = entry.require_dataset(name, mData.shape, mData.dtype)  | 
            ||
| 667 | axis_entry[...] = mData[...]  | 
            ||
| 668 | axis_entry.attrs['units'] = list(labels.values())[0]  | 
            ||
| 669 | count += 1  | 
            ||
| 670 | entry.attrs['axes'] = axes  | 
            ||
| 671 | |||
| 672 | View Code Duplication | def __output_data_patterns(self, data, entry):  | 
            |
| 673 |         data_patterns = data.data_info.get("data_patterns") | 
            ||
| 674 |         entry = entry.require_group('patterns') | 
            ||
| 675 | entry.attrs[NX_CLASS] = 'NXcollection'  | 
            ||
| 676 | for pattern in data_patterns:  | 
            ||
| 677 | nx_data = entry.require_group(pattern)  | 
            ||
| 678 | nx_data.attrs[NX_CLASS] = 'NXparameters'  | 
            ||
| 679 | values = data_patterns[pattern]  | 
            ||
| 680 | self.__output_data(nx_data, values['core_dims'], 'core_dims')  | 
            ||
| 681 | self.__output_data(nx_data, values['slice_dims'], 'slice_dims')  | 
            ||
| 682 | |||
| 683 | def _output_metadata_dict(self, entry, mData):  | 
            ||
| 684 | entry.attrs[NX_CLASS] = 'NXcollection'  | 
            ||
| 685 | for key, value in mData.items():  | 
            ||
| 686 | nx_data = entry.require_group(key)  | 
            ||
| 687 | if isinstance(value, dict):  | 
            ||
| 688 | self._output_metadata_dict(nx_data, value)  | 
            ||
| 689 | else:  | 
            ||
| 690 | nx_data.attrs[NX_CLASS] = 'NXdata'  | 
            ||
| 691 | self.__output_data(nx_data, value, key)  | 
            ||
| 692 |