Test Failed
Pull Request — master (#708)
by Daniil
03:14
created

savu.plugins.plugin.Plugin.post_process()   A

Complexity

Conditions 1

Size

Total Lines 10
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 10
rs 10
c 0
b 0
f 0
1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: plugin
17
   :platform: Unix
18
   :synopsis: Base class for all plugins used by Savu
19
20
.. moduleauthor:: Mark Basham <[email protected]>
21
22
"""
23
24
import ast
25
import copy
26
import logging
27
import inspect
28
import numpy as np
29
30
import savu.plugins.docstring_parser as doc
31
from savu.plugins.plugin_datasets import PluginDatasets
32
33
34
class Plugin(PluginDatasets):
35
    """
36
    The base class from which all plugins should inherit.
37
    :param in_datasets: Create a list of the dataset(s) to \
38
        process. Default: [].
39
    :param out_datasets: Create a list of the dataset(s) to \
40
        create. Default: [].
41
    """
42
43
    def __init__(self, name="Plugin"):
44
        super(Plugin, self).__init__(name)
45
        self.name = name
46
        self.parameters = {}
47
        self.parameters_types = {}
48
        self.parameters_desc = {}
49
        self.parameters_hide = []
50
        self.parameters_user = []
51
        self.chunk = False
52
        self.docstring_info = {}
53
        self.slice_list = None
54
        self.global_index = None
55
        self.pcount = 0
56
        self.exp = None
57
        self.check = False
58
59
    def initialise(self, params, exp, check=False):
60
        self.check = check
61
        self.exp = exp
62
        self._set_parameters(copy.deepcopy(params))
63
        self._main_setup()
64
65
    def _main_setup(self):
66
        """ Performs all the required plugin setup.
67
68
        It sets the experiment, then the parameters and replaces the
69
        in/out_dataset strings in ``self.parameters`` with the relevant data
70
        objects. It then creates PluginData objects for each of these datasets.
71
        """
72
        self._set_plugin_datasets()
73
        self._reset_process_frames_counter()
74
        self.setup()
75
        self.set_filter_padding(*(self.get_plugin_datasets()))
76
        self._finalise_plugin_datasets()
77
        self._finalise_datasets()
78
79
    def _reset_process_frames_counter(self):
80
        self.pcount = 0
81
82
    def get_process_frames_counter(self):
83
        return self.pcount
84
85
    def _set_parameters_this_instance(self, indices):
86
        """ Determines the parameters for this instance of the plugin, in the
87
        case of parameter tuning.
88
89
        param np.ndarray indices: the index of the current value in the
90
            parameter tuning list.
91
        """
92
        dims = set(self.multi_params_dict.keys())
93
        count = 0
94
        for dim in dims:
95
            info = self.multi_params_dict[dim]
96
            name = info['label'].split('_param')[0]
97
            self.parameters[name] = info['values'][indices[count]]
98
            count += 1
99
100
    def set_filter_padding(self, in_data, out_data):
101
        """
102
        Should be overridden to define how wide the frame should be for each
103
        input data set
104
        """
105
        return {}
106
107
    def setup(self):
108
        """
109
        This method is first to be called after the plugin has been created.
110
        It determines input/output datasets and plugin specific dataset
111
        information such as the pattern (e.g. sinogram/projection).
112
        """
113
        logging.error("set_up needs to be implemented")
114
        raise NotImplementedError("setup needs to be implemented")
115
116
    def _populate_default_parameters(self):
117
        """
118
        This method should populate all the required parameters with default
119
        values.  it is used for checking to see if new parameter values are
120
        appropriate
121
122
        It makes use of the classes including parameter information in the
123
        class docstring such as this
124
125
        :param error_threshold: Convergence threshold. Default: 0.001.
126
        """
127
        hidden_items = []
128
        user_items = []
129
        params = []
130
        not_params = []
131
        for clazz in inspect.getmro(self.__class__)[::-1]:
132
            if clazz != object:
133
                desc = doc.find_args(clazz, self)
134
                self.docstring_info['warn'] = desc['warn']
135
                self.docstring_info['info'] = desc['info']
136
                self.docstring_info['synopsis'] = desc['synopsis']
137
                params.extend(desc['param'])
138
                if desc['hide_param']:
139
                    hidden_items.extend(desc['hide_param'])
140
                if desc['user_param']:
141
                    user_items.extend(desc['user_param'])
142
                if desc['not_param']:
143
                    not_params.extend(desc['not_param'])
144
        self._add_item(params, not_params)
145
        user_items = [u for u in user_items if u not in not_params]
146
        hidden_items = [h for h in hidden_items if h not in not_params]
147
        user_items = list(set(user_items).difference(set(hidden_items)))
148
        self.parameters_hide = hidden_items
149
        self.parameters_user = user_items
150
        self.final_parameter_updates()
151
152
    def _add_item(self, item_list, not_list):
153
        true_list = [i for i in item_list if i['name'] not in not_list]
154
        for item in true_list:
155
            self.parameters[item['name']] = item['default']
156
            self.parameters_types[item['name']] = item['dtype']
157
            self.parameters_desc[item['name']] = item['desc']
158
159
    def delete_parameter_entry(self, param):
160
        if param in list(self.parameters.keys()):
161
            del self.parameters[param]
162
            del self.parameters_types[param]
163
            del self.parameters_desc[param]
164
165
    def initialise_parameters(self):
166
        self.parameters = {}
167
        self.parameters_types = {}
168
        self._populate_default_parameters()
169
        self.multi_params_dict = {}
170
        self.extra_dims = []
171
172
    def _set_parameters(self, parameters):
173
        """
174
        This method is called after the plugin has been created by the
175
        pipeline framework.  It replaces ``self.parameters`` default values
176
        with those given in the input process list.
177
178
        :param dict parameters: A dictionary of the parameters for this \
179
        plugin, or None if no customisation is required.
180
        """
181
        self.initialise_parameters()
182
        # reverse sorting added on Python 3 conversion to make the behaviour
183
        # similar (hopefully the same) as on Python 2
184
        for key in parameters.keys():
185
            if key in self.parameters.keys():
186
                value = self.__convert_multi_params(parameters[key], key)
187
                self.parameters[key] = value
188
            else:
189
                error = ("Parameter '%s' is not valid for plugin %s. \nTry "
190
                         "opening and re-saving the process list in the "
191
                         "configurator to auto remove \nobsolete parameters."
192
                         % (key, self.name))
193
                raise ValueError(error)
194
195
    def __convert_multi_params(self, value, key):
196
        """ Set up parameter tuning.
197
198
        Convert parameter value to a list if it uses parameter tuning and set
199
        associated parameters, so the framework knows the new size of the data
200
        and which plugins to re-run.
201
        """
202
        dtype = self.parameters_types[key]
203
        if isinstance(value, str) and ';' in value:
204
            value = value.split(';')
205
            if ":" in value[0]:
206
                seq = value[0].split(':')
207
                seq = [eval(s) for s in seq]
208
                value = list(np.arange(seq[0], seq[1], seq[2]))
209
                if len(value) == 0:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable len does not seem to be defined.
Loading history...
210
                    raise RuntimeError(
211
                        'No values for tuned parameter "{}", '
212
                        'ensure start:stop:step; values are valid.'.format(key))
213
            if not isinstance(value[0], dtype):
214
                try:
215
                    value.remove('')
216
                except Exception:
217
                    pass
218
                if isinstance(value[0], str):
219
                    value = [ast.literal_eval(i) for i in value]
220
                value = list(map(dtype, value))
221
            label = key + '_params.' + type(value[0]).__name__
222
            self.multi_params_dict[len(self.multi_params_dict)] = \
223
                {'label': label, 'values': value}
224
            self.extra_dims.append(len(value))
225
        return value
226
227
    def get_parameters(self, name):
228
        """ Return a plugin parameter
229
230
        :params str name: parameter name (dictionary key)
231
        :returns: the associated value in ``self.parameters``
232
        :rtype: dict value
233
        """
234
        return self.parameters[name]
235
236
    def base_pre_process(self):
237
        """ This method is called after the plugin has been created by the
238
        pipeline framework as a pre-processing step.
239
        """
240
        pass
241
242
    def pre_process(self):
243
        """ This method is called immediately after base_pre_process(). """
244
        pass
245
246
    def base_process_frames_before(self, data):
247
        """ This method is called before each call to process frames """
248
        return data
249
250
    def base_process_frames_after(self, data):
251
        """ This method is called directly after each call to process frames \
252
        and before returning the data to file."""
253
        return data
254
255
    def plugin_process_frames(self, data):
256
        frames = self.base_process_frames_after(self.process_frames(
257
                self.base_process_frames_before(data)))
258
        self.pcount += 1
259
        return frames
260
261
    def process_frames(self, data):
262
        """
263
        This method is called after the plugin has been created by the
264
        pipeline framework and forms the main processing step
265
266
        :param data: A list of numpy arrays for each input dataset.
267
        :type data: list(np.array)
268
        """
269
270
        logging.error("process frames needs to be implemented")
271
        raise NotImplementedError("process needs to be implemented")
272
273
    def post_process(self):
274
        """
275
        This method is called after the process function in the pipeline
276
        framework as a post-processing step. All processes will have finished
277
        performing the main processing at this stage.
278
279
        :param exp: An experiment object, holding input and output datasets
280
        :type exp: experiment class instance
281
        """
282
        pass
283
284
    def base_post_process(self):
285
        """ This method is called immediately after post_process(). """
286
        pass
287
288
    def set_preview(self, data, params):
289
        if not params:
290
            return True
291
        preview = data.get_preview()
292
        orig_indices = preview.get_starts_stops_steps()
293
        nDims = len(orig_indices[0])
294
        no_preview = [[0]*nDims, data.get_shape(), [1]*nDims, [1]*nDims]
295
296
        # Set previewing params if previewing has not already been applied to
297
        # the dataset.
298
        if no_preview == orig_indices:
299
            data.get_preview().revert_shape = data.get_shape()
300
            data.get_preview().set_preview(params)
301
            return True
302
        return False
303
304
    def _clean_up(self):
305
        """ Perform necessary plugin clean up after the plugin has completed.
306
        """
307
        self._clone_datasets()
308
        self.__copy_meta_data()
309
        self.__set_previous_patterns()
310
        self.__clean_up_plugin_data()
311
312
    def __copy_meta_data(self):
313
        """
314
        Copy all metadata from input datasets to output datasets, except axis
315
        data that is no longer valid.
316
        """
317
        remove_keys = self.__remove_axis_data()
318
        in_meta_data, out_meta_data = self.get()
319
        copy_dict = {}
320
        for mData in in_meta_data:
321
            temp = copy.deepcopy(mData.get_dictionary())
322
            copy_dict.update(temp)
323
324
        for i in range(len(out_meta_data)):
325
            temp = copy_dict.copy()
326
            for key in remove_keys[i]:
327
                if temp.get(key, None) is not None:
328
                    del temp[key]
329
            temp.update(out_meta_data[i].get_dictionary())
330
            out_meta_data[i]._set_dictionary(temp)
331
332
    def __set_previous_patterns(self):
333
        for data in self.get_out_datasets():
334
            data._set_previous_pattern(
335
                copy.deepcopy(data._get_plugin_data().get_pattern()))
336
337
    def __remove_axis_data(self):
338
        """
339
        Returns a list of meta_data entries corresponding to axis labels that
340
        are not copied over to the output datasets
341
        """
342
        in_datasets, out_datasets = self.get_datasets()
343
        all_in_labels = []
344
        for data in in_datasets:
345
            axis_keys = data.get_axis_label_keys()
346
            all_in_labels = all_in_labels + axis_keys
347
348
        remove_keys = []
349
        for data in out_datasets:
350
            axis_keys = data.get_axis_label_keys()
351
            remove_keys.append(set(all_in_labels).difference(set(axis_keys)))
352
353
        return remove_keys
354
355
    def __clean_up_plugin_data(self):
356
        """ Remove pluginData object encapsulated in a dataset after plugin
357
        completion.
358
        """
359
        in_data, out_data = self.get_datasets()
360
        data_object_list = in_data + out_data
361
        for data in data_object_list:
362
            data._clear_plugin_data()
363
364
    def _revert_preview(self, in_data):
365
        """ Revert dataset back to original shape if previewing was used in a
366
        plugin to reduce the data shape but the original data shape should be
367
        used thereafter. Remove previewing if it was added in the plugin.
368
        """
369
        for data in in_data:
370
            if data.get_preview().revert_shape:
371
                data.get_preview()._unset_preview()
372
373
    def set_global_frame_index(self, frame_idx):
374
        self.global_index = frame_idx
375
376
    def get_global_frame_index(self):
377
        """ Get the global frame index. """
378
        return self.global_index
379
380
381
    def set_current_slice_list(self, sl):
382
        self.slice_list = sl
383
384
    def get_current_slice_list(self):
385
        """ Get the slice list of the current frame being processed. """
386
        return self.slice_list
387
388
    def get_slice_dir_reps(self, nData):
389
        """ Return the periodicity of the main slice direction.
390
391
        :params int nData: The number of the dataset in the list.
392
        """
393
        slice_dir = \
394
            self.get_plugin_in_datasets()[nData].get_slice_directions()[0]
395
        sl = [sl[slice_dir] for sl in self.slice_list]
396
        reps = [i for i in range(len(sl)) if sl[i] == sl[0]]
397
        return np.diff(reps)[0] if len(reps) > 1 else 1
398
399
    def nInput_datasets(self):
400
        """
401
        The number of datasets required as input to the plugin
402
403
        :returns:  Number of input datasets
404
405
        """
406
        return 1
407
408
    def nOutput_datasets(self):
409
        """
410
        The number of datasets created by the plugin
411
412
        :returns:  Number of output datasets
413
414
        """
415
        return 1
416
417
    def nClone_datasets(self):
418
        """ The number of output datasets that have an clone - i.e. they take\
419
        it in turns to be used as output in an iterative plugin.
420
        """
421
        return 0
422
423
    def nFrames(self):
424
        """ The number of frames to process during each call to process_frames.
425
        """
426
        return 'single'
427
428
    def final_parameter_updates(self):
429
        """ An opportunity to update the parameters after they have been set.
430
        """
431
        pass
432
433
    def get_citation_information(self):
434
        """
435
        Gets the Citation Information for a plugin
436
437
        :returns:  A populated savu.data.plugin_info.CitationInfomration
438
439
        """
440
        return None
441
442
    def executive_summary(self):
443
        """ Provide a summary to the user for the result of the plugin.
444
445
        e.g.
446
         - Warning, the sample may have shifted during data collection
447
         - Filter operated normally
448
449
        :returns:  A list of string summaries
450
        """
451
        return ["Nothing to Report"]
452