|
1
|
|
|
# Copyright 2014 Diamond Light Source Ltd. |
|
2
|
|
|
# |
|
3
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
4
|
|
|
# you may not use this file except in compliance with the License. |
|
5
|
|
|
# You may obtain a copy of the License at |
|
6
|
|
|
# |
|
7
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
8
|
|
|
# |
|
9
|
|
|
# Unless required by applicable law or agreed to in writing, software |
|
10
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
11
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
12
|
|
|
# See the License for the specific language governing permissions and |
|
13
|
|
|
# limitations under the License. |
|
14
|
|
|
|
|
15
|
|
|
""" |
|
16
|
|
|
.. module:: plugin |
|
17
|
|
|
:platform: Unix |
|
18
|
|
|
:synopsis: Base class for all plugins used by Savu |
|
19
|
|
|
|
|
20
|
|
|
.. moduleauthor:: Mark Basham <[email protected]> |
|
21
|
|
|
|
|
22
|
|
|
""" |
|
23
|
|
|
|
|
24
|
|
|
import copy |
|
25
|
|
|
import logging |
|
26
|
|
|
import numpy as np |
|
27
|
|
|
|
|
28
|
|
|
import savu.plugins.utils as pu |
|
29
|
|
|
from savu.plugins.plugin_datasets import PluginDatasets |
|
30
|
|
|
from savu.plugins.stats.statistics import Statistics |
|
31
|
|
|
|
|
32
|
|
|
|
|
33
|
|
|
class Plugin(PluginDatasets): |
|
34
|
|
|
|
|
35
|
|
|
def __init__(self, name="Plugin"): |
|
36
|
|
|
super(Plugin, self).__init__(name) |
|
37
|
|
|
self.name = name |
|
38
|
|
|
self.chunk = False |
|
39
|
|
|
self.slice_list = None |
|
40
|
|
|
self.global_index = None |
|
41
|
|
|
self.pcount = 0 |
|
42
|
|
|
self.exp = None |
|
43
|
|
|
self.check = False |
|
44
|
|
|
self.fixed_length = True |
|
45
|
|
|
self.parameters = {} |
|
46
|
|
|
self.tools = self._set_plugin_tools() |
|
47
|
|
|
|
|
48
|
|
|
def set_parameters(self, params): |
|
49
|
|
|
self.parameters = params |
|
50
|
|
|
|
|
51
|
|
|
def initialise(self, params, exp, check=False): |
|
52
|
|
|
self.check = check |
|
53
|
|
|
self.exp = exp |
|
54
|
|
|
self.get_plugin_tools().initialise(params) |
|
55
|
|
|
self._main_setup() |
|
56
|
|
|
|
|
57
|
|
|
def _main_setup(self): |
|
58
|
|
|
""" Performs all the required plugin setup. |
|
59
|
|
|
|
|
60
|
|
|
It sets the experiment, then the parameters and replaces the |
|
61
|
|
|
in/out_dataset strings in ``self.parameters`` with the relevant data |
|
62
|
|
|
objects. It then creates PluginData objects for each of these datasets. |
|
63
|
|
|
""" |
|
64
|
|
|
self._set_plugin_datasets() |
|
65
|
|
|
self._reset_process_frames_counter() |
|
66
|
|
|
self.stats_obj = Statistics() |
|
67
|
|
|
self.setup() |
|
68
|
|
|
self.stats_obj.setup(self) |
|
69
|
|
|
self.set_filter_padding(*(self.get_plugin_datasets())) |
|
70
|
|
|
self._finalise_plugin_datasets() |
|
71
|
|
|
self._finalise_datasets() |
|
72
|
|
|
|
|
73
|
|
|
|
|
74
|
|
|
def _reset_process_frames_counter(self): |
|
75
|
|
|
self.pcount = 0 |
|
76
|
|
|
|
|
77
|
|
|
def get_process_frames_counter(self): |
|
78
|
|
|
return self.pcount |
|
79
|
|
|
|
|
80
|
|
|
def set_filter_padding(self, in_data, out_data): |
|
81
|
|
|
""" |
|
82
|
|
|
Should be overridden to define how wide the frame should be for each |
|
83
|
|
|
input data set |
|
84
|
|
|
""" |
|
85
|
|
|
return {} |
|
86
|
|
|
|
|
87
|
|
|
def setup(self): |
|
88
|
|
|
""" |
|
89
|
|
|
This method is first to be called after the plugin has been created. |
|
90
|
|
|
It determines input/output datasets and plugin specific dataset |
|
91
|
|
|
information such as the pattern (e.g. sinogram/projection). |
|
92
|
|
|
""" |
|
93
|
|
|
logging.error("set_up needs to be implemented") |
|
94
|
|
|
raise NotImplementedError("setup needs to be implemented") |
|
95
|
|
|
|
|
96
|
|
|
def get_plugin_tools(self): |
|
97
|
|
|
return self.tools |
|
98
|
|
|
|
|
99
|
|
|
def _set_plugin_tools(self): |
|
100
|
|
|
plugin_tools_id = self.__class__.__module__ + '_tools' |
|
101
|
|
|
tool_class = pu.get_tools_class(plugin_tools_id, self) |
|
102
|
|
|
return tool_class |
|
103
|
|
|
|
|
104
|
|
|
def delete_parameter_entry(self, param): |
|
105
|
|
|
if param in list(self.parameters.keys()): |
|
106
|
|
|
del self.parameters[param] |
|
107
|
|
|
|
|
108
|
|
|
def get_parameters(self, name): |
|
109
|
|
|
""" Return a plugin parameter |
|
110
|
|
|
|
|
111
|
|
|
:params str name: parameter name (dictionary key) |
|
112
|
|
|
:returns: the associated value in ``self.parameters`` |
|
113
|
|
|
:rtype: dict value |
|
114
|
|
|
""" |
|
115
|
|
|
return self.parameters[name] |
|
116
|
|
|
|
|
117
|
|
|
def base_pre_process(self): |
|
118
|
|
|
""" This method is called after the plugin has been created by the |
|
119
|
|
|
pipeline framework as a pre-processing step. |
|
120
|
|
|
""" |
|
121
|
|
|
pass |
|
122
|
|
|
|
|
123
|
|
|
def pre_process(self): |
|
124
|
|
|
""" This method is called immediately after base_pre_process(). """ |
|
125
|
|
|
pass |
|
126
|
|
|
|
|
127
|
|
|
def base_process_frames_before(self, data): |
|
128
|
|
|
""" This method is called before each call to process frames """ |
|
129
|
|
|
return data |
|
130
|
|
|
|
|
131
|
|
|
def base_process_frames_after(self, data): |
|
132
|
|
|
""" This method is called directly after each call to process frames \ |
|
133
|
|
|
and before returning the data to file.""" |
|
134
|
|
|
return data |
|
135
|
|
|
|
|
136
|
|
|
def plugin_process_frames(self, data): |
|
137
|
|
|
data_copy = data.copy() # is it ok to copy every frame like this? Enough memory? |
|
138
|
|
|
frames = self.base_process_frames_after(self.process_frames( |
|
139
|
|
|
self.base_process_frames_before(data))) |
|
140
|
|
|
|
|
141
|
|
|
if self.stats_obj.calc_stats and self.stats_obj._stats_flag: |
|
142
|
|
|
self.stats_obj.set_slice_stats(frames, data_copy) |
|
143
|
|
|
self.pcount += 1 |
|
144
|
|
|
return frames |
|
145
|
|
|
|
|
146
|
|
|
def process_frames(self, data): |
|
147
|
|
|
""" |
|
148
|
|
|
This method is called after the plugin has been created by the |
|
149
|
|
|
pipeline framework and forms the main processing step |
|
150
|
|
|
|
|
151
|
|
|
:param data: A list of numpy arrays for each input dataset. |
|
152
|
|
|
:type data: list(np.array) |
|
153
|
|
|
""" |
|
154
|
|
|
|
|
155
|
|
|
logging.error("process frames needs to be implemented") |
|
156
|
|
|
raise NotImplementedError("process needs to be implemented") |
|
157
|
|
|
|
|
158
|
|
|
def post_process(self): |
|
159
|
|
|
""" |
|
160
|
|
|
This method is called after the process function in the pipeline |
|
161
|
|
|
framework as a post-processing step. All processes will have finished |
|
162
|
|
|
performing the main processing at this stage. |
|
163
|
|
|
|
|
164
|
|
|
:param exp: An experiment object, holding input and output datasets |
|
165
|
|
|
:type exp: experiment class instance |
|
166
|
|
|
""" |
|
167
|
|
|
pass |
|
168
|
|
|
|
|
169
|
|
|
def base_post_process(self): |
|
170
|
|
|
""" This method is called immediately after post_process(). """ |
|
171
|
|
|
if self.stats_obj.calc_stats and self.stats_obj._stats_flag: |
|
172
|
|
|
if not self.stats_obj._already_called: |
|
173
|
|
|
self.stats_obj.set_volume_stats() |
|
174
|
|
|
self.stats_obj._already_called = False |
|
175
|
|
|
pass |
|
176
|
|
|
|
|
177
|
|
|
def set_preview(self, data, params): |
|
178
|
|
|
if not params: |
|
179
|
|
|
return True |
|
180
|
|
|
preview = data.get_preview() |
|
181
|
|
|
orig_indices = preview.get_starts_stops_steps() |
|
182
|
|
|
nDims = len(orig_indices[0]) |
|
183
|
|
|
no_preview = [[0]*nDims, data.get_shape(), [1]*nDims, [1]*nDims] |
|
184
|
|
|
|
|
185
|
|
|
# Set previewing params if previewing has not already been applied to |
|
186
|
|
|
# the dataset. |
|
187
|
|
|
if no_preview == orig_indices: |
|
188
|
|
|
data.get_preview().revert_shape = data.get_shape() |
|
189
|
|
|
data.get_preview().set_preview(params) |
|
190
|
|
|
return True |
|
191
|
|
|
return False |
|
192
|
|
|
|
|
193
|
|
|
def _clean_up(self): |
|
194
|
|
|
""" Perform necessary plugin clean up after the plugin has completed. |
|
195
|
|
|
""" |
|
196
|
|
|
self._clone_datasets() |
|
197
|
|
|
self.__copy_meta_data() |
|
198
|
|
|
self.__set_previous_patterns() |
|
199
|
|
|
self.__clean_up_plugin_data() |
|
200
|
|
|
|
|
201
|
|
|
def __copy_meta_data(self): |
|
202
|
|
|
""" |
|
203
|
|
|
Copy all metadata from input datasets to output datasets, except axis |
|
204
|
|
|
data and statistics that is no longer valid. |
|
205
|
|
|
""" |
|
206
|
|
|
remove_keys = self.__remove_axis_data() |
|
207
|
|
|
for i in range(len(remove_keys)): |
|
208
|
|
|
remove_keys[i].add("stats") |
|
209
|
|
|
in_meta_data, out_meta_data = self.get() |
|
210
|
|
|
copy_dict = {} |
|
211
|
|
|
for mData in in_meta_data: |
|
212
|
|
|
temp = copy.deepcopy(mData.get_dictionary()) |
|
213
|
|
|
copy_dict.update(temp) |
|
214
|
|
|
|
|
215
|
|
|
for i in range(len(out_meta_data)): |
|
216
|
|
|
temp = copy_dict.copy() |
|
217
|
|
|
for key in remove_keys[i]: |
|
218
|
|
|
if temp.get(key, None) is not None: |
|
219
|
|
|
del temp[key] |
|
220
|
|
|
temp.update(out_meta_data[i].get_dictionary()) |
|
221
|
|
|
out_meta_data[i]._set_dictionary(temp) |
|
222
|
|
|
|
|
223
|
|
|
def __set_previous_patterns(self): |
|
224
|
|
|
for data in self.get_out_datasets(): |
|
225
|
|
|
data._set_previous_pattern( |
|
226
|
|
|
copy.deepcopy(data._get_plugin_data().get_pattern())) |
|
227
|
|
|
|
|
228
|
|
|
def __remove_axis_data(self): |
|
229
|
|
|
""" |
|
230
|
|
|
Returns a list of meta_data entries corresponding to axis labels that |
|
231
|
|
|
are not copied over to the output datasets |
|
232
|
|
|
""" |
|
233
|
|
|
in_datasets, out_datasets = self.get_datasets() |
|
234
|
|
|
all_in_labels = [] |
|
235
|
|
|
for data in in_datasets: |
|
236
|
|
|
axis_keys = data.get_axis_label_keys() |
|
237
|
|
|
all_in_labels = all_in_labels + axis_keys |
|
238
|
|
|
|
|
239
|
|
|
remove_keys = [] |
|
240
|
|
|
for data in out_datasets: |
|
241
|
|
|
axis_keys = data.get_axis_label_keys() |
|
242
|
|
|
remove_keys.append(set(all_in_labels).difference(set(axis_keys))) |
|
243
|
|
|
|
|
244
|
|
|
return remove_keys |
|
245
|
|
|
|
|
246
|
|
|
def __clean_up_plugin_data(self): |
|
247
|
|
|
""" Remove pluginData object encapsulated in a dataset after plugin |
|
248
|
|
|
completion. |
|
249
|
|
|
""" |
|
250
|
|
|
in_data, out_data = self.get_datasets() |
|
251
|
|
|
data_object_list = in_data + out_data |
|
252
|
|
|
for data in data_object_list: |
|
253
|
|
|
data._clear_plugin_data() |
|
254
|
|
|
|
|
255
|
|
|
def _revert_preview(self, in_data): |
|
256
|
|
|
""" Revert dataset back to original shape if previewing was used in a |
|
257
|
|
|
plugin to reduce the data shape but the original data shape should be |
|
258
|
|
|
used thereafter. Remove previewing if it was added in the plugin. |
|
259
|
|
|
""" |
|
260
|
|
|
for data in in_data: |
|
261
|
|
|
if data.get_preview().revert_shape: |
|
262
|
|
|
data.get_preview()._unset_preview() |
|
263
|
|
|
|
|
264
|
|
|
def set_global_frame_index(self, frame_idx): |
|
265
|
|
|
self.global_index = frame_idx |
|
266
|
|
|
|
|
267
|
|
|
def get_global_frame_index(self): |
|
268
|
|
|
""" Get the global frame index. """ |
|
269
|
|
|
return self.global_index |
|
270
|
|
|
|
|
271
|
|
|
def set_current_slice_list(self, sl): |
|
272
|
|
|
self.slice_list = sl |
|
273
|
|
|
|
|
274
|
|
|
def get_current_slice_list(self): |
|
275
|
|
|
""" Get the slice list of the current frame being processed. """ |
|
276
|
|
|
return self.slice_list |
|
277
|
|
|
|
|
278
|
|
|
def get_slice_dir_reps(self, nData): |
|
279
|
|
|
""" Return the periodicity of the main slice direction. |
|
280
|
|
|
|
|
281
|
|
|
:params int nData: The number of the dataset in the list. |
|
282
|
|
|
""" |
|
283
|
|
|
slice_dir = \ |
|
284
|
|
|
self.get_plugin_in_datasets()[nData].get_slice_directions()[0] |
|
285
|
|
|
sl = [sl[slice_dir] for sl in self.slice_list] |
|
286
|
|
|
reps = [i for i in range(len(sl)) if sl[i] == sl[0]] |
|
287
|
|
|
return np.diff(reps)[0] if len(reps) > 1 else 1 |
|
288
|
|
|
|
|
289
|
|
|
def nInput_datasets(self): |
|
290
|
|
|
""" |
|
291
|
|
|
The number of datasets required as input to the plugin |
|
292
|
|
|
|
|
293
|
|
|
:returns: Number of input datasets |
|
294
|
|
|
|
|
295
|
|
|
""" |
|
296
|
|
|
return 1 |
|
297
|
|
|
|
|
298
|
|
|
def nOutput_datasets(self): |
|
299
|
|
|
""" |
|
300
|
|
|
The number of datasets created by the plugin |
|
301
|
|
|
|
|
302
|
|
|
:returns: Number of output datasets |
|
303
|
|
|
|
|
304
|
|
|
""" |
|
305
|
|
|
return 1 |
|
306
|
|
|
|
|
307
|
|
|
def nClone_datasets(self): |
|
308
|
|
|
""" The number of output datasets that have an clone - i.e. they take\ |
|
309
|
|
|
it in turns to be used as output in an iterative plugin. |
|
310
|
|
|
""" |
|
311
|
|
|
return 0 |
|
312
|
|
|
|
|
313
|
|
|
def nFrames(self): |
|
314
|
|
|
""" The number of frames to process during each call to process_frames. |
|
315
|
|
|
""" |
|
316
|
|
|
return 'single' |
|
317
|
|
|
|
|
318
|
|
|
def final_parameter_updates(self): |
|
319
|
|
|
""" An opportunity to update the parameters after they have been set. |
|
320
|
|
|
""" |
|
321
|
|
|
pass |
|
322
|
|
|
|
|
323
|
|
|
def executive_summary(self): |
|
324
|
|
|
""" Provide a summary to the user for the result of the plugin. |
|
325
|
|
|
|
|
326
|
|
|
e.g. |
|
327
|
|
|
- Warning, the sample may have shifted during data collection |
|
328
|
|
|
- Filter operated normally |
|
329
|
|
|
|
|
330
|
|
|
:returns: A list of string summaries |
|
331
|
|
|
""" |
|
332
|
|
|
return ["Nothing to Report"] |
|
333
|
|
|
|