YamlConverter._check_for_inheritance()   A
last analyzed

Complexity

Conditions 5

Size

Total Lines 14
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 14
nop 4
dl 0
loc 14
rs 9.2333
c 0
b 0
f 0
1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: yaml_converter
17
   :platform: Unix
18
   :synopsis: 'A class to load data from a non-standard nexus/hdf5 file using \
19
               descriptions loaded from a yaml file.'
20
21
.. moduleauthor:: Nicola Wadeson <[email protected]>
22
23
"""
24
25
import os
26
import h5py
27
import yaml
28
import copy
29
import logging
30
import collections.abc as collections
31
import numpy as np  # used in exec so do not delete
32
from ast import literal_eval
33
34
import savu.plugins.utils as pu
35
import savu.plugins.loaders.utils.yaml_utils as yu
36
from savu.plugins.loaders.base_loader import BaseLoader
37
from savu.data.experiment_collection import Experiment
38
39
40
class YamlConverter(BaseLoader):
41
    def __init__(self, name='YamlConverter'):
42
        super(YamlConverter, self).__init__(name)
43
44
    def setup(self, template=False, metadata=True):
45
        #  Read YAML file
46
        yfile = self.parameters['yaml_file']
47
        data_dict = yu.read_yaml(self._get_yaml_file(yfile))
48
        data_dict = self._check_for_inheritance(data_dict, {})
49
        self._check_for_imports(data_dict)
50
        data_dict.pop('inherit', None)
51
        data_dict.pop('import', None)
52
        if template:
53
            return data_dict
54
55
        data_dict = self._add_template_updates(data_dict)
56
        self._set_entries(data_dict)
57
58
    def _get_yaml_file(self, yaml_file):
59
        if yaml_file is None:
60
            raise Exception('Please pass a yaml file to the yaml loader.')
61
            
62
        # try the absolute path
63
        yaml_abs = os.path.abspath(yaml_file)
64
        if os.path.exists(yaml_abs):
65
            return yaml_abs
66
        
67
        # try adding the path to savu
68
        if len(yaml_file.split('Savu/')) > 1:
69
            yaml_savu = os.path.join(os.path.dirname(__file__), "../../../",
70
                                     yaml_file.split('Savu/')[1])
71
            if os.path.exists(yaml_savu):
72
                return yaml_savu
73
74
        # try adding the path to the templates folder
75
        yaml_templ = os.path.join(os.path.dirname(__file__), yaml_file)
76
        if os.path.exists(yaml_templ):
77
            return yaml_templ
78
79
        raise Exception('The yaml file does not exist %s' % yaml_file)
80
81
    def _add_template_updates(self, ddict):
82
        all_entries = ddict.pop('all', {})
83
        for key, value in all_entries:
84
            for entry in ddict:
85
                if key in list(entry.keys()):
86
                    entry[key] = value
87
88
        for entry in self.parameters['template_param']:
89
            updates = self.parameters['template_param'][entry]
90
            ddict[entry]['params'].update(updates)
91
        return ddict
92
93
    def _check_for_imports(self, ddict):
94
        if 'import' in list(ddict.keys()):
95
            for imp in ddict['import']:
96
                name = False
97
                if len(imp.split()) > 1:
98
                    imp, name = imp.split('as')
99
                mod = __import__(imp.strip())
100
                globals()[mod.__name__ if not name else name] = mod
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable globals does not seem to be defined.
Loading history...
101
102
    def _check_for_inheritance(self, ddict, inherit, override=False):
103
        if 'inherit' in list(ddict.keys()):
104
            idict = ddict['inherit']
105
            idict = idict if isinstance(idict, list) else [idict]
106
            for i in idict:
107
                if i != 'None':
108
                    new_dict = yu.read_yaml(self._get_yaml_file(i))
109
                    new_dict, isoverride = \
110
                        self.__override(inherit, new_dict, override)
111
                    inherit.update(new_dict)
112
                    inherit = self._check_for_inheritance(
113
                            new_dict, inherit, override=isoverride)
114
        self._update(inherit, ddict)
115
        return inherit
116
117
    def __override(self, inherit, ddict, override):
118
        isoverride = False
119
        if 'override' in ddict:
120
            isoverride = ddict.pop('override')
121
        if override:
122
            for old, new in override.items():
123
                ddict[new] = ddict.pop(old)
124
                if new in list(inherit.keys()):
125
                    self._update(ddict[new], inherit[new])
126
        return ddict, isoverride
127
128
    def _update(self, d, u):
129
        for k, v in u.items():
130
            if isinstance(v, collections.Mapping):
131
                d[k] = self._update(d.get(k, {}), v)
132
            else:
133
                d[k] = v
134
        return d
135
136
    def _set_entries(self, ddict):
137
        entries = list(ddict.keys())
138
        for name in entries:
139
            self.get_description(ddict[name], name)
140
141
    def get_description(self, entry, name, metadata=True):
142
        # set params first as we may need them subsequently
143
        if 'params' in entry:
144
            self._set_params(entry['params'])
145
        # --------------- check for data entry -----------------------------
146
        if 'data' in list(entry.keys()):
147
            data_obj = self.exp.create_data_object("in_data", name)
148
            data_obj = self.set_data(data_obj, entry['data'])
149
            
150
        else:
151
            emsg = 'Please specify the data information in the yaml file.'
152
            raise Exception(emsg)
153
154
        if metadata:
155
            self._get_meta_data_descriptions(entry, data_obj)
156
157
    def _get_meta_data_descriptions(self, entry, data_obj):
158
        # --------------- check for axis label information -----------------
159
        if 'axis_labels' in list(entry.keys()):
160
            self._set_axis_labels(data_obj, entry['axis_labels'])
161
        else:
162
            raise Exception('Please specify the axis labels in the yaml file.')
163
164
        # --------------- check for data access patterns -------------------
165
        if 'patterns' in list(entry.keys()):
166
            self._set_patterns(data_obj, entry['patterns'])
167
        else:
168
            raise Exception('Please specify the patterns in the yaml file.')
169
170
        # add any additional metadata
171
        if 'metadata' in entry:
172
            self._set_metadata(data_obj, entry['metadata'])
173
        self.set_data_reduction_params(data_obj)
174
175
        if 'exp_metadata' in entry:
176
            self._set_metadata(data_obj, entry['exp_metadata'], exp=True)
177
178
    def set_data(self, name, entry):
179
        raise NotImplementedError('Please implement "set_data" function'
180
                                  ' in the loader')
181
182
    def _set_keywords(self, dObj):
183
        filepath = str(dObj.backing_file.filename)
184
        shape = str(dObj.get_shape())
185
        return {'dfile': filepath, 'dshape': shape}
186
187
    def __get_wildcard_values(self, dObj):
188
        if 'wildcard_values' in list(dObj.data_info.get_dictionary().keys()):
189
            return dObj.data_info.get('wildcard_values')
190
        return None
191
192
    def update_value(self, dObj, value, itr=0):
193
        import pdb
194
        # setting the keywords
195
        if dObj is not None:
196
            dshape = dObj.get_shape()
197
            dfile = dObj.backing_file
198
            globals()['dfile'] = dfile
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable globals does not seem to be defined.
Loading history...
199
            wildcard = self.__get_wildcard_values(dObj)
200
201
        if isinstance(value, str):
202
            split = value.split('$')
203
            if len(split) > 1:
204
                value = self._convert_string(dObj, split[1])
205
                try:
206
                    value = eval(value, globals(), locals())
207
                    value = self._convert_bytes(value)
208
                except Exception as e:
209
                    msg = (f"Error evaluating value: '{value}' \n %s" % e)
210
                    try:
211
                        value = value.replace("index(", "index(b")
212
                        value = eval(value, globals(), locals())
213
                        value = self._convert_bytes(value)
214
                    except:
215
                        raise Exception(msg)
216
        return value
217
218
    def _convert_string(self, dObj, string):
219
        for old, new in self.parameters.items():
220
            if old in string:
221
                if isinstance(new, str):
222
                    split = new.split('$')
223
                    if len(split) > 1:
224
                        new = split[1]
225
                    elif isinstance(new, str): # nothing left to split
226
                        new = "'%s'" % new
227
                string = self._convert_string(
228
                        dObj, string.replace(old, str(new)))
229
        return string
230
231
    def _convert_bytes(self, value):
232
        # convert bytes to str - for back compatability
233
        if isinstance(value, bytes):
234
            return value.decode("ascii")
235
        if isinstance(value, np.ndarray) and isinstance(value[0], bytes):
236
            return value.astype(str)
237
        return value
238
239
    def _set_params(self, params):
240
        # Update variable parameters that are revealed in the template
241
        params = self._update_template_params(params)
242
        self.parameters.update(params)
243
        # find files, open and add to the namespace then delete file params
244
        files = [k for k in list(params.keys()) if k.endswith('file')]
245
        for f in files:
246
            param = params[f]
247
            try:
248
                globals()[str(f)] = self.update_value(None, param)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable str does not seem to be defined.
Loading history...
Comprehensibility Best Practice introduced by
The variable globals does not seem to be defined.
Loading history...
249
            except IOError:
250
                self._check_for_test_data(f, param)
251
            del params[f]
252
253
    def _check_for_test_data(self, f, param):
254
        # check if this is Savu test data
255
        substrs = param.split("'")[1:2]
256
        filename = None
257
        for s in substrs:
258
            try:
259
                filename = self._get_yaml_file(s)
260
                break
261
            except:
262
                pass
263
        param = param.replace(s, filename)
0 ignored issues
show
introduced by
The variable s does not seem to be defined in case the for loop on line 257 is not entered. Are you sure this can never be the case?
Loading history...
264
        globals()[str(f)] = self.update_value(None, param)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable globals does not seem to be defined.
Loading history...
Comprehensibility Best Practice introduced by
The variable str does not seem to be defined.
Loading history...
265
        del self.parameters[f]
266
267
    def _update_template_params(self, params):
268
        for k, v in params.items():
269
            v = pu.is_template_param(v)
270
            if v is not False:
271
                params[k] = \
272
                    self.parameters[k] if k in list(self.parameters.keys()) else v[1]
273
        return params
274
275
    def _set_axis_labels(self, dObj, labels):
276
        dims = list(range(len(list(labels.keys()))))
277
        axis_labels = [None]*len(list(labels.keys()))
278
        for d in dims:
279
            self._check_label_entry(labels[d])
280
            l = labels[d]
281
            for key in list(l.keys()):
282
                l[key] = self.update_value(dObj, l[key])
283
            axis_labels[l['dim']] = (l['name'] + '.' + l['units'])
284
            if l['value'] is not None:
285
                dObj.meta_data.set(l['name'], l['value'])
286
        dObj.set_axis_labels(*axis_labels)
287
288
    def _check_label_entry(self, label):
289
        required = ['dim', 'name', 'value', 'units']
290
        try:
291
            [label[i] for i in required]
292
        except:
293
            raise Exception("name, value and units are required fields for \
294
                            axis labels")
295
296
    def _set_patterns(self, dObj, patterns):
297
        for key, dims in patterns.items():
298
            core_dims = self.__get_tuple(
299
                    self.update_value(dObj, dims['core_dims']))
300
            slice_dims = self.__get_tuple(
301
                    self.update_value(dObj, dims['slice_dims']))
302
            dObj.add_pattern(key, core_dims=core_dims, slice_dims=slice_dims)
303
304
    def __get_tuple(self, val):
305
        return literal_eval(val) if not isinstance(val, tuple) else val
306
307
    def _set_metadata(self, dObj, mdata, exp=False):
308
        populate = dObj.exp if exp else dObj
309
        for key, value in mdata.items():
310
            value = self.update_value(dObj, value['value'])
311
            populate.meta_data.set(key, value)
312