Test Failed
Pull Request — master (#827)
by
unknown
03:34
created

Experiment._set_initial_datasets()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
# -*- coding: utf-8 -*-
2
# Copyright 2014 Diamond Light Source Ltd.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
"""
17
.. module:: experiment
18
   :platform: Unix
19
   :synopsis: Contains information specific to the entire experiment.
20
21
.. moduleauthor:: Nicola Wadeson <[email protected]>
22
"""
23
24
import os
25
import copy
26
import h5py
27
import logging
28
from mpi4py import MPI
29
30
from savu.data.meta_data import MetaData
31
from savu.data.plugin_list import PluginList
32
from savu.data.data_structures.data import Data
33
from savu.core.checkpointing import Checkpointing
34
from savu.plugins.savers.utils.hdf5_utils import Hdf5Utils
35
import savu.plugins.loaders.utils.yaml_utils as yaml
36
37
38
class Experiment(object):
39
    """
40
    One instance of this class is created at the beginning of the
41
    processing chain and remains until the end.  It holds the current data
42
    object and a dictionary containing all metadata.
43
    """
44
45
    def __init__(self, options):
46
        self.meta_data = MetaData(options)
47
        self.__set_system_params()
48
        self.checkpoint = Checkpointing(self)
49
        self.__meta_data_setup(options["process_file"])
50
        self.collection = {}
51
        self.index = {"in_data": {}, "out_data": {}}
52
        self.initial_datasets = None
53
        self.plugin = None
54
        self._transport = None
55
        self._barrier_count = 0
56
        self._dataset_names_complete = False
57
58
    def get(self, entry):
59
        """ Get the meta data dictionary. """
60
        return self.meta_data.get(entry)
61
62
    def __meta_data_setup(self, process_file):
63
        self.meta_data.plugin_list = PluginList()
64
        try:
65
            rtype = self.meta_data.get('run_type')
66
            if rtype == 'test':
67
                self.meta_data.plugin_list.plugin_list = \
68
                    self.meta_data.get('plugin_list')
69
            else:
70
                raise Exception('the run_type is unknown in Experiment class')
71
        except KeyError:
72
            template = self.meta_data.get('template')
73
            self.meta_data.plugin_list._populate_plugin_list(process_file,
74
                                                             template=template)
75
        self.meta_data.set("nPlugin", 0) # initialise
76
77
    def create_data_object(self, dtype, name, override=True):
78
        """ Create a data object.
79
80
        Plugin developers should apply this method in loaders only.
81
82
        :params str dtype: either "in_data" or "out_data".
83
        """
84
        if name not in list(self.index[dtype].keys()) or override:
85
            self.index[dtype][name] = Data(name, self)
86
            data_obj = self.index[dtype][name]
87
            data_obj._set_transport_data(self.meta_data.get('transport'))
88
        return self.index[dtype][name]
89
90
    def _setup(self, transport):
91
        self._set_nxs_file()
92
        self._set_process_list_path()
93
        self._set_transport(transport)
94
        self.collection = {'plugin_dict': [], 'datasets': []}
95
96
        self._barrier()
97
        self._check_checkpoint()
98
        self._barrier()
99
100
    def _finalise_setup(self, plugin_list):
101
        checkpoint = self.meta_data.get('checkpoint')
102
        self._set_dataset_names_complete()
103
        # save the plugin list - one process, first time only
104
        if self.meta_data.get('process') == \
105
                len(self.meta_data.get('processes'))-1 and not checkpoint:
106
            # Save original process list
107
            plugin_list._save_plugin_list(self.meta_data.get('process_list_path'))
108
            # links the input data to the nexus file
109
            plugin_list._save_plugin_list(self.meta_data.get('nxs_filename'))
110
            self._add_input_data_to_nxs_file(self._get_transport())
111
        self._set_dataset_names_complete()
112
        self._save_command_log()
113
114
    def _save_command_log(self):
115
        """Save the original Savu run command and a
116
        modified Savu run command to a log file for reproducibility
117
        """
118
        folder = self.meta_data.get('out_path')
119
        log_folder = os.path.join(folder, "run_log")
120
        filename = os.path.join(log_folder, "run_command.txt")
121
        modified_command = self._get_modified_command()
122
        if not os.path.isfile(filename):
123
            # Only write savu command if savu_mpi command has not been saved
124
            with open(filename, 'w') as command_log:
125
                command_log.write(f"# Original Savu run command\n")
126
                command_log.write(f"{self.meta_data.get('command')}\n")
127
                command_log.write(f"# A modified Savu command to use to "
128
                                  f"reproduce the  obtained result\n")
129
                command_log.write(f"{modified_command}\n")
130
131
    def _get_modified_command(self):
132
        """Modify the input Savu run command, and replace the path to the
133
        process list
134
        :returns modified Savu run command string
135
        """
136
        pl_path = self.meta_data.get('process_file')
137
        new_pl_path = self.meta_data.get('process_list_path')
138
        input_command = self.meta_data.get('command')
139
        updated_command = input_command.replace(pl_path, new_pl_path)
140
        return updated_command
141
142
    def _set_process_list_path(self):
143
        """Create the path the process list should be saved to"""
144
        log_folder = os.path.join(self.meta_data.get('out_path'), "run_log")
145
        plname = os.path.basename(self.meta_data.get('process_file'))
146
        filename = os.path.join(log_folder, plname if plname
147
            else "process_list.nxs")
148
        self.meta_data.set('process_list_path', filename)
149
150
    def _set_process_list_path(self):
151
        """Create the path the process list should be saved to"""
152
        log_folder = os.path.join(self.meta_data.get('out_path'),"run_log")
153
        plname = os.path.basename(self.meta_data.get('process_file'))
154
        filename = os.path.join(log_folder, plname if plname
155
            else "process_list.nxs")
156
        self.meta_data.set('process_list_path', filename)
157
158
    def _set_initial_datasets(self):
159
        self.initial_datasets = copy.deepcopy(self.index['in_data'])
160
161
    def _set_transport(self, transport):
162
        self._transport = transport
163
164
    def _get_transport(self):
165
        return self._transport
166
167
    def __set_system_params(self):
168
        sys_file = self.meta_data.get('system_params')
169
        import sys
170
        if sys_file is None:
171
            # look in conda environment to see which version is being used
172
            savu_path = sys.modules['savu'].__path__[0]
173
            sys_files = os.path.join(
174
                os.path.dirname(savu_path), 'system_files')
175
            subdirs = os.listdir(sys_files)
176
            sys_folder = 'dls' if len(subdirs) > 1 else subdirs[0]
177
            fname = 'system_parameters.yml'
178
            sys_file = os.path.join(sys_files, sys_folder, fname)
179
        logging.info('Using the system parameters file: %s', sys_file)
180
        self.meta_data.set('system_params', yaml.read_yaml(sys_file))
181
182
    def _check_checkpoint(self):
183
        # if checkpointing has been set but the nxs file doesn't contain an
184
        # entry then remove checkpointing (as the previous run didn't get far
185
        # enough to require it).
186
        if self.meta_data.get('checkpoint'):
187
            with h5py.File(self.meta_data.get('nxs_filename'), 'r') as f:
188
                if 'entry' not in f:
189
                    self.meta_data.set('checkpoint', None)
190
191
    def _add_input_data_to_nxs_file(self, transport):
192
        # save the loaded data to file
193
        h5 = Hdf5Utils(self)
194
        for name, data in self.index['in_data'].items():
195
            self.meta_data.set(['link_type', name], 'input_data')
196
            self.meta_data.set(['group_name', name], name)
197
            self.meta_data.set(['filename', name], data.backing_file)
198
            transport._populate_nexus_file(data)
199
            h5._link_datafile_to_nexus_file(data)
200
201
    def _set_dataset_names_complete(self):
202
        """ Missing in/out_datasets fields have been populated
203
        """
204
        self._dataset_names_complete = True
205
206
    def _get_dataset_names_complete(self):
207
        return self._dataset_names_complete
208
209
    def _reset_datasets(self):
210
        self.index['in_data'] = self.initial_datasets
211
        # clear out dataset dictionaries
212
        for data_dict in self.collection['datasets']:
213
            for data in data_dict.values():
214
                data.meta_data._set_dictionary({})
215
216
    def _get_collection(self):
217
        return self.collection
218
219
    def _set_experiment_for_current_plugin(self, count):
220
        datasets_list = self.meta_data.plugin_list._get_datasets_list()[count:]
221
        exp_coll = self._get_collection()
222
        self.index['out_data'] = exp_coll['datasets'][count]
223
        if datasets_list:
224
            self._get_current_and_next_patterns(datasets_list)
225
        self.meta_data.set('nPlugin', count)
226
227
    def _get_current_and_next_patterns(self, datasets_lists):
228
        """ Get the current and next patterns associated with a dataset
229
        throughout the processing chain.
230
        """
231
        current_datasets = datasets_lists[0]
232
        patterns_list = {}
233
        for current_data in current_datasets['out_datasets']:
234
            current_name = current_data['name']
235
            current_pattern = current_data['pattern']
236
            next_pattern = self.__find_next_pattern(datasets_lists[1:],
237
                                                    current_name)
238
            patterns_list[current_name] = \
239
                {'current': current_pattern, 'next': next_pattern}
240
        self.meta_data.set('current_and_next', patterns_list)
241
242
    def __find_next_pattern(self, datasets_lists, current_name):
243
        next_pattern = []
244
        for next_data_list in datasets_lists:
245
            for next_data in next_data_list['in_datasets']:
246
                if next_data['name'] == current_name:
247
                    next_pattern = next_data['pattern']
248
                    return next_pattern
249
        return next_pattern
250
251
    def _set_nxs_file(self):
252
        folder = self.meta_data.get('out_path')
253
        fname = self.meta_data.get('datafile_name') + '_processed.nxs'
254
        filename = os.path.join(folder, fname)
255
        self.meta_data.set('nxs_filename', filename)
256
257
        if self.meta_data.get('process') == 1:
258
            if self.meta_data.get('bllog'):
259
                log_folder_name = self.meta_data.get('bllog')
260
                with open(log_folder_name, 'a') as log_folder:
261
                    log_folder.write(os.path.abspath(filename) + '\n')
262
263
        self._create_nxs_entry()
264
265
    def _create_nxs_entry(self):  # what if the file already exists?!
266
        logging.debug("Testing nexus file")
267
        if self.meta_data.get('process') == len(
268
                self.meta_data.get('processes')) - 1 and not self.checkpoint:
269
            with h5py.File(self.meta_data.get('nxs_filename'), 'w') as nxs_file:
270
                entry_group = nxs_file.create_group('entry')
271
                entry_group.attrs['NX_class'] = 'NXentry'
272
273
    def _clear_data_objects(self):
274
        self.index["out_data"] = {}
275
        self.index["in_data"] = {}
276
277
    def _merge_out_data_to_in(self, plugin_dict):
278
        out_data = self.index['out_data'].copy()
279
        for key, data in out_data.items():
280
            if data.remove is False:
281
                self.index['in_data'][key] = data
282
        self.collection['datasets'].append(out_data)
283
        self.collection['plugin_dict'].append(plugin_dict)
284
        self.index["out_data"] = {}
285
286
    def _finalise_experiment_for_current_plugin(self):
287
        finalise = {'remove': [], 'keep': []}
288
        # populate nexus file with out_dataset information and determine which
289
        # datasets to remove from the framework.
290
291
        for key, data in self.index['out_data'].items():
292
            if data.remove is True:
293
                finalise['remove'].append(data)
294
            else:
295
                finalise['keep'].append(data)
296
297
        # find in datasets to replace
298
        finalise['replace'] = []
299
        for out_name in list(self.index['out_data'].keys()):
300
            if out_name in list(self.index['in_data'].keys()):
301
                finalise['replace'].append(self.index['in_data'][out_name])
302
303
        
304
        return finalise
305
306
    def _reorganise_datasets(self, finalise):
307
        # unreplicate replicated in_datasets
308
        self.__unreplicate_data()
309
310
        # delete all datasets for removal
311
        for data in finalise['remove']:
312
            del self.index["out_data"][data.data_info.get('name')]
313
314
        # Add remaining output datasets to input datasets
315
        for name, data in self.index['out_data'].items():
316
            data.get_preview().set_preview([])
317
            self.index["in_data"][name] = copy.deepcopy(data)
318
        self.index['out_data'] = {}
319
320
    def __unreplicate_data(self):
321
        in_data_list = self.index['in_data']
322
        from savu.data.data_structures.data_types.replicate import Replicate
323
        for in_data in list(in_data_list.values()):
324
            if isinstance(in_data.data, Replicate):
325
                in_data.data = in_data.data._reset()
326
327
    def _set_all_datasets(self, name):
328
        data_names = []
329
        for key in list(self.index["in_data"].keys()):
330
            if 'itr_clone' not in key:
331
                data_names.append(key)
332
        return data_names
333
334
    def _barrier(self, communicator=MPI.COMM_WORLD, msg=''):
335
        comm_dict = {'comm': communicator}
336
        if self.meta_data.get('mpi') is True:
337
            logging.debug("Barrier %d: %d processes expected: %s",
338
                          self._barrier_count, communicator.size, msg)
339
            comm_dict['comm'].barrier()
340
        self._barrier_count += 1
341
342
    def log(self, log_tag, log_level=logging.DEBUG):
343
        """
344
        Log the contents of the experiment at the specified level
345
        """
346
        logging.log(log_level, "Experimental Parameters for %s", log_tag)
347
        for key, value in self.index["in_data"].items():
348
            logging.log(log_level, "in data (%s) shape = %s", key,
349
                        value.get_shape())
350
        for key, value in self.index["in_data"].items():
351
            logging.log(log_level, "out data (%s) shape = %s", key,
352
                        value.get_shape())
353