Experiment.__find_next_pattern()   A
last analyzed

Complexity

Conditions 4

Size

Total Lines 8
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 8
nop 3
dl 0
loc 8
rs 10
c 0
b 0
f 0
1
# -*- coding: utf-8 -*-
2
# Copyright 2014 Diamond Light Source Ltd.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
"""
17
.. module:: experiment
18
   :platform: Unix
19
   :synopsis: Contains information specific to the entire experiment.
20
21
.. moduleauthor:: Nicola Wadeson <[email protected]>
22
"""
23
24
import os
25
import copy
26
import h5py
27
import logging
28
from mpi4py import MPI
29
30
from savu.data.meta_data import MetaData
31
from savu.data.plugin_list import PluginList
32
from savu.data.data_structures.data import Data
33
from savu.core.checkpointing import Checkpointing
34
from savu.core.iterative_plugin_runner import IteratePluginGroup
35
from savu.plugins.savers.utils.hdf5_utils import Hdf5Utils
36
from savu.core.iterate_plugin_group_utils import check_if_in_iterative_loop
37
import savu.plugins.loaders.utils.yaml_utils as yaml
38
39
40
class Experiment(object):
41
    """
42
    One instance of this class is created at the beginning of the
43
    processing chain and remains until the end.  It holds the current data
44
    object and a dictionary containing all metadata.
45
    """
46
47
    def __init__(self, options):
48
        self.meta_data = MetaData(options)
49
        self.__set_system_params()
50
        self.checkpoint = Checkpointing(self)
51
        self.__meta_data_setup(options["process_file"])
52
        self.collection = {}
53
        self.index = {"in_data": {}, "out_data": {}}
54
        self.initial_datasets = None
55
        self.plugin = None
56
        self._transport = None
57
        self._barrier_count = 0
58
        self._dataset_names_complete = False
59
60
    def get(self, entry):
61
        """ Get the meta data dictionary. """
62
        return self.meta_data.get(entry)
63
64
    def __meta_data_setup(self, process_file):
65
        self.meta_data.plugin_list = PluginList()
66
        try:
67
            rtype = self.meta_data.get('run_type')
68
            if rtype == 'test':
69
                self.meta_data.plugin_list.plugin_list = \
70
                    self.meta_data.get('plugin_list')
71
            else:
72
                raise Exception('the run_type is unknown in Experiment class')
73
        except KeyError:
74
            template = self.meta_data.get('template')
75
            self.meta_data.plugin_list._populate_plugin_list(process_file,
76
                                                             template=template)
77
        self.meta_data.set("nPlugin", 0) # initialise
78
        self.meta_data.set('iterate_groups', [])
79
80
    def create_data_object(self, dtype, name, override=True):
81
        """ Create a data object.
82
83
        Plugin developers should apply this method in loaders only.
84
85
        :params str dtype: either "in_data" or "out_data".
86
        """
87
        if name not in list(self.index[dtype].keys()) or override:
88
            self.index[dtype][name] = Data(name, self)
89
            data_obj = self.index[dtype][name]
90
            data_obj._set_transport_data(self.meta_data.get('transport'))
91
        return self.index[dtype][name]
92
93
    def _setup(self, transport):
94
        self._set_nxs_file()
95
        self._set_process_list_path()
96
        self._set_transport(transport)
97
        self.collection = {'plugin_dict': [], 'datasets': []}
98
        self._setup_iterate_plugin_groups(transport)
99
100
        self._barrier()
101
        self._check_checkpoint()
102
        self._barrier()
103
104
    def _setup_iterate_plugin_groups(self, transport):
105
        '''
106
        Create all the necessary instances of IteratePluginGroup
107
        '''
108
        iterate_plugin_groups = []
109
        iterate_group_dicts = self.meta_data.plugin_list.iterate_plugin_groups
110
111
        for group in iterate_group_dicts:
112
            iterate_plugin_group = IteratePluginGroup(transport,
113
                group['start_index'],
114
                group['end_index'],
115
                group['iterations'])
116
            iterate_plugin_groups.append(iterate_plugin_group)
117
118
        self.meta_data.set('iterate_groups', iterate_plugin_groups)
119
120
    def _finalise_setup(self, plugin_list):
121
        checkpoint = self.meta_data.get('checkpoint')
122
        self._set_dataset_names_complete()
123
        # save the plugin list - one process, first time only
124
        if self.meta_data.get('process') == \
125
                len(self.meta_data.get('processes'))-1 and not checkpoint:
126
            # Save original process list
127
            plugin_list._save_plugin_list(self.meta_data.get('process_list_path'))
128
            # links the input data to the nexus file
129
            if self.meta_data.get("pre_run"):
130
                self._create_pre_run_nxs_file()
131
            else:
132
                plugin_list._save_plugin_list(self.meta_data.get('nxs_filename'))
133
                self._add_input_data_to_nxs_file(self._get_transport())
134
        self._set_dataset_names_complete()
135
        self._save_command_log()
136
137
    def _save_command_log(self):
138
        """Save the original Savu run command and a
139
        modified Savu run command to a log file for reproducibility
140
        """
141
        current_path = os.getcwd()
142
        folder = self.meta_data.get('out_path')
143
        log_folder = os.path.join(folder, "run_log")
144
        filename = os.path.join(log_folder, "run_command.txt")
145
        modified_command = self._get_modified_command()
146
        if not os.path.isfile(filename):
147
            # Only write savu command if savu_mpi command has not been saved
148
            with open(filename, 'w') as command_log:
149
                command_log.write(f"# The directory the command was executed from\n")
150
                command_log.write(f"{current_path}\n")
151
                command_log.write(f"# Original Savu run command\n")
152
                command_log.write(f"{self.meta_data.get('command')}\n")
153
                command_log.write(f"# A modified Savu command to use to "
154
                                  f"reproduce the  obtained result\n")
155
                command_log.write(f"{modified_command}\n")
156
157
    def _get_modified_command(self):
158
        """Modify the input Savu run command, and replace the path to the
159
        process list
160
        :returns modified Savu run command string
161
        """
162
        pl_path = self.meta_data.get('process_file')
163
        new_pl_path = self.meta_data.get('process_list_path')
164
        input_command = self.meta_data.get('command')
165
        updated_command = input_command.replace(pl_path, new_pl_path)
166
        return updated_command
167
168
    def _save_pre_run_log(self):
169
        current_path = os.getcwd()
170
        folder = self.meta_data.get('out_path')
171
        log_folder = os.path.join(folder, "run_log")
172
        filename = os.path.join(log_folder, "pre_run_log.txt")
173
        if not os.path.isfile(filename):
174
            with open(filename, 'w') as pre_run_log:
175
                pre_run_log.write(f"# SAVU PRE-RUN\n")
176
                pre_run_log.write(f"# During the pre-run, the following process list was run:\n")
177
                pre_run_log.write(f"{self.meta_data.get('process_file_name')}\n")
178
                pre_run_log.write(f"# The following statistics were calculated on the input data:\n")
179
                if "pre_run_stats" in self.meta_data.get_dictionary().keys():
180
                    for key, value in self.meta_data.get("pre_run_stats").items():
181
                        pre_run_log.write(f"   {key}: {value}\n")
182
                if "pre_run_preview" in self.meta_data.get_dictionary().keys():
183
                    pre_run_log.write(f"# The following value for the preview parameter was calculated from the input data:\n")
184
                    pre_run_log.write(f"   {self.meta_data.get('pre_run_preview')}")
185
                if len(self.meta_data.get("warnings")) != 0:
186
                    pre_run_log.write(f"# Please read the following warnings before deciding whether to continue:\n")
187
                    for warning in self.meta_data.get("warnings"):
188
                        pre_run_log.write(f" ~ {warning}")
189
190
    def _set_process_list_path(self):
191
        """Create the path the process list should be saved to"""
192
        log_folder = os.path.join(self.meta_data.get('out_path'), "run_log")
193
        plname = os.path.basename(self.meta_data.get('process_file'))
194
        filename = os.path.join(log_folder, plname if plname
195
            else "process_list.nxs")
196
        self.meta_data.set('process_list_path', filename)
197
198
    def _set_process_list_path(self):
199
        """Create the path the process list should be saved to"""
200
        log_folder = os.path.join(self.meta_data.get('out_path'),"run_log")
201
        plname = os.path.basename(self.meta_data.get('process_file'))
202
        filename = os.path.join(log_folder, plname if plname
203
            else "process_list.nxs")
204
        self.meta_data.set('process_list_path', filename)
205
206
    def _set_initial_datasets(self):
207
        self.initial_datasets = copy.deepcopy(self.index['in_data'])
208
209
    def _set_transport(self, transport):
210
        self._transport = transport
211
212
    def _get_transport(self):
213
        return self._transport
214
215
    def __set_system_params(self):
216
        sys_file = self.meta_data.get('system_params')
217
        import sys
218
        if sys_file is None:
219
            # look in conda environment to see which version is being used
220
            savu_path = sys.modules['savu'].__path__[0]
221
            sys_files = os.path.join(
222
                os.path.dirname(savu_path), 'system_files')
223
            subdirs = os.listdir(sys_files)
224
            sys_folder = 'dls' if len(subdirs) > 1 else subdirs[0]
225
            fname = 'system_parameters.yml'
226
            sys_file = os.path.join(sys_files, sys_folder, fname)
227
        logging.info('Using the system parameters file: %s', sys_file)
228
        self.meta_data.set('system_params', yaml.read_yaml(sys_file))
229
230
    def _check_checkpoint(self):
231
        # if checkpointing has been set but the nxs file doesn't contain an
232
        # entry then remove checkpointing (as the previous run didn't get far
233
        # enough to require it).
234
        if self.meta_data.get('checkpoint'):
235
            with h5py.File(self.meta_data.get('nxs_filename'), 'r') as f:
236
                if 'entry' not in f:
237
                    self.meta_data.set('checkpoint', None)
238
239
    def _add_input_data_to_nxs_file(self, transport):
240
        # save the loaded data to file
241
        h5 = Hdf5Utils(self)
242
        for name, data in self.index['in_data'].items():
243
            self.meta_data.set(['link_type', name], 'input_data')
244
            self.meta_data.set(['group_name', name], name)
245
            self.meta_data.set(['filename', name], data.backing_file)
246
            transport._populate_nexus_file(data)
247
            h5._link_datafile_to_nexus_file(data)
248
249
    def _create_pre_run_nxs_file(self):
250
        data_path = self.meta_data["data_path"]
251
252
        for name, data in self.index["in_data"].items():
253
            raw_data = data.backing_file
254
            folder = self.meta_data['out_path']
255
            fname = self.meta_data.get('datafile_name') + '_pre_run.nxs'
256
            filename = os.path.join(folder, fname)
257
            self.meta_data.set("pre_run_filename", filename)
258
            self.__copy_input_file_to_output_folder(raw_data, filename)
259
260
            if isinstance(raw_data.get(data_path, getlink=True), h5py.ExternalLink):
261
                link = raw_data.get(data_path, getlink=True)
262
                location = f'{"/".join(self.meta_data.get("data_file").split("/")[:-1])}/{link.filename}'
263
                #new_filename = os.path.join(folder, link.filename)
264
                #with h5py.File(location, "r") as linked_file:
265
                #    self.__copy_input_file_to_output_folder(linked_file, new_filename)
266
                with h5py.File(filename, "r+") as new_file:
267
                    del new_file[data_path]
268
                    new_file[data_path] = h5py.ExternalLink(location, link.path)
269
                    pass
270
271
    def __copy_input_file_to_output_folder(self, file, new_filename):
272
        with h5py.File(new_filename, "w") as new_file:
273
            for group_name in file.keys():
274
                file.copy(file[group_name], new_file["/"], group_name)
275
276
    def _set_dataset_names_complete(self):
277
        """ Missing in/out_datasets fields have been populated
278
        """
279
        self._dataset_names_complete = True
280
281
    def _get_dataset_names_complete(self):
282
        return self._dataset_names_complete
283
284
    def _reset_datasets(self):
285
        self.index['in_data'] = self.initial_datasets
286
        # clear out dataset dictionaries
287
        for data_dict in self.collection['datasets']:
288
            for data in data_dict.values():
289
                data.meta_data._set_dictionary({})
290
291
    def _get_collection(self):
292
        return self.collection
293
294
    def _set_experiment_for_current_plugin(self, count):
295
        datasets_list = self.meta_data.plugin_list._get_datasets_list()[count:]
296
        exp_coll = self._get_collection()
297
        self.index['out_data'] = exp_coll['datasets'][count]
298
        if datasets_list:
299
            self._get_current_and_next_patterns(datasets_list)
300
        self.meta_data.set('nPlugin', count)
301
302
    def _get_current_and_next_patterns(self, datasets_lists):
303
        """ Get the current and next patterns associated with a dataset
304
        throughout the processing chain.
305
        """
306
        current_datasets = datasets_lists[0]
307
        patterns_list = {}
308
        for current_data in current_datasets['out_datasets']:
309
            current_name = current_data['name']
310
            current_pattern = current_data['pattern']
311
            next_pattern = self.__find_next_pattern(datasets_lists[1:],
312
                                                    current_name)
313
            patterns_list[current_name] = \
314
                {'current': current_pattern, 'next': next_pattern}
315
        self.meta_data.set('current_and_next', patterns_list)
316
317
    def __find_next_pattern(self, datasets_lists, current_name):
318
        next_pattern = []
319
        for next_data_list in datasets_lists:
320
            for next_data in next_data_list['in_datasets']:
321
                if next_data['name'] == current_name:
322
                    next_pattern = next_data['pattern']
323
                    return next_pattern
324
        return next_pattern
325
326
    def _set_nxs_file(self):
327
        folder = self.meta_data.get('out_path')
328
        if self.meta_data.get("pre_run") == True:
329
            fname = self.meta_data.get('datafile_name') + '_pre_run.nxs'
330
        else:
331
            fname = self.meta_data.get('datafile_name') + '_processed.nxs'
332
        filename = os.path.join(folder, fname)
333
        self.meta_data.set('nxs_filename', filename)
334
335
        if self.meta_data.get('process') == 1:
336
            if self.meta_data.get('bllog'):
337
                log_folder_name = self.meta_data.get('bllog')
338
                with open(log_folder_name, 'a') as log_folder:
339
                    log_folder.write(os.path.abspath(filename) + '\n')
340
341
        self._create_nxs_entry()
342
343
    def _create_nxs_entry(self):  # what if the file already exists?!
344
        logging.debug("Testing nexus file")
345
        if self.meta_data.get('process') == len(
346
                self.meta_data.get('processes')) - 1 and not self.checkpoint:
347
            with h5py.File(self.meta_data.get('nxs_filename'), 'w') as nxs_file:
348
                entry_group = nxs_file.create_group('entry')
349
                entry_group.attrs['NX_class'] = 'NXentry'
350
351
    def _clear_data_objects(self):
352
        self.index["out_data"] = {}
353
        self.index["in_data"] = {}
354
355
    def _merge_out_data_to_in(self, plugin_dict):
356
        out_data = self.index['out_data'].copy()
357
        for key, data in out_data.items():
358
            if data.remove is False:
359
                self.index['in_data'][key] = data
360
        self.collection['datasets'].append(out_data)
361
        self.collection['plugin_dict'].append(plugin_dict)
362
        self.index["out_data"] = {}
363
364
    def _finalise_experiment_for_current_plugin(self):
365
        finalise = {'remove': [], 'keep': []}
366
        # populate nexus file with out_dataset information and determine which
367
        # datasets to remove from the framework.
368
369
        for key, data in self.index['out_data'].items():
370
            if data.remove is True:
371
                finalise['remove'].append(data)
372
            else:
373
                finalise['keep'].append(data)
374
375
        # find in datasets to replace
376
        finalise['replace'] = []
377
        if not check_if_in_iterative_loop(self):
378
            for out_name in list(self.index['out_data'].keys()):
379
                if out_name in list(self.index['in_data'].keys()):
380
                    finalise['replace'].append(self.index['in_data'][out_name])
381
        else:
382
            # temporary workaround to
383
            # https://jira.diamond.ac.uk/browse/SCI-10216: don't mark any
384
            # datasets as "to replace" if the given plugin is in an iterative
385
            # loop
386
            logging.debug('Not marking any datasets in a loop as '\
387
                          '\"to replace\"')
388
        
389
        return finalise
390
391
    def _reorganise_datasets(self, finalise):
392
        # unreplicate replicated in_datasets
393
        self.__unreplicate_data()
394
395
        # delete all datasets for removal
396
        for data in finalise['remove']:
397
            del self.index["out_data"][data.data_info.get('name')]
398
399
        # Add remaining output datasets to input datasets
400
        for name, data in self.index['out_data'].items():
401
            data.get_preview().set_preview([])
402
            self.index["in_data"][name] = copy.deepcopy(data)
403
        self.index['out_data'] = {}
404
405
    def __unreplicate_data(self):
406
        in_data_list = self.index['in_data']
407
        from savu.data.data_structures.data_types.replicate import Replicate
408
        for in_data in list(in_data_list.values()):
409
            if isinstance(in_data.data, Replicate):
410
                in_data.data = in_data.data._reset()
411
412
    def _set_all_datasets(self, name):
413
        data_names = []
414
        for key in list(self.index["in_data"].keys()):
415
            if 'itr_clone' not in key:
416
                data_names.append(key)
417
        return data_names
418
419
    def _barrier(self, communicator=MPI.COMM_WORLD, msg=''):
420
        comm_dict = {'comm': communicator}
421
        if self.meta_data.get('mpi') is True:
422
            logging.debug("Barrier %d: %d processes expected: %s",
423
                          self._barrier_count, communicator.size, msg)
424
            comm_dict['comm'].barrier()
425
        self._barrier_count += 1
426
427
    def log(self, log_tag, log_level=logging.DEBUG):
428
        """
429
        Log the contents of the experiment at the specified level
430
        """
431
        logging.log(log_level, "Experimental Parameters for %s", log_tag)
432
        for key, value in self.index["in_data"].items():
433
            logging.log(log_level, "in data (%s) shape = %s", key,
434
                        value.get_shape())
435
        for key, value in self.index["in_data"].items():
436
            logging.log(log_level, "out data (%s) shape = %s", key,
437
                        value.get_shape())
438