1
|
|
|
# Copyright 2014 Diamond Light Source Ltd. |
2
|
|
|
# |
3
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
4
|
|
|
# you may not use this file except in compliance with the License. |
5
|
|
|
# You may obtain a copy of the License at |
6
|
|
|
# |
7
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
8
|
|
|
# |
9
|
|
|
# Unless required by applicable law or agreed to in writing, software |
10
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS, |
11
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12
|
|
|
# See the License for the specific language governing permissions and |
13
|
|
|
# limitations under the License. |
14
|
|
|
|
15
|
|
|
""" |
16
|
|
|
.. module:: plugin |
17
|
|
|
:platform: Unix |
18
|
|
|
:synopsis: Base class for all plugins used by Savu |
19
|
|
|
|
20
|
|
|
.. moduleauthor:: Mark Basham <[email protected]> |
21
|
|
|
|
22
|
|
|
""" |
23
|
|
|
|
24
|
|
|
import ast |
25
|
|
|
import copy |
26
|
|
|
import logging |
27
|
|
|
import inspect |
28
|
|
|
import numpy as np |
29
|
|
|
|
30
|
|
|
import savu.plugins.docstring_parser as doc |
31
|
|
|
from savu.plugins.plugin_datasets import PluginDatasets |
32
|
|
|
|
33
|
|
|
|
34
|
|
|
class Plugin(PluginDatasets): |
35
|
|
|
""" |
36
|
|
|
The base class from which all plugins should inherit. |
37
|
|
|
:param in_datasets: Create a list of the dataset(s) to \ |
38
|
|
|
process. Default: []. |
39
|
|
|
:param out_datasets: Create a list of the dataset(s) to \ |
40
|
|
|
create. Default: []. |
41
|
|
|
""" |
42
|
|
|
|
43
|
|
|
def __init__(self, name="Plugin"): |
44
|
|
|
super(Plugin, self).__init__(name) |
45
|
|
|
self.name = name |
46
|
|
|
self.parameters = {} |
47
|
|
|
self.parameters_types = {} |
48
|
|
|
self.parameters_desc = {} |
49
|
|
|
self.parameters_hide = [] |
50
|
|
|
self.parameters_user = [] |
51
|
|
|
self.chunk = False |
52
|
|
|
self.docstring_info = {} |
53
|
|
|
self.slice_list = None |
54
|
|
|
self.global_index = None |
55
|
|
|
self.pcount = 0 |
56
|
|
|
self.exp = None |
57
|
|
|
self.check = False |
58
|
|
|
|
59
|
|
|
def initialise(self, params, exp, check=False): |
60
|
|
|
self.check = check |
61
|
|
|
self.exp = exp |
62
|
|
|
self._set_parameters(copy.deepcopy(params)) |
63
|
|
|
self._main_setup() |
64
|
|
|
|
65
|
|
|
def _main_setup(self): |
66
|
|
|
""" Performs all the required plugin setup. |
67
|
|
|
|
68
|
|
|
It sets the experiment, then the parameters and replaces the |
69
|
|
|
in/out_dataset strings in ``self.parameters`` with the relevant data |
70
|
|
|
objects. It then creates PluginData objects for each of these datasets. |
71
|
|
|
""" |
72
|
|
|
self._set_plugin_datasets() |
73
|
|
|
self._reset_process_frames_counter() |
74
|
|
|
self.setup() |
75
|
|
|
self.set_filter_padding(*(self.get_plugin_datasets())) |
76
|
|
|
self._finalise_plugin_datasets() |
77
|
|
|
self._finalise_datasets() |
78
|
|
|
|
79
|
|
|
def _reset_process_frames_counter(self): |
80
|
|
|
self.pcount = 0 |
81
|
|
|
|
82
|
|
|
def get_process_frames_counter(self): |
83
|
|
|
return self.pcount |
84
|
|
|
|
85
|
|
|
def _set_parameters_this_instance(self, indices): |
86
|
|
|
""" Determines the parameters for this instance of the plugin, in the |
87
|
|
|
case of parameter tuning. |
88
|
|
|
|
89
|
|
|
param np.ndarray indices: the index of the current value in the |
90
|
|
|
parameter tuning list. |
91
|
|
|
""" |
92
|
|
|
dims = set(self.multi_params_dict.keys()) |
93
|
|
|
count = 0 |
94
|
|
|
for dim in dims: |
95
|
|
|
info = self.multi_params_dict[dim] |
96
|
|
|
name = info['label'].split('_param')[0] |
97
|
|
|
self.parameters[name] = info['values'][indices[count]] |
98
|
|
|
count += 1 |
99
|
|
|
|
100
|
|
|
def set_filter_padding(self, in_data, out_data): |
101
|
|
|
""" |
102
|
|
|
Should be overridden to define how wide the frame should be for each |
103
|
|
|
input data set |
104
|
|
|
""" |
105
|
|
|
return {} |
106
|
|
|
|
107
|
|
|
def setup(self): |
108
|
|
|
""" |
109
|
|
|
This method is first to be called after the plugin has been created. |
110
|
|
|
It determines input/output datasets and plugin specific dataset |
111
|
|
|
information such as the pattern (e.g. sinogram/projection). |
112
|
|
|
""" |
113
|
|
|
logging.error("set_up needs to be implemented") |
114
|
|
|
raise NotImplementedError("setup needs to be implemented") |
115
|
|
|
|
116
|
|
|
def _populate_default_parameters(self): |
117
|
|
|
""" |
118
|
|
|
This method should populate all the required parameters with default |
119
|
|
|
values. it is used for checking to see if new parameter values are |
120
|
|
|
appropriate |
121
|
|
|
|
122
|
|
|
It makes use of the classes including parameter information in the |
123
|
|
|
class docstring such as this |
124
|
|
|
|
125
|
|
|
:param error_threshold: Convergence threshold. Default: 0.001. |
126
|
|
|
""" |
127
|
|
|
hidden_items = [] |
128
|
|
|
user_items = [] |
129
|
|
|
params = [] |
130
|
|
|
not_params = [] |
131
|
|
|
for clazz in inspect.getmro(self.__class__)[::-1]: |
132
|
|
|
if clazz != object: |
133
|
|
|
desc = doc.find_args(clazz, self) |
134
|
|
|
self.docstring_info['warn'] = desc['warn'] |
135
|
|
|
self.docstring_info['info'] = desc['info'] |
136
|
|
|
self.docstring_info['synopsis'] = desc['synopsis'] |
137
|
|
|
params.extend(desc['param']) |
138
|
|
|
if desc['hide_param']: |
139
|
|
|
hidden_items.extend(desc['hide_param']) |
140
|
|
|
if desc['user_param']: |
141
|
|
|
user_items.extend(desc['user_param']) |
142
|
|
|
if desc['not_param']: |
143
|
|
|
not_params.extend(desc['not_param']) |
144
|
|
|
self._add_item(params, not_params) |
145
|
|
|
user_items = [u for u in user_items if u not in not_params] |
146
|
|
|
hidden_items = [h for h in hidden_items if h not in not_params] |
147
|
|
|
user_items = list(set(user_items).difference(set(hidden_items))) |
148
|
|
|
self.parameters_hide = hidden_items |
149
|
|
|
self.parameters_user = user_items |
150
|
|
|
self.final_parameter_updates() |
151
|
|
|
|
152
|
|
|
def _add_item(self, item_list, not_list): |
153
|
|
|
true_list = [i for i in item_list if i['name'] not in not_list] |
154
|
|
|
for item in true_list: |
155
|
|
|
self.parameters[item['name']] = item['default'] |
156
|
|
|
self.parameters_types[item['name']] = item['dtype'] |
157
|
|
|
self.parameters_desc[item['name']] = item['desc'] |
158
|
|
|
|
159
|
|
|
def delete_parameter_entry(self, param): |
160
|
|
|
if param in list(self.parameters.keys()): |
161
|
|
|
del self.parameters[param] |
162
|
|
|
del self.parameters_types[param] |
163
|
|
|
del self.parameters_desc[param] |
164
|
|
|
|
165
|
|
|
def initialise_parameters(self): |
166
|
|
|
self.parameters = {} |
167
|
|
|
self.parameters_types = {} |
168
|
|
|
self._populate_default_parameters() |
169
|
|
|
self.multi_params_dict = {} |
170
|
|
|
self.extra_dims = [] |
171
|
|
|
|
172
|
|
|
def _set_parameters(self, parameters): |
173
|
|
|
""" |
174
|
|
|
This method is called after the plugin has been created by the |
175
|
|
|
pipeline framework. It replaces ``self.parameters`` default values |
176
|
|
|
with those given in the input process list. |
177
|
|
|
|
178
|
|
|
:param dict parameters: A dictionary of the parameters for this \ |
179
|
|
|
plugin, or None if no customisation is required. |
180
|
|
|
""" |
181
|
|
|
self.initialise_parameters() |
182
|
|
|
# reverse sorting added on Python 3 conversion to make the behaviour |
183
|
|
|
# similar (hopefully the same) as on Python 2 |
184
|
|
|
for key in parameters.keys(): |
185
|
|
|
if key in self.parameters.keys(): |
186
|
|
|
value = self.__convert_multi_params(parameters[key], key) |
187
|
|
|
self.parameters[key] = value |
188
|
|
|
else: |
189
|
|
|
error = ("Parameter '%s' is not valid for plugin %s. \nTry " |
190
|
|
|
"opening and re-saving the process list in the " |
191
|
|
|
"configurator to auto remove \nobsolete parameters." |
192
|
|
|
% (key, self.name)) |
193
|
|
|
raise ValueError(error) |
194
|
|
|
|
195
|
|
|
def __convert_multi_params(self, value, key): |
196
|
|
|
""" Set up parameter tuning. |
197
|
|
|
|
198
|
|
|
Convert parameter value to a list if it uses parameter tuning and set |
199
|
|
|
associated parameters, so the framework knows the new size of the data |
200
|
|
|
and which plugins to re-run. |
201
|
|
|
""" |
202
|
|
|
dtype = self.parameters_types[key] |
203
|
|
|
if isinstance(value, str) and ';' in value: |
204
|
|
|
value = value.split(';') |
205
|
|
|
if ":" in value[0]: |
206
|
|
|
seq = value[0].split(':') |
207
|
|
|
seq = [eval(s) for s in seq] |
208
|
|
|
value = list(np.arange(seq[0], seq[1], seq[2])) |
209
|
|
|
if len(value) == 0: |
|
|
|
|
210
|
|
|
raise RuntimeError( |
211
|
|
|
'No values for tuned parameter "{}", ' |
212
|
|
|
'ensure start:stop:step; values are valid.'.format(key)) |
213
|
|
|
if not isinstance(value[0], dtype): |
214
|
|
|
try: |
215
|
|
|
value.remove('') |
216
|
|
|
except Exception: |
217
|
|
|
pass |
218
|
|
|
if isinstance(value[0], str): |
219
|
|
|
value = [ast.literal_eval(i) for i in value] |
220
|
|
|
value = list(map(dtype, value)) |
221
|
|
|
label = key + '_params.' + type(value[0]).__name__ |
222
|
|
|
self.multi_params_dict[len(self.multi_params_dict)] = \ |
223
|
|
|
{'label': label, 'values': value} |
224
|
|
|
self.extra_dims.append(len(value)) |
225
|
|
|
return value |
226
|
|
|
|
227
|
|
|
def get_parameters(self, name): |
228
|
|
|
""" Return a plugin parameter |
229
|
|
|
|
230
|
|
|
:params str name: parameter name (dictionary key) |
231
|
|
|
:returns: the associated value in ``self.parameters`` |
232
|
|
|
:rtype: dict value |
233
|
|
|
""" |
234
|
|
|
return self.parameters[name] |
235
|
|
|
|
236
|
|
|
def base_pre_process(self): |
237
|
|
|
""" This method is called after the plugin has been created by the |
238
|
|
|
pipeline framework as a pre-processing step. |
239
|
|
|
""" |
240
|
|
|
pass |
241
|
|
|
|
242
|
|
|
def pre_process(self): |
243
|
|
|
""" This method is called immediately after base_pre_process(). """ |
244
|
|
|
pass |
245
|
|
|
|
246
|
|
|
def base_process_frames_before(self, data): |
247
|
|
|
""" This method is called before each call to process frames """ |
248
|
|
|
return data |
249
|
|
|
|
250
|
|
|
def base_process_frames_after(self, data): |
251
|
|
|
""" This method is called directly after each call to process frames \ |
252
|
|
|
and before returning the data to file.""" |
253
|
|
|
return data |
254
|
|
|
|
255
|
|
|
def plugin_process_frames(self, data): |
256
|
|
|
frames = self.base_process_frames_after(self.process_frames( |
257
|
|
|
self.base_process_frames_before(data))) |
258
|
|
|
self.pcount += 1 |
259
|
|
|
return frames |
260
|
|
|
|
261
|
|
|
def process_frames(self, data): |
262
|
|
|
""" |
263
|
|
|
This method is called after the plugin has been created by the |
264
|
|
|
pipeline framework and forms the main processing step |
265
|
|
|
|
266
|
|
|
:param data: A list of numpy arrays for each input dataset. |
267
|
|
|
:type data: list(np.array) |
268
|
|
|
""" |
269
|
|
|
|
270
|
|
|
logging.error("process frames needs to be implemented") |
271
|
|
|
raise NotImplementedError("process needs to be implemented") |
272
|
|
|
|
273
|
|
|
def post_process(self): |
274
|
|
|
""" |
275
|
|
|
This method is called after the process function in the pipeline |
276
|
|
|
framework as a post-processing step. All processes will have finished |
277
|
|
|
performing the main processing at this stage. |
278
|
|
|
|
279
|
|
|
:param exp: An experiment object, holding input and output datasets |
280
|
|
|
:type exp: experiment class instance |
281
|
|
|
""" |
282
|
|
|
pass |
283
|
|
|
|
284
|
|
|
def base_post_process(self): |
285
|
|
|
""" This method is called immediately after post_process(). """ |
286
|
|
|
pass |
287
|
|
|
|
288
|
|
|
def set_preview(self, data, params): |
289
|
|
|
if not params: |
290
|
|
|
return True |
291
|
|
|
preview = data.get_preview() |
292
|
|
|
orig_indices = preview.get_starts_stops_steps() |
293
|
|
|
nDims = len(orig_indices[0]) |
294
|
|
|
no_preview = [[0]*nDims, data.get_shape(), [1]*nDims, [1]*nDims] |
295
|
|
|
|
296
|
|
|
# Set previewing params if previewing has not already been applied to |
297
|
|
|
# the dataset. |
298
|
|
|
if no_preview == orig_indices: |
299
|
|
|
data.get_preview().revert_shape = data.get_shape() |
300
|
|
|
data.get_preview().set_preview(params) |
301
|
|
|
return True |
302
|
|
|
return False |
303
|
|
|
|
304
|
|
|
def _clean_up(self): |
305
|
|
|
""" Perform necessary plugin clean up after the plugin has completed. |
306
|
|
|
""" |
307
|
|
|
self._clone_datasets() |
308
|
|
|
self.__copy_meta_data() |
309
|
|
|
self.__set_previous_patterns() |
310
|
|
|
self.__clean_up_plugin_data() |
311
|
|
|
|
312
|
|
|
def __copy_meta_data(self): |
313
|
|
|
""" |
314
|
|
|
Copy all metadata from input datasets to output datasets, except axis |
315
|
|
|
data that is no longer valid. |
316
|
|
|
""" |
317
|
|
|
remove_keys = self.__remove_axis_data() |
318
|
|
|
in_meta_data, out_meta_data = self.get() |
319
|
|
|
copy_dict = {} |
320
|
|
|
for mData in in_meta_data: |
321
|
|
|
temp = copy.deepcopy(mData.get_dictionary()) |
322
|
|
|
copy_dict.update(temp) |
323
|
|
|
|
324
|
|
|
for i in range(len(out_meta_data)): |
325
|
|
|
temp = copy_dict.copy() |
326
|
|
|
for key in remove_keys[i]: |
327
|
|
|
if temp.get(key, None) is not None: |
328
|
|
|
del temp[key] |
329
|
|
|
temp.update(out_meta_data[i].get_dictionary()) |
330
|
|
|
out_meta_data[i]._set_dictionary(temp) |
331
|
|
|
|
332
|
|
|
def __set_previous_patterns(self): |
333
|
|
|
for data in self.get_out_datasets(): |
334
|
|
|
data._set_previous_pattern( |
335
|
|
|
copy.deepcopy(data._get_plugin_data().get_pattern())) |
336
|
|
|
|
337
|
|
|
def __remove_axis_data(self): |
338
|
|
|
""" |
339
|
|
|
Returns a list of meta_data entries corresponding to axis labels that |
340
|
|
|
are not copied over to the output datasets |
341
|
|
|
""" |
342
|
|
|
in_datasets, out_datasets = self.get_datasets() |
343
|
|
|
all_in_labels = [] |
344
|
|
|
for data in in_datasets: |
345
|
|
|
axis_keys = data.get_axis_label_keys() |
346
|
|
|
all_in_labels = all_in_labels + axis_keys |
347
|
|
|
|
348
|
|
|
remove_keys = [] |
349
|
|
|
for data in out_datasets: |
350
|
|
|
axis_keys = data.get_axis_label_keys() |
351
|
|
|
remove_keys.append(set(all_in_labels).difference(set(axis_keys))) |
352
|
|
|
|
353
|
|
|
return remove_keys |
354
|
|
|
|
355
|
|
|
def __clean_up_plugin_data(self): |
356
|
|
|
""" Remove pluginData object encapsulated in a dataset after plugin |
357
|
|
|
completion. |
358
|
|
|
""" |
359
|
|
|
in_data, out_data = self.get_datasets() |
360
|
|
|
data_object_list = in_data + out_data |
361
|
|
|
for data in data_object_list: |
362
|
|
|
data._clear_plugin_data() |
363
|
|
|
|
364
|
|
|
def _revert_preview(self, in_data): |
365
|
|
|
""" Revert dataset back to original shape if previewing was used in a |
366
|
|
|
plugin to reduce the data shape but the original data shape should be |
367
|
|
|
used thereafter. Remove previewing if it was added in the plugin. |
368
|
|
|
""" |
369
|
|
|
for data in in_data: |
370
|
|
|
if data.get_preview().revert_shape: |
371
|
|
|
data.get_preview()._unset_preview() |
372
|
|
|
|
373
|
|
|
def set_global_frame_index(self, frame_idx): |
374
|
|
|
self.global_index = frame_idx |
375
|
|
|
|
376
|
|
|
def get_global_frame_index(self): |
377
|
|
|
""" Get the global frame index. """ |
378
|
|
|
return self.global_index |
379
|
|
|
|
380
|
|
|
|
381
|
|
|
def set_current_slice_list(self, sl): |
382
|
|
|
self.slice_list = sl |
383
|
|
|
|
384
|
|
|
def get_current_slice_list(self): |
385
|
|
|
""" Get the slice list of the current frame being processed. """ |
386
|
|
|
return self.slice_list |
387
|
|
|
|
388
|
|
|
def get_slice_dir_reps(self, nData): |
389
|
|
|
""" Return the periodicity of the main slice direction. |
390
|
|
|
|
391
|
|
|
:params int nData: The number of the dataset in the list. |
392
|
|
|
""" |
393
|
|
|
slice_dir = \ |
394
|
|
|
self.get_plugin_in_datasets()[nData].get_slice_directions()[0] |
395
|
|
|
sl = [sl[slice_dir] for sl in self.slice_list] |
396
|
|
|
reps = [i for i in range(len(sl)) if sl[i] == sl[0]] |
397
|
|
|
return np.diff(reps)[0] if len(reps) > 1 else 1 |
398
|
|
|
|
399
|
|
|
def nInput_datasets(self): |
400
|
|
|
""" |
401
|
|
|
The number of datasets required as input to the plugin |
402
|
|
|
|
403
|
|
|
:returns: Number of input datasets |
404
|
|
|
|
405
|
|
|
""" |
406
|
|
|
return 1 |
407
|
|
|
|
408
|
|
|
def nOutput_datasets(self): |
409
|
|
|
""" |
410
|
|
|
The number of datasets created by the plugin |
411
|
|
|
|
412
|
|
|
:returns: Number of output datasets |
413
|
|
|
|
414
|
|
|
""" |
415
|
|
|
return 1 |
416
|
|
|
|
417
|
|
|
def nClone_datasets(self): |
418
|
|
|
""" The number of output datasets that have an clone - i.e. they take\ |
419
|
|
|
it in turns to be used as output in an iterative plugin. |
420
|
|
|
""" |
421
|
|
|
return 0 |
422
|
|
|
|
423
|
|
|
def nFrames(self): |
424
|
|
|
""" The number of frames to process during each call to process_frames. |
425
|
|
|
""" |
426
|
|
|
return 'single' |
427
|
|
|
|
428
|
|
|
def final_parameter_updates(self): |
429
|
|
|
""" An opportunity to update the parameters after they have been set. |
430
|
|
|
""" |
431
|
|
|
pass |
432
|
|
|
|
433
|
|
|
def get_citation_information(self): |
434
|
|
|
""" |
435
|
|
|
Gets the Citation Information for a plugin |
436
|
|
|
|
437
|
|
|
:returns: A populated savu.data.plugin_info.CitationInfomration |
438
|
|
|
|
439
|
|
|
""" |
440
|
|
|
return None |
441
|
|
|
|
442
|
|
|
def executive_summary(self): |
443
|
|
|
""" Provide a summary to the user for the result of the plugin. |
444
|
|
|
|
445
|
|
|
e.g. |
446
|
|
|
- Warning, the sample may have shifted during data collection |
447
|
|
|
- Filter operated normally |
448
|
|
|
|
449
|
|
|
:returns: A list of string summaries |
450
|
|
|
""" |
451
|
|
|
return ["Nothing to Report"] |
452
|
|
|
|