Test Failed
Pull Request — master (#820)
by
unknown
03:58
created

Experiment._update_command()   A

Complexity

Conditions 1

Size

Total Lines 8
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 6
nop 1
dl 0
loc 8
rs 10
c 0
b 0
f 0
1
# -*- coding: utf-8 -*-
2
# Copyright 2014 Diamond Light Source Ltd.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
"""
17
.. module:: experiment
18
   :platform: Unix
19
   :synopsis: Contains information specific to the entire experiment.
20
21
.. moduleauthor:: Nicola Wadeson <[email protected]>
22
"""
23
24
import os
25
import copy
26
import h5py
27
import logging
28
from mpi4py import MPI
29
30
from savu.data.meta_data import MetaData
31
from savu.data.plugin_list import PluginList
32
from savu.data.data_structures.data import Data
33
from savu.core.checkpointing import Checkpointing
34
from savu.plugins.savers.utils.hdf5_utils import Hdf5Utils
35
import savu.plugins.loaders.utils.yaml_utils as yaml
36
37
38
class Experiment(object):
39
    """
40
    One instance of this class is created at the beginning of the
41
    processing chain and remains until the end.  It holds the current data
42
    object and a dictionary containing all metadata.
43
    """
44
45
    def __init__(self, options):
46
        self.meta_data = MetaData(options)
47
        self.__set_system_params()
48
        self.checkpoint = Checkpointing(self)
49
        self.__meta_data_setup(options["process_file"])
50
        self.collection = {}
51
        self.index = {"in_data": {}, "out_data": {}}
52
        self.initial_datasets = None
53
        self.plugin = None
54
        self._transport = None
55
        self._barrier_count = 0
56
        self._dataset_names_complete = False
57
58
    def get(self, entry):
59
        """ Get the meta data dictionary. """
60
        return self.meta_data.get(entry)
61
62
    def __meta_data_setup(self, process_file):
63
        self.meta_data.plugin_list = PluginList()
64
        try:
65
            rtype = self.meta_data.get('run_type')
66
            if rtype == 'test':
67
                self.meta_data.plugin_list.plugin_list = \
68
                    self.meta_data.get('plugin_list')
69
            else:
70
                raise Exception('the run_type is unknown in Experiment class')
71
        except KeyError:
72
            template = self.meta_data.get('template')
73
            self.meta_data.plugin_list._populate_plugin_list(process_file,
74
                                                             template=template)
75
        self.meta_data.set("nPlugin", 0) # initialise
76
77
    def create_data_object(self, dtype, name, override=True):
78
        """ Create a data object.
79
80
        Plugin developers should apply this method in loaders only.
81
82
        :params str dtype: either "in_data" or "out_data".
83
        """
84
        if name not in list(self.index[dtype].keys()) or override:
85
            self.index[dtype][name] = Data(name, self)
86
            data_obj = self.index[dtype][name]
87
            data_obj._set_transport_data(self.meta_data.get('transport'))
88
        return self.index[dtype][name]
89
90
    def _setup(self, transport):
91
        self._set_nxs_file()
92
        self._set_process_list_path()
93
        self._set_transport(transport)
94
        self.collection = {'plugin_dict': [], 'datasets': []}
95
96
        self._barrier()
97
        self._check_checkpoint()
98
        self._barrier()
99
100
    def _finalise_setup(self, plugin_list):
101
        checkpoint = self.meta_data.get('checkpoint')
102
        self._set_dataset_names_complete()
103
        # save the plugin list - one process, first time only
104
        if self.meta_data.get('process') == \
105
                len(self.meta_data.get('processes'))-1 and not checkpoint:
106
            # Save original process list
107
            plugin_list._save_plugin_list(self.meta_data.get('process_list_path'))
108
            # links the input data to the nexus file
109
            plugin_list._save_plugin_list(self.meta_data.get('nxs_filename'),
110
                                          self.meta_data.get('command'))
111
            self._add_input_data_to_nxs_file(self._get_transport())
112
        self._set_dataset_names_complete()
113
114
    def _set_initial_datasets(self):
115
        self.initial_datasets = copy.deepcopy(self.index['in_data'])
116
117
    def _set_transport(self, transport):
118
        self._transport = transport
119
120
    def _get_transport(self):
121
        return self._transport
122
123
    def __set_system_params(self):
124
        sys_file = self.meta_data.get('system_params')
125
        import sys
126
        if sys_file is None:
127
            # look in conda environment to see which version is being used
128
            savu_path = sys.modules['savu'].__path__[0]
129
            sys_files = os.path.join(
130
                os.path.dirname(savu_path), 'system_files')
131
            subdirs = os.listdir(sys_files)
132
            sys_folder = 'dls' if len(subdirs) > 1 else subdirs[0]
133
            fname = 'system_parameters.yml'
134
            sys_file = os.path.join(sys_files, sys_folder, fname)
135
        logging.info('Using the system parameters file: %s', sys_file)
136
        self.meta_data.set('system_params', yaml.read_yaml(sys_file))
137
138
    def _check_checkpoint(self):
139
        # if checkpointing has been set but the nxs file doesn't contain an
140
        # entry then remove checkpointing (as the previous run didn't get far
141
        # enough to require it).
142
        if self.meta_data.get('checkpoint'):
143
            with h5py.File(self.meta_data.get('nxs_filename'), 'r') as f:
144
                if 'entry' not in f:
145
                    self.meta_data.set('checkpoint', None)
146
147
    def _add_input_data_to_nxs_file(self, transport):
148
        # save the loaded data to file
149
        h5 = Hdf5Utils(self)
150
        for name, data in self.index['in_data'].items():
151
            self.meta_data.set(['link_type', name], 'input_data')
152
            self.meta_data.set(['group_name', name], name)
153
            self.meta_data.set(['filename', name], data.backing_file)
154
            transport._populate_nexus_file(data)
155
            h5._link_datafile_to_nexus_file(data)
156
157
    def _set_dataset_names_complete(self):
158
        """ Missing in/out_datasets fields have been populated
159
        """
160
        self._dataset_names_complete = True
161
162
    def _get_dataset_names_complete(self):
163
        return self._dataset_names_complete
164
165
    def _reset_datasets(self):
166
        self.index['in_data'] = self.initial_datasets
167
        # clear out dataset dictionaries
168
        for data_dict in self.collection['datasets']:
169
            for data in data_dict.values():
170
                data.meta_data._set_dictionary({})
171
172
    def _get_collection(self):
173
        return self.collection
174
175
    def _set_experiment_for_current_plugin(self, count):
176
        datasets_list = self.meta_data.plugin_list._get_datasets_list()[count:]
177
        exp_coll = self._get_collection()
178
        self.index['out_data'] = exp_coll['datasets'][count]
179
        if datasets_list:
180
            self._get_current_and_next_patterns(datasets_list)
181
        self.meta_data.set('nPlugin', count)
182
183
    def _get_current_and_next_patterns(self, datasets_lists):
184
        """ Get the current and next patterns associated with a dataset
185
        throughout the processing chain.
186
        """
187
        current_datasets = datasets_lists[0]
188
        patterns_list = {}
189
        for current_data in current_datasets['out_datasets']:
190
            current_name = current_data['name']
191
            current_pattern = current_data['pattern']
192
            next_pattern = self.__find_next_pattern(datasets_lists[1:],
193
                                                    current_name)
194
            patterns_list[current_name] = \
195
                {'current': current_pattern, 'next': next_pattern}
196
        self.meta_data.set('current_and_next', patterns_list)
197
198
    def __find_next_pattern(self, datasets_lists, current_name):
199
        next_pattern = []
200
        for next_data_list in datasets_lists:
201
            for next_data in next_data_list['in_datasets']:
202
                if next_data['name'] == current_name:
203
                    next_pattern = next_data['pattern']
204
                    return next_pattern
205
        return next_pattern
206
207
    def _set_process_list_path(self):
208
        """Create the path the process list should be saved to"""
209
        folder = self.meta_data.get('out_path')
210
        plname = os.path.basename(self.meta_data.get('process_file'))
211
        filename = os.path.join(folder, plname if plname
212
            else "process_list.nxs")
213
        self.meta_data.set('process_list_path', filename)
214
        self._update_command()
215
216
    def _update_command(self):
217
        """Update the input command and replace the path to the
218
        process list"""
219
        pl_path = self.meta_data.get('process_file')
220
        new_pl_path = self.meta_data.get('process_list_path')
221
        input_command = self.meta_data.get('command')
222
        updated_command = input_command.replace(pl_path, new_pl_path)
223
        self.meta_data.set('command', updated_command)
224
225
    def _set_nxs_file(self):
226
        folder = self.meta_data.get('out_path')
227
        fname = self.meta_data.get('datafile_name') + '_processed.nxs'
228
        filename = os.path.join(folder, fname)
229
        self.meta_data.set('nxs_filename', filename)
230
231
        if self.meta_data.get('process') == 1:
232
            if self.meta_data.get('bllog'):
233
                log_folder_name = self.meta_data.get('bllog')
234
                with open(log_folder_name, 'a') as log_folder:
235
                    log_folder.write(os.path.abspath(filename) + '\n')
236
237
        self._create_nxs_entry()
238
239
    def _create_nxs_entry(self):  # what if the file already exists?!
240
        logging.debug("Testing nexus file")
241
        if self.meta_data.get('process') == len(
242
                self.meta_data.get('processes')) - 1 and not self.checkpoint:
243
            with h5py.File(self.meta_data.get('nxs_filename'), 'w') as nxs_file:
244
                entry_group = nxs_file.create_group('entry')
245
                entry_group.attrs['NX_class'] = 'NXentry'
246
247
    def _clear_data_objects(self):
248
        self.index["out_data"] = {}
249
        self.index["in_data"] = {}
250
251
    def _merge_out_data_to_in(self, plugin_dict):
252
        out_data = self.index['out_data'].copy()
253
        for key, data in out_data.items():
254
            if data.remove is False:
255
                self.index['in_data'][key] = data
256
        self.collection['datasets'].append(out_data)
257
        self.collection['plugin_dict'].append(plugin_dict)
258
        self.index["out_data"] = {}
259
260
    def _finalise_experiment_for_current_plugin(self):
261
        finalise = {'remove': [], 'keep': []}
262
        # populate nexus file with out_dataset information and determine which
263
        # datasets to remove from the framework.
264
265
        for key, data in self.index['out_data'].items():
266
            if data.remove is True:
267
                finalise['remove'].append(data)
268
            else:
269
                finalise['keep'].append(data)
270
271
        # find in datasets to replace
272
        finalise['replace'] = []
273
        for out_name in list(self.index['out_data'].keys()):
274
            if out_name in list(self.index['in_data'].keys()):
275
                finalise['replace'].append(self.index['in_data'][out_name])
276
277
        
278
        return finalise
279
280
    def _reorganise_datasets(self, finalise):
281
        # unreplicate replicated in_datasets
282
        self.__unreplicate_data()
283
284
        # delete all datasets for removal
285
        for data in finalise['remove']:
286
            del self.index["out_data"][data.data_info.get('name')]
287
288
        # Add remaining output datasets to input datasets
289
        for name, data in self.index['out_data'].items():
290
            data.get_preview().set_preview([])
291
            self.index["in_data"][name] = copy.deepcopy(data)
292
        self.index['out_data'] = {}
293
294
    def __unreplicate_data(self):
295
        in_data_list = self.index['in_data']
296
        from savu.data.data_structures.data_types.replicate import Replicate
297
        for in_data in list(in_data_list.values()):
298
            if isinstance(in_data.data, Replicate):
299
                in_data.data = in_data.data._reset()
300
301
    def _set_all_datasets(self, name):
302
        data_names = []
303
        for key in list(self.index["in_data"].keys()):
304
            if 'itr_clone' not in key:
305
                data_names.append(key)
306
        return data_names
307
308
    def _barrier(self, communicator=MPI.COMM_WORLD, msg=''):
309
        comm_dict = {'comm': communicator}
310
        if self.meta_data.get('mpi') is True:
311
            logging.debug("Barrier %d: %d processes expected: %s",
312
                          self._barrier_count, communicator.size, msg)
313
            comm_dict['comm'].barrier()
314
        self._barrier_count += 1
315
316
    def log(self, log_tag, log_level=logging.DEBUG):
317
        """
318
        Log the contents of the experiment at the specified level
319
        """
320
        logging.log(log_level, "Experimental Parameters for %s", log_tag)
321
        for key, value in self.index["in_data"].items():
322
            logging.log(log_level, "in data (%s) shape = %s", key,
323
                        value.get_shape())
324
        for key, value in self.index["in_data"].items():
325
            logging.log(log_level, "out data (%s) shape = %s", key,
326
                        value.get_shape())
327