Test Failed
Pull Request — master (#820)
by
unknown
04:20
created

Experiment._reorganise_datasets()   A

Complexity

Conditions 3

Size

Total Lines 13
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 8
nop 2
dl 0
loc 13
rs 10
c 0
b 0
f 0
1
# -*- coding: utf-8 -*-
2
# Copyright 2014 Diamond Light Source Ltd.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
"""
17
.. module:: experiment
18
   :platform: Unix
19
   :synopsis: Contains information specific to the entire experiment.
20
21
.. moduleauthor:: Nicola Wadeson <[email protected]>
22
"""
23
24
import os
25
import copy
26
import h5py
27
import logging
28
from mpi4py import MPI
29
30
from savu.data.meta_data import MetaData
31
from savu.data.plugin_list import PluginList
32
from savu.data.data_structures.data import Data
33
from savu.core.checkpointing import Checkpointing
34
from savu.plugins.savers.utils.hdf5_utils import Hdf5Utils
35
import savu.plugins.loaders.utils.yaml_utils as yaml
36
37
38
class Experiment(object):
39
    """
40
    One instance of this class is created at the beginning of the
41
    processing chain and remains until the end.  It holds the current data
42
    object and a dictionary containing all metadata.
43
    """
44
45
    def __init__(self, options):
46
        self.meta_data = MetaData(options)
47
        self.__set_system_params()
48
        self.checkpoint = Checkpointing(self)
49
        self.__meta_data_setup(options["process_file"])
50
        self.collection = {}
51
        self.index = {"in_data": {}, "out_data": {}}
52
        self.initial_datasets = None
53
        self.plugin = None
54
        self._transport = None
55
        self._barrier_count = 0
56
        self._dataset_names_complete = False
57
58
    def get(self, entry):
59
        """ Get the meta data dictionary. """
60
        return self.meta_data.get(entry)
61
62
    def __meta_data_setup(self, process_file):
63
        self.meta_data.plugin_list = PluginList()
64
        try:
65
            rtype = self.meta_data.get('run_type')
66
            if rtype == 'test':
67
                self.meta_data.plugin_list.plugin_list = \
68
                    self.meta_data.get('plugin_list')
69
            else:
70
                raise Exception('the run_type is unknown in Experiment class')
71
        except KeyError:
72
            template = self.meta_data.get('template')
73
            self.meta_data.plugin_list._populate_plugin_list(process_file,
74
                                                             template=template)
75
        self.meta_data.set("nPlugin", 0) # initialise
76
77
    def create_data_object(self, dtype, name, override=True):
78
        """ Create a data object.
79
80
        Plugin developers should apply this method in loaders only.
81
82
        :params str dtype: either "in_data" or "out_data".
83
        """
84
        if name not in list(self.index[dtype].keys()) or override:
85
            self.index[dtype][name] = Data(name, self)
86
            data_obj = self.index[dtype][name]
87
            data_obj._set_transport_data(self.meta_data.get('transport'))
88
        return self.index[dtype][name]
89
90
    def _setup(self, transport):
91
        self._set_nxs_file()
92
        self._set_process_list_path()
93
        self._set_transport(transport)
94
        self.collection = {'plugin_dict': [], 'datasets': []}
95
96
        self._barrier()
97
        self._check_checkpoint()
98
        self._barrier()
99
100
    def _finalise_setup(self, plugin_list):
101
        checkpoint = self.meta_data.get('checkpoint')
102
        self._set_dataset_names_complete()
103
        # save the plugin list - one process, first time only
104
        if self.meta_data.get('process') == \
105
                len(self.meta_data.get('processes'))-1 and not checkpoint:
106
            # Save original process list
107
            plugin_list._save_plugin_list(self.meta_data.get('process_list_path'))
108
            # links the input data to the nexus file
109
            plugin_list._save_plugin_list(self.meta_data.get('nxs_filename'))
110
            self._add_input_data_to_nxs_file(self._get_transport())
111
        self._set_dataset_names_complete()
112
        self._save_command_log()
113
114
    def _save_command_log(self):
115
        """Save the original Savu run command and a
116
        modified Savu run command to a log file for reproducibility
117
        """
118
        folder = self.meta_data.get('out_path')
119
        log_folder = f"{folder}/run_log"
120
        fname = "run_command.txt"
121
        filename = os.path.join(log_folder, fname)
122
        modified_command = self._get_modified_command()
123
        with open(filename, 'a') as command_log:
124
            command_log.write(f"# Original Savu run command\n")
125
            command_log.write(f"{self.meta_data.get('command')}\n")
126
            command_log.write(f"# A modified Savu command to use to "
127
                              f"reproduce the  obtained result\n")
128
            command_log.write(f"{modified_command}\n")
129
130
    def _get_modified_command(self):
131
        """Modify the input Savu run command, and replace the path to the
132
        process list
133
        :returns modified Savu run command string
134
        """
135
        pl_path = self.meta_data.get('process_file')
136
        new_pl_path = self.meta_data.get('process_list_path')
137
        input_command = self.meta_data.get('command')
138
        updated_command = input_command.replace(pl_path, new_pl_path)
139
        return updated_command
140
141
    def _set_process_list_path(self):
142
        """Create the path the process list should be saved to"""
143
        log_folder = f"{self.meta_data.get('out_path')}/run_log"
144
        plname = os.path.basename(self.meta_data.get('process_file'))
145
        filename = os.path.join(log_folder, plname if plname
146
            else "process_list.nxs")
147
        self.meta_data.set('process_list_path', filename)
148
149
    def _set_initial_datasets(self):
150
        self.initial_datasets = copy.deepcopy(self.index['in_data'])
151
152
    def _set_transport(self, transport):
153
        self._transport = transport
154
155
    def _get_transport(self):
156
        return self._transport
157
158
    def __set_system_params(self):
159
        sys_file = self.meta_data.get('system_params')
160
        import sys
161
        if sys_file is None:
162
            # look in conda environment to see which version is being used
163
            savu_path = sys.modules['savu'].__path__[0]
164
            sys_files = os.path.join(
165
                os.path.dirname(savu_path), 'system_files')
166
            subdirs = os.listdir(sys_files)
167
            sys_folder = 'dls' if len(subdirs) > 1 else subdirs[0]
168
            fname = 'system_parameters.yml'
169
            sys_file = os.path.join(sys_files, sys_folder, fname)
170
        logging.info('Using the system parameters file: %s', sys_file)
171
        self.meta_data.set('system_params', yaml.read_yaml(sys_file))
172
173
    def _check_checkpoint(self):
174
        # if checkpointing has been set but the nxs file doesn't contain an
175
        # entry then remove checkpointing (as the previous run didn't get far
176
        # enough to require it).
177
        if self.meta_data.get('checkpoint'):
178
            with h5py.File(self.meta_data.get('nxs_filename'), 'r') as f:
179
                if 'entry' not in f:
180
                    self.meta_data.set('checkpoint', None)
181
182
    def _add_input_data_to_nxs_file(self, transport):
183
        # save the loaded data to file
184
        h5 = Hdf5Utils(self)
185
        for name, data in self.index['in_data'].items():
186
            self.meta_data.set(['link_type', name], 'input_data')
187
            self.meta_data.set(['group_name', name], name)
188
            self.meta_data.set(['filename', name], data.backing_file)
189
            transport._populate_nexus_file(data)
190
            h5._link_datafile_to_nexus_file(data)
191
192
    def _set_dataset_names_complete(self):
193
        """ Missing in/out_datasets fields have been populated
194
        """
195
        self._dataset_names_complete = True
196
197
    def _get_dataset_names_complete(self):
198
        return self._dataset_names_complete
199
200
    def _reset_datasets(self):
201
        self.index['in_data'] = self.initial_datasets
202
        # clear out dataset dictionaries
203
        for data_dict in self.collection['datasets']:
204
            for data in data_dict.values():
205
                data.meta_data._set_dictionary({})
206
207
    def _get_collection(self):
208
        return self.collection
209
210
    def _set_experiment_for_current_plugin(self, count):
211
        datasets_list = self.meta_data.plugin_list._get_datasets_list()[count:]
212
        exp_coll = self._get_collection()
213
        self.index['out_data'] = exp_coll['datasets'][count]
214
        if datasets_list:
215
            self._get_current_and_next_patterns(datasets_list)
216
        self.meta_data.set('nPlugin', count)
217
218
    def _get_current_and_next_patterns(self, datasets_lists):
219
        """ Get the current and next patterns associated with a dataset
220
        throughout the processing chain.
221
        """
222
        current_datasets = datasets_lists[0]
223
        patterns_list = {}
224
        for current_data in current_datasets['out_datasets']:
225
            current_name = current_data['name']
226
            current_pattern = current_data['pattern']
227
            next_pattern = self.__find_next_pattern(datasets_lists[1:],
228
                                                    current_name)
229
            patterns_list[current_name] = \
230
                {'current': current_pattern, 'next': next_pattern}
231
        self.meta_data.set('current_and_next', patterns_list)
232
233
    def __find_next_pattern(self, datasets_lists, current_name):
234
        next_pattern = []
235
        for next_data_list in datasets_lists:
236
            for next_data in next_data_list['in_datasets']:
237
                if next_data['name'] == current_name:
238
                    next_pattern = next_data['pattern']
239
                    return next_pattern
240
        return next_pattern
241
242
    def _set_nxs_file(self):
243
        folder = self.meta_data.get('out_path')
244
        fname = self.meta_data.get('datafile_name') + '_processed.nxs'
245
        filename = os.path.join(folder, fname)
246
        self.meta_data.set('nxs_filename', filename)
247
248
        if self.meta_data.get('process') == 1:
249
            if self.meta_data.get('bllog'):
250
                log_folder_name = self.meta_data.get('bllog')
251
                with open(log_folder_name, 'a') as log_folder:
252
                    log_folder.write(os.path.abspath(filename) + '\n')
253
254
        self._create_nxs_entry()
255
256
    def _create_nxs_entry(self):  # what if the file already exists?!
257
        logging.debug("Testing nexus file")
258
        if self.meta_data.get('process') == len(
259
                self.meta_data.get('processes')) - 1 and not self.checkpoint:
260
            with h5py.File(self.meta_data.get('nxs_filename'), 'w') as nxs_file:
261
                entry_group = nxs_file.create_group('entry')
262
                entry_group.attrs['NX_class'] = 'NXentry'
263
264
    def _clear_data_objects(self):
265
        self.index["out_data"] = {}
266
        self.index["in_data"] = {}
267
268
    def _merge_out_data_to_in(self, plugin_dict):
269
        out_data = self.index['out_data'].copy()
270
        for key, data in out_data.items():
271
            if data.remove is False:
272
                self.index['in_data'][key] = data
273
        self.collection['datasets'].append(out_data)
274
        self.collection['plugin_dict'].append(plugin_dict)
275
        self.index["out_data"] = {}
276
277
    def _finalise_experiment_for_current_plugin(self):
278
        finalise = {'remove': [], 'keep': []}
279
        # populate nexus file with out_dataset information and determine which
280
        # datasets to remove from the framework.
281
282
        for key, data in self.index['out_data'].items():
283
            if data.remove is True:
284
                finalise['remove'].append(data)
285
            else:
286
                finalise['keep'].append(data)
287
288
        # find in datasets to replace
289
        finalise['replace'] = []
290
        for out_name in list(self.index['out_data'].keys()):
291
            if out_name in list(self.index['in_data'].keys()):
292
                finalise['replace'].append(self.index['in_data'][out_name])
293
294
        
295
        return finalise
296
297
    def _reorganise_datasets(self, finalise):
298
        # unreplicate replicated in_datasets
299
        self.__unreplicate_data()
300
301
        # delete all datasets for removal
302
        for data in finalise['remove']:
303
            del self.index["out_data"][data.data_info.get('name')]
304
305
        # Add remaining output datasets to input datasets
306
        for name, data in self.index['out_data'].items():
307
            data.get_preview().set_preview([])
308
            self.index["in_data"][name] = copy.deepcopy(data)
309
        self.index['out_data'] = {}
310
311
    def __unreplicate_data(self):
312
        in_data_list = self.index['in_data']
313
        from savu.data.data_structures.data_types.replicate import Replicate
314
        for in_data in list(in_data_list.values()):
315
            if isinstance(in_data.data, Replicate):
316
                in_data.data = in_data.data._reset()
317
318
    def _set_all_datasets(self, name):
319
        data_names = []
320
        for key in list(self.index["in_data"].keys()):
321
            if 'itr_clone' not in key:
322
                data_names.append(key)
323
        return data_names
324
325
    def _barrier(self, communicator=MPI.COMM_WORLD, msg=''):
326
        comm_dict = {'comm': communicator}
327
        if self.meta_data.get('mpi') is True:
328
            logging.debug("Barrier %d: %d processes expected: %s",
329
                          self._barrier_count, communicator.size, msg)
330
            comm_dict['comm'].barrier()
331
        self._barrier_count += 1
332
333
    def log(self, log_tag, log_level=logging.DEBUG):
334
        """
335
        Log the contents of the experiment at the specified level
336
        """
337
        logging.log(log_level, "Experimental Parameters for %s", log_tag)
338
        for key, value in self.index["in_data"].items():
339
            logging.log(log_level, "in data (%s) shape = %s", key,
340
                        value.get_shape())
341
        for key, value in self.index["in_data"].items():
342
            logging.log(log_level, "out data (%s) shape = %s", key,
343
                        value.get_shape())
344