1
|
|
|
# Copyright 2014 Diamond Light Source Ltd. |
2
|
|
|
# |
3
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
4
|
|
|
# you may not use this file except in compliance with the License. |
5
|
|
|
# You may obtain a copy of the License at |
6
|
|
|
# |
7
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
8
|
|
|
# |
9
|
|
|
# Unless required by applicable law or agreed to in writing, software |
10
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS, |
11
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12
|
|
|
# See the License for the specific language governing permissions and |
13
|
|
|
# limitations under the License. |
14
|
|
|
|
15
|
|
|
""" |
16
|
|
|
.. module:: plugin |
17
|
|
|
:platform: Unix |
18
|
|
|
:synopsis: Base class for all plugins used by Savu |
19
|
|
|
|
20
|
|
|
.. moduleauthor:: Mark Basham <[email protected]> |
21
|
|
|
|
22
|
|
|
""" |
23
|
|
|
|
24
|
|
|
import copy |
25
|
|
|
import logging |
26
|
|
|
import numpy as np |
27
|
|
|
|
28
|
|
|
import savu.plugins.utils as pu |
29
|
|
|
from savu.plugins.plugin_datasets import PluginDatasets |
30
|
|
|
from savu.plugins.stats.statistics import Statistics |
31
|
|
|
|
32
|
|
|
|
33
|
|
|
class Plugin(PluginDatasets): |
34
|
|
|
|
35
|
|
|
def __init__(self, name="Plugin"): |
36
|
|
|
super(Plugin, self).__init__(name) |
37
|
|
|
self.name = name |
38
|
|
|
self.chunk = False |
39
|
|
|
self.slice_list = None |
40
|
|
|
self.global_index = None |
41
|
|
|
self.pcount = 0 |
42
|
|
|
self.exp = None |
43
|
|
|
self.check = False |
44
|
|
|
self.fixed_length = True |
45
|
|
|
self.parameters = {} |
46
|
|
|
self.tools = self._set_plugin_tools() |
47
|
|
|
|
48
|
|
|
def set_parameters(self, params): |
49
|
|
|
self.parameters = params |
50
|
|
|
|
51
|
|
|
def initialise(self, params, exp, check=False): |
52
|
|
|
self.check = check |
53
|
|
|
self.exp = exp |
54
|
|
|
self.get_plugin_tools().initialise(params) |
55
|
|
|
self._main_setup() |
56
|
|
|
|
57
|
|
|
def _main_setup(self): |
58
|
|
|
""" Performs all the required plugin setup. |
59
|
|
|
|
60
|
|
|
It sets the experiment, then the parameters and replaces the |
61
|
|
|
in/out_dataset strings in ``self.parameters`` with the relevant data |
62
|
|
|
objects. It then creates PluginData objects for each of these datasets. |
63
|
|
|
""" |
64
|
|
|
self._set_plugin_datasets() |
65
|
|
|
self._reset_process_frames_counter() |
66
|
|
|
self.stats_obj = Statistics() |
67
|
|
|
self.setup() |
68
|
|
|
self.stats_obj.setup(self) |
69
|
|
|
self.set_filter_padding(*(self.get_plugin_datasets())) |
70
|
|
|
self._finalise_plugin_datasets() |
71
|
|
|
self._finalise_datasets() |
72
|
|
|
|
73
|
|
|
|
74
|
|
|
def _reset_process_frames_counter(self): |
75
|
|
|
self.pcount = 0 |
76
|
|
|
|
77
|
|
|
def get_process_frames_counter(self): |
78
|
|
|
return self.pcount |
79
|
|
|
|
80
|
|
|
def set_filter_padding(self, in_data, out_data): |
81
|
|
|
""" |
82
|
|
|
Should be overridden to define how wide the frame should be for each |
83
|
|
|
input data set |
84
|
|
|
""" |
85
|
|
|
return {} |
86
|
|
|
|
87
|
|
|
def setup(self): |
88
|
|
|
""" |
89
|
|
|
This method is first to be called after the plugin has been created. |
90
|
|
|
It determines input/output datasets and plugin specific dataset |
91
|
|
|
information such as the pattern (e.g. sinogram/projection). |
92
|
|
|
""" |
93
|
|
|
logging.error("set_up needs to be implemented") |
94
|
|
|
raise NotImplementedError("setup needs to be implemented") |
95
|
|
|
|
96
|
|
|
def get_plugin_tools(self): |
97
|
|
|
return self.tools |
98
|
|
|
|
99
|
|
|
def _set_plugin_tools(self): |
100
|
|
|
plugin_tools_id = self.__class__.__module__ + '_tools' |
101
|
|
|
tool_class = pu.get_tools_class(plugin_tools_id, self) |
102
|
|
|
return tool_class |
103
|
|
|
|
104
|
|
|
def delete_parameter_entry(self, param): |
105
|
|
|
if param in list(self.parameters.keys()): |
106
|
|
|
del self.parameters[param] |
107
|
|
|
|
108
|
|
|
def get_parameters(self, name): |
109
|
|
|
""" Return a plugin parameter |
110
|
|
|
|
111
|
|
|
:params str name: parameter name (dictionary key) |
112
|
|
|
:returns: the associated value in ``self.parameters`` |
113
|
|
|
:rtype: dict value |
114
|
|
|
""" |
115
|
|
|
return self.parameters[name] |
116
|
|
|
|
117
|
|
|
def base_pre_process(self): |
118
|
|
|
""" This method is called after the plugin has been created by the |
119
|
|
|
pipeline framework as a pre-processing step. |
120
|
|
|
""" |
121
|
|
|
pass |
122
|
|
|
|
123
|
|
|
def pre_process(self): |
124
|
|
|
""" This method is called immediately after base_pre_process(). """ |
125
|
|
|
pass |
126
|
|
|
|
127
|
|
|
def base_process_frames_before(self, data): |
128
|
|
|
""" This method is called before each call to process frames """ |
129
|
|
|
return data |
130
|
|
|
|
131
|
|
|
def base_process_frames_after(self, data): |
132
|
|
|
""" This method is called directly after each call to process frames \ |
133
|
|
|
and before returning the data to file.""" |
134
|
|
|
return data |
135
|
|
|
|
136
|
|
|
def plugin_process_frames(self, data): |
137
|
|
|
data_copy = data.copy() # is it ok to copy every frame like this? Enough memory? |
138
|
|
|
frames = self.base_process_frames_after(self.process_frames( |
139
|
|
|
self.base_process_frames_before(data))) |
140
|
|
|
|
141
|
|
|
if self.stats_obj.calc_stats and self.stats_obj._stats_flag: |
142
|
|
|
self.stats_obj.set_slice_stats(frames, data_copy) |
143
|
|
|
self.pcount += 1 |
144
|
|
|
return frames |
145
|
|
|
|
146
|
|
|
def process_frames(self, data): |
147
|
|
|
""" |
148
|
|
|
This method is called after the plugin has been created by the |
149
|
|
|
pipeline framework and forms the main processing step |
150
|
|
|
|
151
|
|
|
:param data: A list of numpy arrays for each input dataset. |
152
|
|
|
:type data: list(np.array) |
153
|
|
|
""" |
154
|
|
|
|
155
|
|
|
logging.error("process frames needs to be implemented") |
156
|
|
|
raise NotImplementedError("process needs to be implemented") |
157
|
|
|
|
158
|
|
|
def post_process(self): |
159
|
|
|
""" |
160
|
|
|
This method is called after the process function in the pipeline |
161
|
|
|
framework as a post-processing step. All processes will have finished |
162
|
|
|
performing the main processing at this stage. |
163
|
|
|
|
164
|
|
|
:param exp: An experiment object, holding input and output datasets |
165
|
|
|
:type exp: experiment class instance |
166
|
|
|
""" |
167
|
|
|
pass |
168
|
|
|
|
169
|
|
|
def base_post_process(self): |
170
|
|
|
""" This method is called immediately after post_process(). """ |
171
|
|
|
if self.stats_obj.calc_stats and self.stats_obj._stats_flag: |
172
|
|
|
if not self.stats_obj._already_called: |
173
|
|
|
self.stats_obj.set_volume_stats() |
174
|
|
|
self.stats_obj._already_called = False |
175
|
|
|
pass |
176
|
|
|
|
177
|
|
|
def set_preview(self, data, params): |
178
|
|
|
if not params: |
179
|
|
|
return True |
180
|
|
|
preview = data.get_preview() |
181
|
|
|
orig_indices = preview.get_starts_stops_steps() |
182
|
|
|
nDims = len(orig_indices[0]) |
183
|
|
|
no_preview = [[0]*nDims, data.get_shape(), [1]*nDims, [1]*nDims] |
184
|
|
|
|
185
|
|
|
# Set previewing params if previewing has not already been applied to |
186
|
|
|
# the dataset. |
187
|
|
|
if no_preview == orig_indices: |
188
|
|
|
data.get_preview().revert_shape = data.get_shape() |
189
|
|
|
data.get_preview().set_preview(params) |
190
|
|
|
return True |
191
|
|
|
return False |
192
|
|
|
|
193
|
|
|
def _clean_up(self): |
194
|
|
|
""" Perform necessary plugin clean up after the plugin has completed. |
195
|
|
|
""" |
196
|
|
|
self._clone_datasets() |
197
|
|
|
self.__copy_meta_data() |
198
|
|
|
self.__set_previous_patterns() |
199
|
|
|
self.__clean_up_plugin_data() |
200
|
|
|
|
201
|
|
|
def __copy_meta_data(self): |
202
|
|
|
""" |
203
|
|
|
Copy all metadata from input datasets to output datasets, except axis |
204
|
|
|
data and statistics that is no longer valid. |
205
|
|
|
""" |
206
|
|
|
remove_keys = self.__remove_axis_data() |
207
|
|
|
for i in range(len(remove_keys)): |
208
|
|
|
remove_keys[i].add("stats") |
209
|
|
|
in_meta_data, out_meta_data = self.get() |
210
|
|
|
copy_dict = {} |
211
|
|
|
for mData in in_meta_data: |
212
|
|
|
temp = copy.deepcopy(mData.get_dictionary()) |
213
|
|
|
copy_dict.update(temp) |
214
|
|
|
|
215
|
|
|
for i in range(len(out_meta_data)): |
216
|
|
|
temp = copy_dict.copy() |
217
|
|
|
for key in remove_keys[i]: |
218
|
|
|
if temp.get(key, None) is not None: |
219
|
|
|
del temp[key] |
220
|
|
|
temp.update(out_meta_data[i].get_dictionary()) |
221
|
|
|
out_meta_data[i]._set_dictionary(temp) |
222
|
|
|
|
223
|
|
|
def __set_previous_patterns(self): |
224
|
|
|
for data in self.get_out_datasets(): |
225
|
|
|
data._set_previous_pattern( |
226
|
|
|
copy.deepcopy(data._get_plugin_data().get_pattern())) |
227
|
|
|
|
228
|
|
|
def __remove_axis_data(self): |
229
|
|
|
""" |
230
|
|
|
Returns a list of meta_data entries corresponding to axis labels that |
231
|
|
|
are not copied over to the output datasets |
232
|
|
|
""" |
233
|
|
|
in_datasets, out_datasets = self.get_datasets() |
234
|
|
|
all_in_labels = [] |
235
|
|
|
for data in in_datasets: |
236
|
|
|
axis_keys = data.get_axis_label_keys() |
237
|
|
|
all_in_labels = all_in_labels + axis_keys |
238
|
|
|
|
239
|
|
|
remove_keys = [] |
240
|
|
|
for data in out_datasets: |
241
|
|
|
axis_keys = data.get_axis_label_keys() |
242
|
|
|
remove_keys.append(set(all_in_labels).difference(set(axis_keys))) |
243
|
|
|
|
244
|
|
|
return remove_keys |
245
|
|
|
|
246
|
|
|
def __clean_up_plugin_data(self): |
247
|
|
|
""" Remove pluginData object encapsulated in a dataset after plugin |
248
|
|
|
completion. |
249
|
|
|
""" |
250
|
|
|
in_data, out_data = self.get_datasets() |
251
|
|
|
data_object_list = in_data + out_data |
252
|
|
|
for data in data_object_list: |
253
|
|
|
data._clear_plugin_data() |
254
|
|
|
|
255
|
|
|
def _revert_preview(self, in_data): |
256
|
|
|
""" Revert dataset back to original shape if previewing was used in a |
257
|
|
|
plugin to reduce the data shape but the original data shape should be |
258
|
|
|
used thereafter. Remove previewing if it was added in the plugin. |
259
|
|
|
""" |
260
|
|
|
for data in in_data: |
261
|
|
|
if data.get_preview().revert_shape: |
262
|
|
|
data.get_preview()._unset_preview() |
263
|
|
|
|
264
|
|
|
def set_global_frame_index(self, frame_idx): |
265
|
|
|
self.global_index = frame_idx |
266
|
|
|
|
267
|
|
|
def get_global_frame_index(self): |
268
|
|
|
""" Get the global frame index. """ |
269
|
|
|
return self.global_index |
270
|
|
|
|
271
|
|
|
def set_current_slice_list(self, sl): |
272
|
|
|
self.slice_list = sl |
273
|
|
|
|
274
|
|
|
def get_current_slice_list(self): |
275
|
|
|
""" Get the slice list of the current frame being processed. """ |
276
|
|
|
return self.slice_list |
277
|
|
|
|
278
|
|
|
def get_slice_dir_reps(self, nData): |
279
|
|
|
""" Return the periodicity of the main slice direction. |
280
|
|
|
|
281
|
|
|
:params int nData: The number of the dataset in the list. |
282
|
|
|
""" |
283
|
|
|
slice_dir = \ |
284
|
|
|
self.get_plugin_in_datasets()[nData].get_slice_directions()[0] |
285
|
|
|
sl = [sl[slice_dir] for sl in self.slice_list] |
286
|
|
|
reps = [i for i in range(len(sl)) if sl[i] == sl[0]] |
287
|
|
|
return np.diff(reps)[0] if len(reps) > 1 else 1 |
288
|
|
|
|
289
|
|
|
def nInput_datasets(self): |
290
|
|
|
""" |
291
|
|
|
The number of datasets required as input to the plugin |
292
|
|
|
|
293
|
|
|
:returns: Number of input datasets |
294
|
|
|
|
295
|
|
|
""" |
296
|
|
|
return 1 |
297
|
|
|
|
298
|
|
|
def nOutput_datasets(self): |
299
|
|
|
""" |
300
|
|
|
The number of datasets created by the plugin |
301
|
|
|
|
302
|
|
|
:returns: Number of output datasets |
303
|
|
|
|
304
|
|
|
""" |
305
|
|
|
return 1 |
306
|
|
|
|
307
|
|
|
def nClone_datasets(self): |
308
|
|
|
""" The number of output datasets that have an clone - i.e. they take\ |
309
|
|
|
it in turns to be used as output in an iterative plugin. |
310
|
|
|
""" |
311
|
|
|
return 0 |
312
|
|
|
|
313
|
|
|
def nFrames(self): |
314
|
|
|
""" The number of frames to process during each call to process_frames. |
315
|
|
|
""" |
316
|
|
|
return 'single' |
317
|
|
|
|
318
|
|
|
def final_parameter_updates(self): |
319
|
|
|
""" An opportunity to update the parameters after they have been set. |
320
|
|
|
""" |
321
|
|
|
pass |
322
|
|
|
|
323
|
|
|
def executive_summary(self): |
324
|
|
|
""" Provide a summary to the user for the result of the plugin. |
325
|
|
|
|
326
|
|
|
e.g. |
327
|
|
|
- Warning, the sample may have shifted during data collection |
328
|
|
|
- Filter operated normally |
329
|
|
|
|
330
|
|
|
:returns: A list of string summaries |
331
|
|
|
""" |
332
|
|
|
return ["Nothing to Report"] |
333
|
|
|
|