Test Failed
Pull Request — master (#729)
by Nicola
05:14 queued 01:44
created

Experiment._update()   A

Complexity

Conditions 2

Size

Total Lines 7
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 6
nop 2
dl 0
loc 7
rs 10
c 0
b 0
f 0
1
# -*- coding: utf-8 -*-
2
# Copyright 2014 Diamond Light Source Ltd.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
"""
17
.. module:: experiment
18
   :platform: Unix
19
   :synopsis: Contains information specific to the entire experiment.
20
21
.. moduleauthor:: Nicola Wadeson <[email protected]>
22
"""
23
from __future__ import print_function, division, absolute_import
24
25
import os
26
import copy
27
import h5py
28
import logging
29
from mpi4py import MPI
30
31
from savu.data.meta_data import MetaData
32
from savu.data.plugin_list import PluginList
33
from savu.data.data_structures.data import Data
34
from savu.core.checkpointing import Checkpointing
35
from savu.plugins.savers.utils.hdf5_utils import Hdf5Utils
36
import savu.plugins.loaders.utils.yaml_utils as yaml
37
38
39
class Experiment(object):
40
    """
41
    One instance of this class is created at the beginning of the
42
    processing chain and remains until the end.  It holds the current data
43
    object and a dictionary containing all metadata.
44
    """
45
46
    def __init__(self, options):
47
        self.meta_data = MetaData(options)
48
        self.__set_system_params()
49
        self.checkpoint = Checkpointing(self)
50
        self.__meta_data_setup(options["process_file"])
51
        self.collection = {}
52
        self.index = {"in_data": {}, "out_data": {}}
53
        self.initial_datasets = None
54
        self.plugin = None
55
        self._transport = None
56
        self._barrier_count = 0
57
        self._dataset_names_complete = False
58
59
    def get(self, entry):
60
        """ Get the meta data dictionary. """
61
        return self.meta_data.get(entry)
62
63
    def __meta_data_setup(self, process_file):
64
        self.meta_data.plugin_list = PluginList()
65
        try:
66
            rtype = self.meta_data.get('run_type')
67
            if rtype == 'test':
68
                self.meta_data.plugin_list.plugin_list = \
69
                    self.meta_data.get('plugin_list')
70
            else:
71
                raise Exception('the run_type is unknown in Experiment class')
72
        except KeyError:
73
            template = self.meta_data.get('template')
74
            self.meta_data.plugin_list._populate_plugin_list(process_file,
75
                                                             template=template)
76
        self.meta_data.set("nPlugin", 0) # initialise
77
78
    def create_data_object(self, dtype, name, override=True):
79
        """ Create a data object.
80
81
        Plugin developers should apply this method in loaders only.
82
83
        :params str dtype: either "in_data" or "out_data".
84
        """
85
        if name not in list(self.index[dtype].keys()) or override:
86
            self.index[dtype][name] = Data(name, self)
87
            data_obj = self.index[dtype][name]
88
            data_obj._set_transport_data(self.meta_data.get('transport'))
89
        return self.index[dtype][name]
90
91
    def _setup(self, transport):
92
        self._set_nxs_file()
93
        self._set_transport(transport)
94
        self.collection = {'plugin_dict': [], 'datasets': []}
95
96
        self._barrier()
97
        self._check_checkpoint()
98
        self._barrier()
99
100
    def _finalise_setup(self, plugin_list):
101
        checkpoint = self.meta_data.get('checkpoint')
102
        self._set_dataset_names_complete()
103
        # save the plugin list - one process, first time only
104
        if self.meta_data.get('process') == \
105
                len(self.meta_data.get('processes'))-1 and not checkpoint:
106
            # links the input data to the nexus file
107
            plugin_list._save_plugin_list(self.meta_data.get('nxs_filename'))
108
            self._add_input_data_to_nxs_file(self._get_transport())
109
        self._set_dataset_names_complete()
110
111
    def _set_initial_datasets(self):
112
        self.initial_datasets = copy.deepcopy(self.index['in_data'])
113
114
    def _set_transport(self, transport):
115
        self._transport = transport
116
117
    def _get_transport(self):
118
        return self._transport
119
120
    def __set_system_params(self):
121
        sys_file = self.meta_data.get('system_params')
122
        import sys
123
        if sys_file is None:
124
            # look in conda environment to see which version is being used
125
            savu_path = sys.modules['savu'].__path__[0]
126
            sys_files = os.path.join(
127
                os.path.dirname(savu_path), 'system_files')
128
            subdirs = os.listdir(sys_files)
129
            sys_folder = 'dls' if len(subdirs) > 1 else subdirs[0]
130
            fname = 'system_parameters.yml'
131
            sys_file = os.path.join(sys_files, sys_folder, fname)
132
        logging.info('Using the system parameters file: %s', sys_file)
133
        self.meta_data.set('system_params', yaml.read_yaml(sys_file))
134
135
    def _check_checkpoint(self):
136
        # if checkpointing has been set but the nxs file doesn't contain an
137
        # entry then remove checkpointing (as the previous run didn't get far
138
        # enough to require it).
139
        if self.meta_data.get('checkpoint'):
140
            with h5py.File(self.meta_data.get('nxs_filename'), 'r') as f:
141
                if 'entry' not in f:
142
                    self.meta_data.set('checkpoint', None)
143
144
    def _add_input_data_to_nxs_file(self, transport):
145
        # save the loaded data to file
146
        h5 = Hdf5Utils(self)
147
        for name, data in self.index['in_data'].items():
148
            self.meta_data.set(['link_type', name], 'input_data')
149
            self.meta_data.set(['group_name', name], name)
150
            self.meta_data.set(['filename', name], data.backing_file)
151
            transport._populate_nexus_file(data)
152
            h5._link_datafile_to_nexus_file(data)
153
154
    def _set_dataset_names_complete(self):
155
        """ Missing in/out_datasets fields have been populated
156
        """
157
        self._dataset_names_complete = True
158
159
    def _get_dataset_names_complete(self):
160
        return self._dataset_names_complete
161
162
    def _reset_datasets(self):
163
        self.index['in_data'] = self.initial_datasets
164
        # clear out dataset dictionaries
165
        for data_dict in self.collection['datasets']:
166
            for data in data_dict.values():
167
                data.meta_data._set_dictionary({})
168
169
    def _get_collection(self):
170
        return self.collection
171
172
    def _set_experiment_for_current_plugin(self, count):
173
        datasets_list = self.meta_data.plugin_list._get_datasets_list()[count:]
174
        exp_coll = self._get_collection()
175
        self.index['out_data'] = exp_coll['datasets'][count]
176
        if datasets_list:
177
            self._get_current_and_next_patterns(datasets_list)
178
        self.meta_data.set('nPlugin', count)
179
180
    def _get_current_and_next_patterns(self, datasets_lists):
181
        """ Get the current and next patterns associated with a dataset
182
        throughout the processing chain.
183
        """
184
        current_datasets = datasets_lists[0]
185
        patterns_list = {}
186
        for current_data in current_datasets['out_datasets']:
187
            current_name = current_data['name']
188
            current_pattern = current_data['pattern']
189
            next_pattern = self.__find_next_pattern(datasets_lists[1:],
190
                                                    current_name)
191
            patterns_list[current_name] = \
192
                {'current': current_pattern, 'next': next_pattern}
193
        self.meta_data.set('current_and_next', patterns_list)
194
195
    def __find_next_pattern(self, datasets_lists, current_name):
196
        next_pattern = []
197
        for next_data_list in datasets_lists:
198
            for next_data in next_data_list['in_datasets']:
199
                if next_data['name'] == current_name:
200
                    next_pattern = next_data['pattern']
201
                    return next_pattern
202
        return next_pattern
203
204
    def _set_nxs_file(self):
205
        folder = self.meta_data.get('out_path')
206
        fname = self.meta_data.get('datafile_name') + '_processed.nxs'
207
        filename = os.path.join(folder, fname)
208
        self.meta_data.set('nxs_filename', filename)
209
210
        if self.meta_data.get('process') == 1:
211
            if self.meta_data.get('bllog'):
212
                log_folder_name = self.meta_data.get('bllog')
213
                log_folder = open(log_folder_name, 'a')
214
                log_folder.write(os.path.abspath(filename) + '\n')
215
                log_folder.close()
216
217
        self._create_nxs_entry()
218
219
    def _create_nxs_entry(self):  # what if the file already exists?!
220
        logging.debug("Testing nexus file")
221
        import h5py
222
        if self.meta_data.get('process') == len(
223
                self.meta_data.get('processes')) - 1 and not self.checkpoint:
224
            with h5py.File(self.meta_data.get('nxs_filename'), 'w') as nxs_file:
225
                entry_group = nxs_file.create_group('entry')
226
                entry_group.attrs['NX_class'] = 'NXentry'
227
228
    def _clear_data_objects(self):
229
        self.index["out_data"] = {}
230
        self.index["in_data"] = {}
231
232
    def _merge_out_data_to_in(self, plugin_dict):
233
        out_data = self.index['out_data']
234
        for key, data in out_data.items():
235
            if data.remove is False:
236
                out_data[key] = data
237
        self.collection['datasets'].append(out_data)
238
        self.collection['plugin_dict'].append(plugin_dict)
239
        self.index["out_data"] = {}
240
241
    def _finalise_experiment_for_current_plugin(self):
242
        finalise = {'remove': [], 'keep': []}
243
        # populate nexus file with out_dataset information and determine which
244
        # datasets to remove from the framework.
245
246
        for key, data in self.index['out_data'].items():
247
            if data.remove is True:
248
                finalise['remove'].append(data)
249
            else:
250
                finalise['keep'].append(data)
251
252
        # find in datasets to replace
253
        finalise['replace'] = []
254
        for out_name in list(self.index['out_data'].keys()):
255
            if out_name in list(self.index['in_data'].keys()):
256
                finalise['replace'].append(self.index['in_data'][out_name])
257
258
        
259
        return finalise
260
261
    def _reorganise_datasets(self, finalise):
262
        # unreplicate replicated in_datasets
263
        self.__unreplicate_data()
264
265
        # delete all datasets for removal
266
        for data in finalise['remove']:
267
            del self.index["out_data"][data.data_info.get('name')]
268
269
        # Add remaining output datasets to input datasets
270
        for name, data in self.index['out_data'].items():
271
            data.get_preview().set_preview([])
272
            self.index["in_data"][name] = copy.deepcopy(data)
273
        self.index['out_data'] = {}
274
275
    def __unreplicate_data(self):
276
        in_data_list = self.index['in_data']
277
        from savu.data.data_structures.data_types.replicate import Replicate
278
        for in_data in list(in_data_list.values()):
279
            if isinstance(in_data.data, Replicate):
280
                in_data.data = in_data.data._reset()
281
282
    def _set_all_datasets(self, name):
283
        data_names = []
284
        for key in list(self.index["in_data"].keys()):
285
            if 'itr_clone' not in key:
286
                data_names.append(key)
287
        return data_names
288
289
    def _barrier(self, communicator=MPI.COMM_WORLD, msg=''):
290
        comm_dict = {'comm': communicator}
291
        if self.meta_data.get('mpi') is True:
292
            logging.debug("Barrier %d: %d processes expected: %s",
293
                          self._barrier_count, communicator.size, msg)
294
            comm_dict['comm'].barrier()
295
        self._barrier_count += 1
296
297
    def log(self, log_tag, log_level=logging.DEBUG):
298
        """
299
        Log the contents of the experiment at the specified level
300
        """
301
        logging.log(log_level, "Experimental Parameters for %s", log_tag)
302
        for key, value in self.index["in_data"].items():
303
            logging.log(log_level, "in data (%s) shape = %s", key,
304
                        value.get_shape())
305
        for key, value in self.index["in_data"].items():
306
            logging.log(log_level, "out data (%s) shape = %s", key,
307
                        value.get_shape())
308