|
1
|
|
|
# Copyright 2014 Diamond Light Source Ltd. |
|
2
|
|
|
# |
|
3
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
4
|
|
|
# you may not use this file except in compliance with the License. |
|
5
|
|
|
# You may obtain a copy of the License at |
|
6
|
|
|
# |
|
7
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
8
|
|
|
# |
|
9
|
|
|
# Unless required by applicable law or agreed to in writing, software |
|
10
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
11
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
12
|
|
|
# See the License for the specific language governing permissions and |
|
13
|
|
|
# limitations under the License. |
|
14
|
|
|
|
|
15
|
|
|
""" |
|
16
|
|
|
.. module:: plugin |
|
17
|
|
|
:platform: Unix |
|
18
|
|
|
:synopsis: Base class for all plugins used by Savu |
|
19
|
|
|
|
|
20
|
|
|
.. moduleauthor:: Mark Basham <[email protected]> |
|
21
|
|
|
|
|
22
|
|
|
""" |
|
23
|
|
|
|
|
24
|
|
|
import ast |
|
25
|
|
|
import copy |
|
26
|
|
|
import logging |
|
27
|
|
|
import inspect |
|
28
|
|
|
import numpy as np |
|
29
|
|
|
|
|
30
|
|
|
import savu.plugins.docstring_parser as doc |
|
31
|
|
|
from savu.plugins.plugin_datasets import PluginDatasets |
|
32
|
|
|
|
|
33
|
|
|
|
|
34
|
|
|
class Plugin(PluginDatasets): |
|
35
|
|
|
""" |
|
36
|
|
|
The base class from which all plugins should inherit. |
|
37
|
|
|
:param in_datasets: Create a list of the dataset(s) to \ |
|
38
|
|
|
process. Default: []. |
|
39
|
|
|
:param out_datasets: Create a list of the dataset(s) to \ |
|
40
|
|
|
create. Default: []. |
|
41
|
|
|
""" |
|
42
|
|
|
|
|
43
|
|
|
def __init__(self, name="Plugin"): |
|
44
|
|
|
super(Plugin, self).__init__(name) |
|
45
|
|
|
self.name = name |
|
46
|
|
|
self.parameters = {} |
|
47
|
|
|
self.parameters_types = {} |
|
48
|
|
|
self.parameters_desc = {} |
|
49
|
|
|
self.parameters_hide = [] |
|
50
|
|
|
self.parameters_user = [] |
|
51
|
|
|
self.chunk = False |
|
52
|
|
|
self.docstring_info = {} |
|
53
|
|
|
self.slice_list = None |
|
54
|
|
|
self.global_index = None |
|
55
|
|
|
self.pcount = 0 |
|
56
|
|
|
self.exp = None |
|
57
|
|
|
self.check = False |
|
58
|
|
|
self.fixed_length = True |
|
59
|
|
|
self.parameters = {} |
|
60
|
|
|
self.tools = self._set_plugin_tools() |
|
61
|
|
|
|
|
62
|
|
|
def set_parameters(self, params): |
|
63
|
|
|
self.parameters = params |
|
64
|
|
|
|
|
65
|
|
|
def initialise(self, params, exp, check=False): |
|
66
|
|
|
self.check = check |
|
67
|
|
|
self.exp = exp |
|
68
|
|
|
self._set_parameters(copy.deepcopy(params)) |
|
69
|
|
|
self._main_setup() |
|
70
|
|
|
|
|
71
|
|
|
def _main_setup(self): |
|
72
|
|
|
""" Performs all the required plugin setup. |
|
73
|
|
|
|
|
74
|
|
|
It sets the experiment, then the parameters and replaces the |
|
75
|
|
|
in/out_dataset strings in ``self.parameters`` with the relevant data |
|
76
|
|
|
objects. It then creates PluginData objects for each of these datasets. |
|
77
|
|
|
""" |
|
78
|
|
|
self._set_plugin_datasets() |
|
79
|
|
|
self._reset_process_frames_counter() |
|
80
|
|
|
self.setup() |
|
81
|
|
|
self.set_filter_padding(*(self.get_plugin_datasets())) |
|
82
|
|
|
self._finalise_plugin_datasets() |
|
83
|
|
|
self._finalise_datasets() |
|
84
|
|
|
|
|
85
|
|
|
def _reset_process_frames_counter(self): |
|
86
|
|
|
self.pcount = 0 |
|
87
|
|
|
|
|
88
|
|
|
def get_process_frames_counter(self): |
|
89
|
|
|
return self.pcount |
|
90
|
|
|
|
|
91
|
|
|
def _set_parameters_this_instance(self, indices): |
|
92
|
|
|
""" Determines the parameters for this instance of the plugin, in the |
|
93
|
|
|
case of parameter tuning. |
|
94
|
|
|
|
|
95
|
|
|
param np.ndarray indices: the index of the current value in the |
|
96
|
|
|
parameter tuning list. |
|
97
|
|
|
""" |
|
98
|
|
|
dims = set(self.multi_params_dict.keys()) |
|
99
|
|
|
count = 0 |
|
100
|
|
|
for dim in dims: |
|
101
|
|
|
info = self.multi_params_dict[dim] |
|
102
|
|
|
name = info['label'].split('_param')[0] |
|
103
|
|
|
self.parameters[name] = info['values'][indices[count]] |
|
104
|
|
|
count += 1 |
|
105
|
|
|
|
|
106
|
|
|
def set_filter_padding(self, in_data, out_data): |
|
107
|
|
|
""" |
|
108
|
|
|
Should be overridden to define how wide the frame should be for each |
|
109
|
|
|
input data set |
|
110
|
|
|
""" |
|
111
|
|
|
return {} |
|
112
|
|
|
|
|
113
|
|
|
def setup(self): |
|
114
|
|
|
""" |
|
115
|
|
|
This method is first to be called after the plugin has been created. |
|
116
|
|
|
It determines input/output datasets and plugin specific dataset |
|
117
|
|
|
information such as the pattern (e.g. sinogram/projection). |
|
118
|
|
|
""" |
|
119
|
|
|
logging.error("set_up needs to be implemented") |
|
120
|
|
|
raise NotImplementedError("setup needs to be implemented") |
|
121
|
|
|
|
|
122
|
|
|
def _populate_default_parameters(self): |
|
123
|
|
|
""" |
|
124
|
|
|
This method should populate all the required parameters with default |
|
125
|
|
|
values. it is used for checking to see if new parameter values are |
|
126
|
|
|
appropriate |
|
127
|
|
|
|
|
128
|
|
|
It makes use of the classes including parameter information in the |
|
129
|
|
|
class docstring such as this |
|
130
|
|
|
|
|
131
|
|
|
:param error_threshold: Convergence threshold. Default: 0.001. |
|
132
|
|
|
""" |
|
133
|
|
|
hidden_items = [] |
|
134
|
|
|
user_items = [] |
|
135
|
|
|
params = [] |
|
136
|
|
|
not_params = [] |
|
137
|
|
|
for clazz in inspect.getmro(self.__class__)[::-1]: |
|
138
|
|
|
if clazz != object: |
|
139
|
|
|
desc = doc.find_args(clazz, self) |
|
140
|
|
|
self.docstring_info['warn'] = desc['warn'] |
|
141
|
|
|
self.docstring_info['info'] = desc['info'] |
|
142
|
|
|
self.docstring_info['synopsis'] = desc['synopsis'] |
|
143
|
|
|
params.extend(desc['param']) |
|
144
|
|
|
if desc['hide_param']: |
|
145
|
|
|
hidden_items.extend(desc['hide_param']) |
|
146
|
|
|
if desc['user_param']: |
|
147
|
|
|
user_items.extend(desc['user_param']) |
|
148
|
|
|
if desc['not_param']: |
|
149
|
|
|
not_params.extend(desc['not_param']) |
|
150
|
|
|
self._add_item(params, not_params) |
|
151
|
|
|
user_items = [u for u in user_items if u not in not_params] |
|
152
|
|
|
hidden_items = [h for h in hidden_items if h not in not_params] |
|
153
|
|
|
user_items = list(set(user_items).difference(set(hidden_items))) |
|
154
|
|
|
self.parameters_hide = hidden_items |
|
155
|
|
|
self.parameters_user = user_items |
|
156
|
|
|
self.final_parameter_updates() |
|
157
|
|
|
|
|
158
|
|
|
def _add_item(self, item_list, not_list): |
|
159
|
|
|
true_list = [i for i in item_list if i['name'] not in not_list] |
|
160
|
|
|
for item in true_list: |
|
161
|
|
|
self.parameters[item['name']] = item['default'] |
|
162
|
|
|
self.parameters_types[item['name']] = item['dtype'] |
|
163
|
|
|
self.parameters_desc[item['name']] = item['desc'] |
|
164
|
|
|
|
|
165
|
|
|
def delete_parameter_entry(self, param): |
|
166
|
|
|
if param in list(self.parameters.keys()): |
|
167
|
|
|
del self.parameters[param] |
|
168
|
|
|
del self.parameters_types[param] |
|
169
|
|
|
del self.parameters_desc[param] |
|
170
|
|
|
|
|
171
|
|
|
def initialise_parameters(self): |
|
172
|
|
|
self.parameters = {} |
|
173
|
|
|
self.parameters_types = {} |
|
174
|
|
|
self._populate_default_parameters() |
|
175
|
|
|
self.multi_params_dict = {} |
|
176
|
|
|
self.extra_dims = [] |
|
177
|
|
|
|
|
178
|
|
|
def _set_parameters(self, parameters): |
|
179
|
|
|
""" |
|
180
|
|
|
This method is called after the plugin has been created by the |
|
181
|
|
|
pipeline framework. It replaces ``self.parameters`` default values |
|
182
|
|
|
with those given in the input process list. |
|
183
|
|
|
|
|
184
|
|
|
:param dict parameters: A dictionary of the parameters for this \ |
|
185
|
|
|
plugin, or None if no customisation is required. |
|
186
|
|
|
""" |
|
187
|
|
|
self.initialise_parameters() |
|
188
|
|
|
# reverse sorting added on Python 3 conversion to make the behaviour |
|
189
|
|
|
# similar (hopefully the same) as on Python 2 |
|
190
|
|
|
for key in parameters.keys(): |
|
191
|
|
|
if key in self.parameters.keys(): |
|
192
|
|
|
value = self.__convert_multi_params(parameters[key], key) |
|
193
|
|
|
self.parameters[key] = value |
|
194
|
|
|
else: |
|
195
|
|
|
error = ("Parameter '%s' is not valid for plugin %s. \nTry " |
|
196
|
|
|
"opening and re-saving the process list in the " |
|
197
|
|
|
"configurator to auto remove \nobsolete parameters." |
|
198
|
|
|
% (key, self.name)) |
|
199
|
|
|
raise ValueError(error) |
|
200
|
|
|
|
|
201
|
|
|
def __convert_multi_params(self, value, key): |
|
202
|
|
|
""" Set up parameter tuning. |
|
203
|
|
|
|
|
204
|
|
|
Convert parameter value to a list if it uses parameter tuning and set |
|
205
|
|
|
associated parameters, so the framework knows the new size of the data |
|
206
|
|
|
and which plugins to re-run. |
|
207
|
|
|
""" |
|
208
|
|
|
dtype = self.parameters_types[key] |
|
209
|
|
|
if isinstance(value, str) and ';' in value: |
|
210
|
|
|
value = value.split(';') |
|
211
|
|
|
if ":" in value[0]: |
|
212
|
|
|
seq = value[0].split(':') |
|
213
|
|
|
seq = [eval(s) for s in seq] |
|
214
|
|
|
value = list(np.arange(seq[0], seq[1], seq[2])) |
|
215
|
|
|
if len(value) == 0: |
|
|
|
|
|
|
216
|
|
|
raise RuntimeError( |
|
217
|
|
|
'No values for tuned parameter "{}", ' |
|
218
|
|
|
'ensure start:stop:step; values are valid.'.format(key)) |
|
219
|
|
|
if not isinstance(value[0], dtype): |
|
220
|
|
|
try: |
|
221
|
|
|
value.remove('') |
|
222
|
|
|
except Exception: |
|
223
|
|
|
pass |
|
224
|
|
|
if isinstance(value[0], str): |
|
225
|
|
|
value = [ast.literal_eval(i) for i in value] |
|
226
|
|
|
value = list(map(dtype, value)) |
|
227
|
|
|
label = key + '_params.' + type(value[0]).__name__ |
|
228
|
|
|
self.multi_params_dict[len(self.multi_params_dict)] = \ |
|
229
|
|
|
{'label': label, 'values': value} |
|
230
|
|
|
self.extra_dims.append(len(value)) |
|
231
|
|
|
return value |
|
232
|
|
|
|
|
233
|
|
|
def get_parameters(self, name): |
|
234
|
|
|
""" Return a plugin parameter |
|
235
|
|
|
|
|
236
|
|
|
:params str name: parameter name (dictionary key) |
|
237
|
|
|
:returns: the associated value in ``self.parameters`` |
|
238
|
|
|
:rtype: dict value |
|
239
|
|
|
""" |
|
240
|
|
|
return self.parameters[name] |
|
241
|
|
|
|
|
242
|
|
|
def base_pre_process(self): |
|
243
|
|
|
""" This method is called after the plugin has been created by the |
|
244
|
|
|
pipeline framework as a pre-processing step. |
|
245
|
|
|
""" |
|
246
|
|
|
pass |
|
247
|
|
|
|
|
248
|
|
|
def pre_process(self): |
|
249
|
|
|
""" This method is called immediately after base_pre_process(). """ |
|
250
|
|
|
pass |
|
251
|
|
|
|
|
252
|
|
|
def base_process_frames_before(self, data): |
|
253
|
|
|
""" This method is called before each call to process frames """ |
|
254
|
|
|
return data |
|
255
|
|
|
|
|
256
|
|
|
def base_process_frames_after(self, data): |
|
257
|
|
|
""" This method is called directly after each call to process frames \ |
|
258
|
|
|
and before returning the data to file.""" |
|
259
|
|
|
return data |
|
260
|
|
|
|
|
261
|
|
|
def plugin_process_frames(self, data): |
|
262
|
|
|
frames = self.base_process_frames_after(self.process_frames( |
|
263
|
|
|
self.base_process_frames_before(data))) |
|
264
|
|
|
self.pcount += 1 |
|
265
|
|
|
return frames |
|
266
|
|
|
|
|
267
|
|
|
def process_frames(self, data): |
|
268
|
|
|
""" |
|
269
|
|
|
This method is called after the plugin has been created by the |
|
270
|
|
|
pipeline framework and forms the main processing step |
|
271
|
|
|
|
|
272
|
|
|
:param data: A list of numpy arrays for each input dataset. |
|
273
|
|
|
:type data: list(np.array) |
|
274
|
|
|
""" |
|
275
|
|
|
|
|
276
|
|
|
logging.error("process frames needs to be implemented") |
|
277
|
|
|
raise NotImplementedError("process needs to be implemented") |
|
278
|
|
|
|
|
279
|
|
|
def post_process(self): |
|
280
|
|
|
""" |
|
281
|
|
|
This method is called after the process function in the pipeline |
|
282
|
|
|
framework as a post-processing step. All processes will have finished |
|
283
|
|
|
performing the main processing at this stage. |
|
284
|
|
|
|
|
285
|
|
|
:param exp: An experiment object, holding input and output datasets |
|
286
|
|
|
:type exp: experiment class instance |
|
287
|
|
|
""" |
|
288
|
|
|
pass |
|
289
|
|
|
|
|
290
|
|
|
def base_post_process(self): |
|
291
|
|
|
""" This method is called immediately after post_process(). """ |
|
292
|
|
|
pass |
|
293
|
|
|
|
|
294
|
|
|
def set_preview(self, data, params): |
|
295
|
|
|
if not params: |
|
296
|
|
|
return True |
|
297
|
|
|
preview = data.get_preview() |
|
298
|
|
|
orig_indices = preview.get_starts_stops_steps() |
|
299
|
|
|
nDims = len(orig_indices[0]) |
|
300
|
|
|
no_preview = [[0]*nDims, data.get_shape(), [1]*nDims, [1]*nDims] |
|
301
|
|
|
|
|
302
|
|
|
# Set previewing params if previewing has not already been applied to |
|
303
|
|
|
# the dataset. |
|
304
|
|
|
if no_preview == orig_indices: |
|
305
|
|
|
data.get_preview().revert_shape = data.get_shape() |
|
306
|
|
|
data.get_preview().set_preview(params) |
|
307
|
|
|
return True |
|
308
|
|
|
return False |
|
309
|
|
|
|
|
310
|
|
|
def _clean_up(self): |
|
311
|
|
|
""" Perform necessary plugin clean up after the plugin has completed. |
|
312
|
|
|
""" |
|
313
|
|
|
self._clone_datasets() |
|
314
|
|
|
self.__copy_meta_data() |
|
315
|
|
|
self.__set_previous_patterns() |
|
316
|
|
|
self.__clean_up_plugin_data() |
|
317
|
|
|
|
|
318
|
|
|
def __copy_meta_data(self): |
|
319
|
|
|
""" |
|
320
|
|
|
Copy all metadata from input datasets to output datasets, except axis |
|
321
|
|
|
data that is no longer valid. |
|
322
|
|
|
""" |
|
323
|
|
|
remove_keys = self.__remove_axis_data() |
|
324
|
|
|
in_meta_data, out_meta_data = self.get() |
|
325
|
|
|
copy_dict = {} |
|
326
|
|
|
for mData in in_meta_data: |
|
327
|
|
|
temp = copy.deepcopy(mData.get_dictionary()) |
|
328
|
|
|
copy_dict.update(temp) |
|
329
|
|
|
|
|
330
|
|
|
for i in range(len(out_meta_data)): |
|
331
|
|
|
temp = copy_dict.copy() |
|
332
|
|
|
for key in remove_keys[i]: |
|
333
|
|
|
if temp.get(key, None) is not None: |
|
334
|
|
|
del temp[key] |
|
335
|
|
|
temp.update(out_meta_data[i].get_dictionary()) |
|
336
|
|
|
out_meta_data[i]._set_dictionary(temp) |
|
337
|
|
|
|
|
338
|
|
|
def __set_previous_patterns(self): |
|
339
|
|
|
for data in self.get_out_datasets(): |
|
340
|
|
|
data._set_previous_pattern( |
|
341
|
|
|
copy.deepcopy(data._get_plugin_data().get_pattern())) |
|
342
|
|
|
|
|
343
|
|
|
def __remove_axis_data(self): |
|
344
|
|
|
""" |
|
345
|
|
|
Returns a list of meta_data entries corresponding to axis labels that |
|
346
|
|
|
are not copied over to the output datasets |
|
347
|
|
|
""" |
|
348
|
|
|
in_datasets, out_datasets = self.get_datasets() |
|
349
|
|
|
all_in_labels = [] |
|
350
|
|
|
for data in in_datasets: |
|
351
|
|
|
axis_keys = data.get_axis_label_keys() |
|
352
|
|
|
all_in_labels = all_in_labels + axis_keys |
|
353
|
|
|
|
|
354
|
|
|
remove_keys = [] |
|
355
|
|
|
for data in out_datasets: |
|
356
|
|
|
axis_keys = data.get_axis_label_keys() |
|
357
|
|
|
remove_keys.append(set(all_in_labels).difference(set(axis_keys))) |
|
358
|
|
|
|
|
359
|
|
|
return remove_keys |
|
360
|
|
|
|
|
361
|
|
|
def __clean_up_plugin_data(self): |
|
362
|
|
|
""" Remove pluginData object encapsulated in a dataset after plugin |
|
363
|
|
|
completion. |
|
364
|
|
|
""" |
|
365
|
|
|
in_data, out_data = self.get_datasets() |
|
366
|
|
|
data_object_list = in_data + out_data |
|
367
|
|
|
for data in data_object_list: |
|
368
|
|
|
data._clear_plugin_data() |
|
369
|
|
|
|
|
370
|
|
|
def _revert_preview(self, in_data): |
|
371
|
|
|
""" Revert dataset back to original shape if previewing was used in a |
|
372
|
|
|
plugin to reduce the data shape but the original data shape should be |
|
373
|
|
|
used thereafter. Remove previewing if it was added in the plugin. |
|
374
|
|
|
""" |
|
375
|
|
|
for data in in_data: |
|
376
|
|
|
if data.get_preview().revert_shape: |
|
377
|
|
|
data.get_preview()._unset_preview() |
|
378
|
|
|
|
|
379
|
|
|
def set_global_frame_index(self, frame_idx): |
|
380
|
|
|
self.global_index = frame_idx |
|
381
|
|
|
|
|
382
|
|
|
def get_global_frame_index(self): |
|
383
|
|
|
""" Get the global frame index. """ |
|
384
|
|
|
return self.global_index |
|
385
|
|
|
|
|
386
|
|
|
|
|
387
|
|
|
def set_current_slice_list(self, sl): |
|
388
|
|
|
self.slice_list = sl |
|
389
|
|
|
|
|
390
|
|
|
def get_current_slice_list(self): |
|
391
|
|
|
""" Get the slice list of the current frame being processed. """ |
|
392
|
|
|
return self.slice_list |
|
393
|
|
|
|
|
394
|
|
|
def get_slice_dir_reps(self, nData): |
|
395
|
|
|
""" Return the periodicity of the main slice direction. |
|
396
|
|
|
|
|
397
|
|
|
:params int nData: The number of the dataset in the list. |
|
398
|
|
|
""" |
|
399
|
|
|
slice_dir = \ |
|
400
|
|
|
self.get_plugin_in_datasets()[nData].get_slice_directions()[0] |
|
401
|
|
|
sl = [sl[slice_dir] for sl in self.slice_list] |
|
402
|
|
|
reps = [i for i in range(len(sl)) if sl[i] == sl[0]] |
|
403
|
|
|
return np.diff(reps)[0] if len(reps) > 1 else 1 |
|
404
|
|
|
|
|
405
|
|
|
def nInput_datasets(self): |
|
406
|
|
|
""" |
|
407
|
|
|
The number of datasets required as input to the plugin |
|
408
|
|
|
|
|
409
|
|
|
:returns: Number of input datasets |
|
410
|
|
|
|
|
411
|
|
|
""" |
|
412
|
|
|
return 1 |
|
413
|
|
|
|
|
414
|
|
|
def nOutput_datasets(self): |
|
415
|
|
|
""" |
|
416
|
|
|
The number of datasets created by the plugin |
|
417
|
|
|
|
|
418
|
|
|
:returns: Number of output datasets |
|
419
|
|
|
|
|
420
|
|
|
""" |
|
421
|
|
|
return 1 |
|
422
|
|
|
|
|
423
|
|
|
def nClone_datasets(self): |
|
424
|
|
|
""" The number of output datasets that have an clone - i.e. they take\ |
|
425
|
|
|
it in turns to be used as output in an iterative plugin. |
|
426
|
|
|
""" |
|
427
|
|
|
return 0 |
|
428
|
|
|
|
|
429
|
|
|
def nFrames(self): |
|
430
|
|
|
""" The number of frames to process during each call to process_frames. |
|
431
|
|
|
""" |
|
432
|
|
|
return 'single' |
|
433
|
|
|
|
|
434
|
|
|
def final_parameter_updates(self): |
|
435
|
|
|
""" An opportunity to update the parameters after they have been set. |
|
436
|
|
|
""" |
|
437
|
|
|
pass |
|
438
|
|
|
|
|
439
|
|
|
def get_citation_information(self): |
|
440
|
|
|
""" |
|
441
|
|
|
Gets the Citation Information for a plugin |
|
442
|
|
|
|
|
443
|
|
|
:returns: A populated savu.data.plugin_info.CitationInfomration |
|
444
|
|
|
|
|
445
|
|
|
""" |
|
446
|
|
|
return None |
|
447
|
|
|
|
|
448
|
|
|
def executive_summary(self): |
|
449
|
|
|
""" Provide a summary to the user for the result of the plugin. |
|
450
|
|
|
|
|
451
|
|
|
e.g. |
|
452
|
|
|
- Warning, the sample may have shifted during data collection |
|
453
|
|
|
- Filter operated normally |
|
454
|
|
|
|
|
455
|
|
|
:returns: A list of string summaries |
|
456
|
|
|
""" |
|
457
|
|
|
return ["Nothing to Report"] |
|
458
|
|
|
|