Test Failed
Pull Request — master (#743)
by Nicola
03:13
created

savu.plugins.plugin.Plugin.set_parameters()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 2
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: plugin
17
   :platform: Unix
18
   :synopsis: Base class for all plugins used by Savu
19
20
.. moduleauthor:: Mark Basham <[email protected]>
21
22
"""
23
24
import ast
25
import copy
26
import logging
27
import inspect
28
import numpy as np
29
30
import savu.plugins.docstring_parser as doc
31
from savu.plugins.plugin_datasets import PluginDatasets
32
33
34
class Plugin(PluginDatasets):
35
    """
36
    The base class from which all plugins should inherit.
37
    :param in_datasets: Create a list of the dataset(s) to \
38
        process. Default: [].
39
    :param out_datasets: Create a list of the dataset(s) to \
40
        create. Default: [].
41
    """
42
43
    def __init__(self, name="Plugin"):
44
        super(Plugin, self).__init__(name)
45
        self.name = name
46
        self.parameters = {}
47
        self.parameters_types = {}
48
        self.parameters_desc = {}
49
        self.parameters_hide = []
50
        self.parameters_user = []
51
        self.chunk = False
52
        self.docstring_info = {}
53
        self.slice_list = None
54
        self.global_index = None
55
        self.pcount = 0
56
        self.exp = None
57
        self.check = False
58
        self.fixed_length = True
59
        self.parameters = {}
60
        self.tools = self._set_plugin_tools()
61
62
    def set_parameters(self, params):
63
        self.parameters = params
64
65
    def initialise(self, params, exp, check=False):
66
        self.check = check
67
        self.exp = exp
68
        self._set_parameters(copy.deepcopy(params))
69
        self._main_setup()
70
71
    def _main_setup(self):
72
        """ Performs all the required plugin setup.
73
74
        It sets the experiment, then the parameters and replaces the
75
        in/out_dataset strings in ``self.parameters`` with the relevant data
76
        objects. It then creates PluginData objects for each of these datasets.
77
        """
78
        self._set_plugin_datasets()
79
        self._reset_process_frames_counter()
80
        self.setup()
81
        self.set_filter_padding(*(self.get_plugin_datasets()))
82
        self._finalise_plugin_datasets()
83
        self._finalise_datasets()
84
85
    def _reset_process_frames_counter(self):
86
        self.pcount = 0
87
88
    def get_process_frames_counter(self):
89
        return self.pcount
90
91
    def _set_parameters_this_instance(self, indices):
92
        """ Determines the parameters for this instance of the plugin, in the
93
        case of parameter tuning.
94
95
        param np.ndarray indices: the index of the current value in the
96
            parameter tuning list.
97
        """
98
        dims = set(self.multi_params_dict.keys())
99
        count = 0
100
        for dim in dims:
101
            info = self.multi_params_dict[dim]
102
            name = info['label'].split('_param')[0]
103
            self.parameters[name] = info['values'][indices[count]]
104
            count += 1
105
106
    def set_filter_padding(self, in_data, out_data):
107
        """
108
        Should be overridden to define how wide the frame should be for each
109
        input data set
110
        """
111
        return {}
112
113
    def setup(self):
114
        """
115
        This method is first to be called after the plugin has been created.
116
        It determines input/output datasets and plugin specific dataset
117
        information such as the pattern (e.g. sinogram/projection).
118
        """
119
        logging.error("set_up needs to be implemented")
120
        raise NotImplementedError("setup needs to be implemented")
121
122
    def _populate_default_parameters(self):
123
        """
124
        This method should populate all the required parameters with default
125
        values.  it is used for checking to see if new parameter values are
126
        appropriate
127
128
        It makes use of the classes including parameter information in the
129
        class docstring such as this
130
131
        :param error_threshold: Convergence threshold. Default: 0.001.
132
        """
133
        hidden_items = []
134
        user_items = []
135
        params = []
136
        not_params = []
137
        for clazz in inspect.getmro(self.__class__)[::-1]:
138
            if clazz != object:
139
                desc = doc.find_args(clazz, self)
140
                self.docstring_info['warn'] = desc['warn']
141
                self.docstring_info['info'] = desc['info']
142
                self.docstring_info['synopsis'] = desc['synopsis']
143
                params.extend(desc['param'])
144
                if desc['hide_param']:
145
                    hidden_items.extend(desc['hide_param'])
146
                if desc['user_param']:
147
                    user_items.extend(desc['user_param'])
148
                if desc['not_param']:
149
                    not_params.extend(desc['not_param'])
150
        self._add_item(params, not_params)
151
        user_items = [u for u in user_items if u not in not_params]
152
        hidden_items = [h for h in hidden_items if h not in not_params]
153
        user_items = list(set(user_items).difference(set(hidden_items)))
154
        self.parameters_hide = hidden_items
155
        self.parameters_user = user_items
156
        self.final_parameter_updates()
157
158
    def _add_item(self, item_list, not_list):
159
        true_list = [i for i in item_list if i['name'] not in not_list]
160
        for item in true_list:
161
            self.parameters[item['name']] = item['default']
162
            self.parameters_types[item['name']] = item['dtype']
163
            self.parameters_desc[item['name']] = item['desc']
164
165
    def delete_parameter_entry(self, param):
166
        if param in list(self.parameters.keys()):
167
            del self.parameters[param]
168
            del self.parameters_types[param]
169
            del self.parameters_desc[param]
170
171
    def initialise_parameters(self):
172
        self.parameters = {}
173
        self.parameters_types = {}
174
        self._populate_default_parameters()
175
        self.multi_params_dict = {}
176
        self.extra_dims = []
177
178
    def _set_parameters(self, parameters):
179
        """
180
        This method is called after the plugin has been created by the
181
        pipeline framework.  It replaces ``self.parameters`` default values
182
        with those given in the input process list.
183
184
        :param dict parameters: A dictionary of the parameters for this \
185
        plugin, or None if no customisation is required.
186
        """
187
        self.initialise_parameters()
188
        # reverse sorting added on Python 3 conversion to make the behaviour
189
        # similar (hopefully the same) as on Python 2
190
        for key in parameters.keys():
191
            if key in self.parameters.keys():
192
                value = self.__convert_multi_params(parameters[key], key)
193
                self.parameters[key] = value
194
            else:
195
                error = ("Parameter '%s' is not valid for plugin %s. \nTry "
196
                         "opening and re-saving the process list in the "
197
                         "configurator to auto remove \nobsolete parameters."
198
                         % (key, self.name))
199
                raise ValueError(error)
200
201
    def __convert_multi_params(self, value, key):
202
        """ Set up parameter tuning.
203
204
        Convert parameter value to a list if it uses parameter tuning and set
205
        associated parameters, so the framework knows the new size of the data
206
        and which plugins to re-run.
207
        """
208
        dtype = self.parameters_types[key]
209
        if isinstance(value, str) and ';' in value:
210
            value = value.split(';')
211
            if ":" in value[0]:
212
                seq = value[0].split(':')
213
                seq = [eval(s) for s in seq]
214
                value = list(np.arange(seq[0], seq[1], seq[2]))
215
                if len(value) == 0:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable len does not seem to be defined.
Loading history...
216
                    raise RuntimeError(
217
                        'No values for tuned parameter "{}", '
218
                        'ensure start:stop:step; values are valid.'.format(key))
219
            if not isinstance(value[0], dtype):
220
                try:
221
                    value.remove('')
222
                except Exception:
223
                    pass
224
                if isinstance(value[0], str):
225
                    value = [ast.literal_eval(i) for i in value]
226
                value = list(map(dtype, value))
227
            label = key + '_params.' + type(value[0]).__name__
228
            self.multi_params_dict[len(self.multi_params_dict)] = \
229
                {'label': label, 'values': value}
230
            self.extra_dims.append(len(value))
231
        return value
232
233
    def get_parameters(self, name):
234
        """ Return a plugin parameter
235
236
        :params str name: parameter name (dictionary key)
237
        :returns: the associated value in ``self.parameters``
238
        :rtype: dict value
239
        """
240
        return self.parameters[name]
241
242
    def base_pre_process(self):
243
        """ This method is called after the plugin has been created by the
244
        pipeline framework as a pre-processing step.
245
        """
246
        pass
247
248
    def pre_process(self):
249
        """ This method is called immediately after base_pre_process(). """
250
        pass
251
252
    def base_process_frames_before(self, data):
253
        """ This method is called before each call to process frames """
254
        return data
255
256
    def base_process_frames_after(self, data):
257
        """ This method is called directly after each call to process frames \
258
        and before returning the data to file."""
259
        return data
260
261
    def plugin_process_frames(self, data):
262
        frames = self.base_process_frames_after(self.process_frames(
263
                self.base_process_frames_before(data)))
264
        self.pcount += 1
265
        return frames
266
267
    def process_frames(self, data):
268
        """
269
        This method is called after the plugin has been created by the
270
        pipeline framework and forms the main processing step
271
272
        :param data: A list of numpy arrays for each input dataset.
273
        :type data: list(np.array)
274
        """
275
276
        logging.error("process frames needs to be implemented")
277
        raise NotImplementedError("process needs to be implemented")
278
279
    def post_process(self):
280
        """
281
        This method is called after the process function in the pipeline
282
        framework as a post-processing step. All processes will have finished
283
        performing the main processing at this stage.
284
285
        :param exp: An experiment object, holding input and output datasets
286
        :type exp: experiment class instance
287
        """
288
        pass
289
290
    def base_post_process(self):
291
        """ This method is called immediately after post_process(). """
292
        pass
293
294
    def set_preview(self, data, params):
295
        if not params:
296
            return True
297
        preview = data.get_preview()
298
        orig_indices = preview.get_starts_stops_steps()
299
        nDims = len(orig_indices[0])
300
        no_preview = [[0]*nDims, data.get_shape(), [1]*nDims, [1]*nDims]
301
302
        # Set previewing params if previewing has not already been applied to
303
        # the dataset.
304
        if no_preview == orig_indices:
305
            data.get_preview().revert_shape = data.get_shape()
306
            data.get_preview().set_preview(params)
307
            return True
308
        return False
309
310
    def _clean_up(self):
311
        """ Perform necessary plugin clean up after the plugin has completed.
312
        """
313
        self._clone_datasets()
314
        self.__copy_meta_data()
315
        self.__set_previous_patterns()
316
        self.__clean_up_plugin_data()
317
318
    def __copy_meta_data(self):
319
        """
320
        Copy all metadata from input datasets to output datasets, except axis
321
        data that is no longer valid.
322
        """
323
        remove_keys = self.__remove_axis_data()
324
        in_meta_data, out_meta_data = self.get()
325
        copy_dict = {}
326
        for mData in in_meta_data:
327
            temp = copy.deepcopy(mData.get_dictionary())
328
            copy_dict.update(temp)
329
330
        for i in range(len(out_meta_data)):
331
            temp = copy_dict.copy()
332
            for key in remove_keys[i]:
333
                if temp.get(key, None) is not None:
334
                    del temp[key]
335
            temp.update(out_meta_data[i].get_dictionary())
336
            out_meta_data[i]._set_dictionary(temp)
337
338
    def __set_previous_patterns(self):
339
        for data in self.get_out_datasets():
340
            data._set_previous_pattern(
341
                copy.deepcopy(data._get_plugin_data().get_pattern()))
342
343
    def __remove_axis_data(self):
344
        """
345
        Returns a list of meta_data entries corresponding to axis labels that
346
        are not copied over to the output datasets
347
        """
348
        in_datasets, out_datasets = self.get_datasets()
349
        all_in_labels = []
350
        for data in in_datasets:
351
            axis_keys = data.get_axis_label_keys()
352
            all_in_labels = all_in_labels + axis_keys
353
354
        remove_keys = []
355
        for data in out_datasets:
356
            axis_keys = data.get_axis_label_keys()
357
            remove_keys.append(set(all_in_labels).difference(set(axis_keys)))
358
359
        return remove_keys
360
361
    def __clean_up_plugin_data(self):
362
        """ Remove pluginData object encapsulated in a dataset after plugin
363
        completion.
364
        """
365
        in_data, out_data = self.get_datasets()
366
        data_object_list = in_data + out_data
367
        for data in data_object_list:
368
            data._clear_plugin_data()
369
370
    def _revert_preview(self, in_data):
371
        """ Revert dataset back to original shape if previewing was used in a
372
        plugin to reduce the data shape but the original data shape should be
373
        used thereafter. Remove previewing if it was added in the plugin.
374
        """
375
        for data in in_data:
376
            if data.get_preview().revert_shape:
377
                data.get_preview()._unset_preview()
378
379
    def set_global_frame_index(self, frame_idx):
380
        self.global_index = frame_idx
381
382
    def get_global_frame_index(self):
383
        """ Get the global frame index. """
384
        return self.global_index
385
386
387
    def set_current_slice_list(self, sl):
388
        self.slice_list = sl
389
390
    def get_current_slice_list(self):
391
        """ Get the slice list of the current frame being processed. """
392
        return self.slice_list
393
394
    def get_slice_dir_reps(self, nData):
395
        """ Return the periodicity of the main slice direction.
396
397
        :params int nData: The number of the dataset in the list.
398
        """
399
        slice_dir = \
400
            self.get_plugin_in_datasets()[nData].get_slice_directions()[0]
401
        sl = [sl[slice_dir] for sl in self.slice_list]
402
        reps = [i for i in range(len(sl)) if sl[i] == sl[0]]
403
        return np.diff(reps)[0] if len(reps) > 1 else 1
404
405
    def nInput_datasets(self):
406
        """
407
        The number of datasets required as input to the plugin
408
409
        :returns:  Number of input datasets
410
411
        """
412
        return 1
413
414
    def nOutput_datasets(self):
415
        """
416
        The number of datasets created by the plugin
417
418
        :returns:  Number of output datasets
419
420
        """
421
        return 1
422
423
    def nClone_datasets(self):
424
        """ The number of output datasets that have an clone - i.e. they take\
425
        it in turns to be used as output in an iterative plugin.
426
        """
427
        return 0
428
429
    def nFrames(self):
430
        """ The number of frames to process during each call to process_frames.
431
        """
432
        return 'single'
433
434
    def final_parameter_updates(self):
435
        """ An opportunity to update the parameters after they have been set.
436
        """
437
        pass
438
439
    def get_citation_information(self):
440
        """
441
        Gets the Citation Information for a plugin
442
443
        :returns:  A populated savu.data.plugin_info.CitationInfomration
444
445
        """
446
        return None
447
448
    def executive_summary(self):
449
        """ Provide a summary to the user for the result of the plugin.
450
451
        e.g.
452
         - Warning, the sample may have shifted during data collection
453
         - Filter operated normally
454
455
        :returns:  A list of string summaries
456
        """
457
        return ["Nothing to Report"]
458