|
1
|
|
|
# Copyright 201i Diamond Light Source Ltd. |
|
2
|
|
|
# |
|
3
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
4
|
|
|
# you may not use this file except in compliance with the License. |
|
5
|
|
|
# You may obtain a copy of the License at |
|
6
|
|
|
# |
|
7
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
8
|
|
|
# |
|
9
|
|
|
# Unless required by applicable law or agreed to in writing, software |
|
10
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
11
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
12
|
|
|
# See the License for the specific language governing permissions and |
|
13
|
|
|
# limitations under the License. |
|
14
|
|
|
|
|
15
|
|
|
""" |
|
16
|
|
|
.. module:: plugin_data |
|
17
|
|
|
:platform: Unix |
|
18
|
|
|
:synopsis: Contains the PluginData class. Each Data set used in a plugin \ |
|
19
|
|
|
has a PluginData object encapsulated within it, for the duration of a \ |
|
20
|
|
|
plugin run. |
|
21
|
|
|
|
|
22
|
|
|
.. moduleauthor:: Nicola Wadeson <[email protected]> |
|
23
|
|
|
|
|
24
|
|
|
""" |
|
25
|
|
|
import sys |
|
26
|
|
|
import copy |
|
27
|
|
|
import h5py |
|
28
|
|
|
import logging |
|
29
|
|
|
import numpy as np |
|
30
|
|
|
from fractions import gcd |
|
31
|
|
|
|
|
32
|
|
|
from savu.data.meta_data import MetaData |
|
33
|
|
|
from savu.data.data_structures.data_add_ons import Padding |
|
34
|
|
|
|
|
35
|
|
|
|
|
36
|
|
|
class PluginData(object): |
|
37
|
|
|
""" The PluginData class contains plugin specific information about a Data |
|
38
|
|
|
object for the duration of a plugin. An instance of the class is |
|
39
|
|
|
encapsulated inside the Data object during the plugin run |
|
40
|
|
|
""" |
|
41
|
|
|
|
|
42
|
|
|
def __init__(self, data_obj, plugin=None): |
|
43
|
|
|
self.data_obj = data_obj |
|
44
|
|
|
self._preview = None |
|
45
|
|
|
self.data_obj._set_plugin_data(self) |
|
46
|
|
|
self.meta_data = MetaData() |
|
47
|
|
|
self.padding = None |
|
48
|
|
|
self.pad_dict = None |
|
49
|
|
|
self.shape = None |
|
50
|
|
|
self.core_shape = None |
|
51
|
|
|
self.multi_params = {} |
|
52
|
|
|
self.extra_dims = [] |
|
53
|
|
|
self._plugin = plugin |
|
54
|
|
|
self.fixed_dims = True |
|
55
|
|
|
self.split = None |
|
56
|
|
|
self.boundary_padding = None |
|
57
|
|
|
self.no_squeeze = False |
|
58
|
|
|
self.pre_tuning_shape = None |
|
59
|
|
|
self._frame_limit = None |
|
60
|
|
|
self._increase_rank = 0 |
|
61
|
|
|
|
|
62
|
|
|
def _get_preview(self): |
|
63
|
|
|
return self._preview |
|
64
|
|
|
|
|
65
|
|
|
def get_total_frames(self): |
|
66
|
|
|
""" Get the total number of frames to process (all MPI processes). |
|
67
|
|
|
|
|
68
|
|
|
:returns: Number of frames |
|
69
|
|
|
:rtype: int |
|
70
|
|
|
""" |
|
71
|
|
|
temp = 1 |
|
72
|
|
|
slice_dir = \ |
|
73
|
|
|
self.data_obj.get_data_patterns()[ |
|
74
|
|
|
self.get_pattern_name()]["slice_dims"] |
|
75
|
|
|
for tslice in slice_dir: |
|
76
|
|
|
temp *= self.data_obj.get_shape()[tslice] |
|
77
|
|
|
return temp |
|
78
|
|
|
|
|
79
|
|
|
def __set_pattern(self, name, first_sdim=None): |
|
80
|
|
|
""" Set the pattern related information in the meta data dict. |
|
81
|
|
|
""" |
|
82
|
|
|
pattern = self.data_obj.get_data_patterns()[name] |
|
83
|
|
|
self.meta_data.set("name", name) |
|
84
|
|
|
self.meta_data.set("core_dims", pattern['core_dims']) |
|
85
|
|
|
self.__set_slice_dimensions(first_sdim=first_sdim) |
|
86
|
|
|
|
|
87
|
|
|
def get_pattern_name(self): |
|
88
|
|
|
""" Get the pattern name. |
|
89
|
|
|
|
|
90
|
|
|
:returns: the pattern name |
|
91
|
|
|
:rtype: str |
|
92
|
|
|
""" |
|
93
|
|
|
try: |
|
94
|
|
|
name = self.meta_data.get("name") |
|
95
|
|
|
return name |
|
96
|
|
|
except KeyError: |
|
97
|
|
|
raise Exception("The pattern name has not been set.") |
|
98
|
|
|
|
|
99
|
|
|
def get_pattern(self): |
|
100
|
|
|
""" Get the current pattern. |
|
101
|
|
|
|
|
102
|
|
|
:returns: dict of the pattern name against the pattern. |
|
103
|
|
|
:rtype: dict |
|
104
|
|
|
""" |
|
105
|
|
|
pattern_name = self.get_pattern_name() |
|
106
|
|
|
return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]} |
|
107
|
|
|
|
|
108
|
|
|
def __set_shape(self): |
|
109
|
|
|
""" Set the shape of the plugin data processing chunk. |
|
110
|
|
|
""" |
|
111
|
|
|
core_dir = self.data_obj.get_core_dimensions() |
|
112
|
|
|
slice_dir = self.data_obj.get_slice_dimensions() |
|
113
|
|
|
dirs = list(set(core_dir + (slice_dir[0],))) |
|
114
|
|
|
slice_idx = dirs.index(slice_dir[0]) |
|
115
|
|
|
dshape = self.data_obj.get_shape() |
|
116
|
|
|
shape = [] |
|
117
|
|
|
for core in set(core_dir): |
|
118
|
|
|
shape.append(dshape[core]) |
|
119
|
|
|
self.__set_core_shape(tuple(shape)) |
|
120
|
|
|
|
|
121
|
|
|
mfp = self._get_max_frames_process() |
|
122
|
|
|
if mfp > 1 or self._get_no_squeeze(): |
|
123
|
|
|
shape.insert(slice_idx, mfp) |
|
124
|
|
|
self.shape = tuple(shape) |
|
125
|
|
|
|
|
126
|
|
|
def _set_shape_transfer(self, slice_size): |
|
127
|
|
|
dshape = self.data_obj.get_shape() |
|
128
|
|
|
shape_before_tuning = self._get_shape_before_tuning() |
|
129
|
|
|
add = [1]*(len(dshape) - len(shape_before_tuning)) |
|
130
|
|
|
slice_size = slice_size + add |
|
131
|
|
|
|
|
132
|
|
|
core_dir = self.data_obj.get_core_dimensions() |
|
133
|
|
|
slice_dir = self.data_obj.get_slice_dimensions() |
|
134
|
|
|
shape = [None]*len(dshape) |
|
135
|
|
|
for dim in core_dir: |
|
136
|
|
|
shape[dim] = dshape[dim] |
|
137
|
|
|
i = 0 |
|
138
|
|
|
for dim in slice_dir: |
|
139
|
|
|
shape[dim] = slice_size[i] |
|
140
|
|
|
i += 1 |
|
141
|
|
|
return tuple(shape) |
|
142
|
|
|
|
|
143
|
|
|
def __get_slice_size(self, mft): |
|
144
|
|
|
""" Calculate the number of frames transfer in each dimension given |
|
145
|
|
|
mft. """ |
|
146
|
|
|
dshape = list(self.data_obj.get_shape()) |
|
147
|
|
|
|
|
148
|
|
|
if 'fixed_dimensions' in list(self.meta_data.get_dictionary().keys()): |
|
149
|
|
|
fixed_dims = self.meta_data.get('fixed_dimensions') |
|
150
|
|
|
for d in fixed_dims: |
|
151
|
|
|
dshape[d] = 1 |
|
152
|
|
|
|
|
153
|
|
|
dshape = [dshape[i] for i in self.meta_data.get('slice_dims')] |
|
154
|
|
|
size_list = [1]*len(dshape) |
|
155
|
|
|
i = 0 |
|
156
|
|
|
|
|
157
|
|
|
while(mft > 1 and i < len(size_list)): |
|
158
|
|
|
size_list[i] = min(dshape[i], mft) |
|
159
|
|
|
mft //= np.prod(size_list) if np.prod(size_list) > 1 else 1 |
|
160
|
|
|
i += 1 |
|
161
|
|
|
|
|
162
|
|
|
# case of fixed integer max_frames, where max_frames > nSlices |
|
163
|
|
|
if mft > 1: |
|
164
|
|
|
size_list[0] *= mft |
|
165
|
|
|
|
|
166
|
|
|
self.meta_data.set('size_list', size_list) |
|
167
|
|
|
return size_list |
|
168
|
|
|
|
|
169
|
|
|
def set_bytes_per_frame(self): |
|
170
|
|
|
""" Return the size of a single frame in bytes. """ |
|
171
|
|
|
nBytes = self.data_obj.get_itemsize() |
|
172
|
|
|
dims = list(self.get_pattern().values())[0]['core_dims'] |
|
173
|
|
|
frame_shape = [self.data_obj.get_shape()[d] for d in dims] |
|
174
|
|
|
b_per_f = np.prod(frame_shape)*nBytes |
|
175
|
|
|
return frame_shape, b_per_f |
|
176
|
|
|
|
|
177
|
|
|
def get_shape(self): |
|
178
|
|
|
""" Get the shape of the data (without padding) that is passed to the |
|
179
|
|
|
plugin process_frames method. |
|
180
|
|
|
""" |
|
181
|
|
|
return self.shape |
|
182
|
|
|
|
|
183
|
|
|
def _set_padded_shape(self): |
|
184
|
|
|
pass |
|
185
|
|
|
|
|
186
|
|
|
def get_padded_shape(self): |
|
187
|
|
|
""" Get the shape of the data (with padding) that is passed to the |
|
188
|
|
|
plugin process_frames method. |
|
189
|
|
|
""" |
|
190
|
|
|
return self.shape |
|
191
|
|
|
|
|
192
|
|
|
def get_shape_transfer(self): |
|
193
|
|
|
""" Get the shape of the plugin data to be transferred each time. |
|
194
|
|
|
""" |
|
195
|
|
|
return self.meta_data.get('transfer_shape') |
|
196
|
|
|
|
|
197
|
|
|
def __set_core_shape(self, shape): |
|
198
|
|
|
""" Set the core shape to hold only the shape of the core dimensions |
|
199
|
|
|
""" |
|
200
|
|
|
self.core_shape = shape |
|
201
|
|
|
|
|
202
|
|
|
def get_core_shape(self): |
|
203
|
|
|
""" Get the shape of the core dimensions only. |
|
204
|
|
|
|
|
205
|
|
|
:returns: shape of core dimensions |
|
206
|
|
|
:rtype: tuple |
|
207
|
|
|
""" |
|
208
|
|
|
return self.core_shape |
|
209
|
|
|
|
|
210
|
|
|
def _set_shape_before_tuning(self, shape): |
|
211
|
|
|
""" Set the shape of the full dataset used during each run of the \ |
|
212
|
|
|
plugin (i.e. ignore extra dimensions due to parameter tuning). """ |
|
213
|
|
|
self.pre_tuning_shape = shape |
|
214
|
|
|
|
|
215
|
|
|
def _get_shape_before_tuning(self): |
|
216
|
|
|
""" Return the shape of the full dataset used during each run of the \ |
|
217
|
|
|
plugin (i.e. ignore extra dimensions due to parameter tuning). """ |
|
218
|
|
|
return self.pre_tuning_shape if self.pre_tuning_shape else\ |
|
219
|
|
|
self.data_obj.get_shape() |
|
220
|
|
|
|
|
221
|
|
|
def __check_dimensions(self, indices, core_dir, slice_dir, nDims): |
|
222
|
|
|
if len(indices) is not len(slice_dir): |
|
223
|
|
|
sys.exit("Incorrect number of indices specified when accessing " |
|
224
|
|
|
"data.") |
|
225
|
|
|
|
|
226
|
|
|
if (len(core_dir)+len(slice_dir)) is not nDims: |
|
227
|
|
|
sys.exit("Incorrect number of data dimensions specified.") |
|
228
|
|
|
|
|
229
|
|
|
def __set_slice_dimensions(self, first_sdim=None): |
|
230
|
|
|
""" Set the slice dimensions in the pluginData meta data dictionary.\ |
|
231
|
|
|
Reorder pattern slice_dims to ensure first_sdim is at the front. |
|
232
|
|
|
""" |
|
233
|
|
|
pattern = self.data_obj.get_data_patterns()[self.get_pattern_name()] |
|
234
|
|
|
slice_dims = pattern['slice_dims'] |
|
235
|
|
|
|
|
236
|
|
|
if first_sdim: |
|
237
|
|
|
slice_dims = list(slice_dims) |
|
238
|
|
|
first_sdim = \ |
|
239
|
|
|
self.data_obj.get_data_dimension_by_axis_label(first_sdim) |
|
240
|
|
|
slice_dims.insert(0, slice_dims.pop(slice_dims.index(first_sdim))) |
|
241
|
|
|
pattern['slice_dims'] = tuple(slice_dims) |
|
242
|
|
|
|
|
243
|
|
|
self.meta_data.set('slice_dims', tuple(slice_dims)) |
|
244
|
|
|
|
|
245
|
|
|
def get_slice_dimension(self): |
|
246
|
|
|
""" |
|
247
|
|
|
Return the position of the slice dimension in relation to the data |
|
248
|
|
|
handed to the plugin. |
|
249
|
|
|
""" |
|
250
|
|
|
core_dirs = self.data_obj.get_core_dimensions() |
|
251
|
|
|
slice_dir = self.data_obj.get_slice_dimensions()[0] |
|
252
|
|
|
return list(set(core_dirs + (slice_dir,))).index(slice_dir) |
|
253
|
|
|
|
|
254
|
|
|
def get_data_dimension_by_axis_label(self, label, contains=False): |
|
255
|
|
|
""" |
|
256
|
|
|
Return the dimension of the data in the plugin that has the specified |
|
257
|
|
|
axis label. |
|
258
|
|
|
""" |
|
259
|
|
|
label_dim = self.data_obj.get_data_dimension_by_axis_label( |
|
260
|
|
|
label, contains=contains) |
|
261
|
|
|
plugin_dims = self.data_obj.get_core_dimensions() |
|
262
|
|
|
if self._get_max_frames_process() > 1 or self.max_frames == 'multiple': |
|
263
|
|
|
plugin_dims += (self.get_slice_dimension(),) |
|
264
|
|
|
return list(set(plugin_dims)).index(label_dim) |
|
265
|
|
|
|
|
266
|
|
|
def set_slicing_order(self, order): # should this function be deleted? |
|
267
|
|
|
""" |
|
268
|
|
|
Reorder the slice dimensions. The fastest changing slice dimension |
|
269
|
|
|
will always be the first one stated in the pattern key ``slice_dir``. |
|
270
|
|
|
The input param is a tuple stating the desired order of slicing |
|
271
|
|
|
dimensions relative to the current order. |
|
272
|
|
|
""" |
|
273
|
|
|
slice_dirs = self.data_obj.get_slice_dimensions() |
|
274
|
|
|
if len(slice_dirs) < len(order): |
|
275
|
|
|
raise Exception("Incorrect number of dimensions specifed.") |
|
276
|
|
|
ordered = [slice_dirs[o] for o in order] |
|
277
|
|
|
remaining = [s for s in slice_dirs if s not in ordered] |
|
278
|
|
|
new_slice_dirs = tuple(ordered + remaining) |
|
279
|
|
|
self.get_pattern()['slice_dir'] = new_slice_dirs |
|
280
|
|
|
|
|
281
|
|
|
def get_core_dimensions(self): |
|
282
|
|
|
""" |
|
283
|
|
|
Return the position of the core dimensions in relation to the data |
|
284
|
|
|
handed to the plugin. |
|
285
|
|
|
""" |
|
286
|
|
|
core_dims = self.data_obj.get_core_dimensions() |
|
287
|
|
|
first_slice_dim = (self.data_obj.get_slice_dimensions()[0],) |
|
288
|
|
|
plugin_dims = np.sort(core_dims + first_slice_dim) |
|
289
|
|
|
return np.searchsorted(plugin_dims, np.sort(core_dims)) |
|
290
|
|
|
|
|
291
|
|
|
def set_fixed_dimensions(self, dims, values): |
|
292
|
|
|
""" Fix a data direction to the index in values list. |
|
293
|
|
|
|
|
294
|
|
|
:param list(int) dims: Directions to fix |
|
295
|
|
|
:param list(int) value: Index of fixed directions |
|
296
|
|
|
""" |
|
297
|
|
|
slice_dirs = self.data_obj.get_slice_dimensions() |
|
298
|
|
|
if set(dims).difference(set(slice_dirs)): |
|
299
|
|
|
raise Exception("You are trying to fix a direction that is not" |
|
300
|
|
|
" a slicing direction") |
|
301
|
|
|
self.meta_data.set("fixed_dimensions", dims) |
|
302
|
|
|
self.meta_data.set("fixed_dimensions_values", values) |
|
303
|
|
|
self.__set_slice_dimensions() |
|
304
|
|
|
shape = list(self.data_obj.get_shape()) |
|
305
|
|
|
for dim in dims: |
|
306
|
|
|
shape[dim] = 1 |
|
307
|
|
|
self.data_obj.set_shape(tuple(shape)) |
|
308
|
|
|
#self.__set_shape() |
|
309
|
|
|
|
|
310
|
|
|
def _get_fixed_dimensions(self): |
|
311
|
|
|
""" Get the fixed data directions and their indices |
|
312
|
|
|
|
|
313
|
|
|
:returns: Fixed directions and their associated values |
|
314
|
|
|
:rtype: list(list(int), list(int)) |
|
315
|
|
|
""" |
|
316
|
|
|
fixed = [] |
|
317
|
|
|
values = [] |
|
318
|
|
|
if 'fixed_dimensions' in self.meta_data.get_dictionary(): |
|
319
|
|
|
fixed = self.meta_data.get("fixed_dimensions") |
|
320
|
|
|
values = self.meta_data.get("fixed_dimensions_values") |
|
321
|
|
|
return [fixed, values] |
|
322
|
|
|
|
|
323
|
|
|
def _get_data_slice_list(self, plist): |
|
324
|
|
|
""" Convert a plugin data slice list to a slice list for the whole |
|
325
|
|
|
dataset, i.e. add in any missing dimensions. |
|
326
|
|
|
""" |
|
327
|
|
|
nDims = len(self.get_shape()) |
|
328
|
|
|
all_dims = self.get_core_dimensions() + self.get_slice_dimension() |
|
329
|
|
|
extra_dims = all_dims[nDims:] |
|
330
|
|
|
dlist = list(plist) |
|
331
|
|
|
for i in extra_dims: |
|
332
|
|
|
dlist.insert(i, slice(None)) |
|
333
|
|
|
return tuple(dlist) |
|
334
|
|
|
|
|
335
|
|
|
def _get_max_frames_process(self): |
|
336
|
|
|
""" Get the number of frames to process for each run of process_frames. |
|
337
|
|
|
|
|
338
|
|
|
If the number of frames is not divisible by the previewing ``chunk`` |
|
339
|
|
|
value then amend the number of frames to gcd(frames, chunk) |
|
340
|
|
|
|
|
341
|
|
|
:returns: Number of frames to process |
|
342
|
|
|
:rtype: int |
|
343
|
|
|
""" |
|
344
|
|
|
if self._plugin and self._plugin.chunk > 1: |
|
345
|
|
|
frame_chunk = self.meta_data.get("max_frames_process") |
|
346
|
|
|
chunk = self.data_obj.get_preview().get_starts_stops_steps( |
|
347
|
|
|
key='chunks')[self.get_slice_directions()[0]] |
|
348
|
|
|
self.meta_data.set('max_frames_process', gcd(frame_chunk, chunk)) |
|
349
|
|
|
return self.meta_data.get("max_frames_process") |
|
350
|
|
|
|
|
351
|
|
|
def _get_max_frames_transfer(self): |
|
352
|
|
|
""" Get the number of frames to transfer for each run of |
|
353
|
|
|
process_frames. """ |
|
354
|
|
|
return self.meta_data.get('max_frames_transfer') |
|
355
|
|
|
|
|
356
|
|
|
def _set_no_squeeze(self): |
|
357
|
|
|
self.no_squeeze = True |
|
358
|
|
|
|
|
359
|
|
|
def _get_no_squeeze(self): |
|
360
|
|
|
return self.no_squeeze |
|
361
|
|
|
|
|
362
|
|
|
def _set_rank_inc(self, n): |
|
363
|
|
|
""" Increase the rank of the array passed to the plugin by n. |
|
364
|
|
|
|
|
365
|
|
|
:param int n: Rank increment. |
|
366
|
|
|
""" |
|
367
|
|
|
self._increase_rank = n |
|
368
|
|
|
|
|
369
|
|
|
def _get_rank_inc(self): |
|
370
|
|
|
""" Return the increased rank value |
|
371
|
|
|
|
|
372
|
|
|
:returns: Rank increment |
|
373
|
|
|
:rtype: int |
|
374
|
|
|
""" |
|
375
|
|
|
return self._increase_rank |
|
376
|
|
|
|
|
377
|
|
|
def _set_meta_data(self): |
|
378
|
|
|
fixed, _ = self._get_fixed_dimensions() |
|
379
|
|
|
sdir = \ |
|
380
|
|
|
[s for s in self.data_obj.get_slice_dimensions() if s not in fixed] |
|
381
|
|
|
shape = self.data_obj.get_shape() |
|
382
|
|
|
shape_before_tuning = self._get_shape_before_tuning() |
|
383
|
|
|
|
|
384
|
|
|
diff = len(shape) - len(shape_before_tuning) |
|
385
|
|
|
if diff: |
|
386
|
|
|
shape = shape_before_tuning |
|
387
|
|
|
sdir = sdir[:-diff] |
|
388
|
|
|
|
|
389
|
|
|
if 'fix_total_frames' in list(self.meta_data.get_dictionary().keys()): |
|
390
|
|
|
frames = self.meta_data.get('fix_total_frames') |
|
391
|
|
|
else: |
|
392
|
|
|
frames = np.prod([shape[d] for d in sdir]) |
|
393
|
|
|
|
|
394
|
|
|
base_names = [p.__name__ for p in self._plugin.__class__.__bases__] |
|
395
|
|
|
processes = self.data_obj.exp.meta_data.get('processes') |
|
396
|
|
|
|
|
397
|
|
|
if 'GpuPlugin' in base_names: |
|
398
|
|
|
n_procs = len([n for n in processes if 'GPU' in n]) |
|
399
|
|
|
else: |
|
400
|
|
|
n_procs = len(processes) |
|
401
|
|
|
|
|
402
|
|
|
# Fixing f_per_p to be just the first slice dimension for now due to |
|
403
|
|
|
# slow performance from HDF5 when not slicing multiple dimensions |
|
404
|
|
|
# concurrently |
|
405
|
|
|
#f_per_p = np.ceil(frames/n_procs) |
|
406
|
|
|
f_per_p = np.ceil(shape[sdir[0]]/n_procs) |
|
407
|
|
|
self.meta_data.set('shape', shape) |
|
408
|
|
|
self.meta_data.set('sdir', sdir) |
|
409
|
|
|
self.meta_data.set('total_frames', frames) |
|
410
|
|
|
self.meta_data.set('mpi_procs', n_procs) |
|
411
|
|
|
self.meta_data.set('frames_per_process', f_per_p) |
|
412
|
|
|
frame_shape, b_per_f = self.set_bytes_per_frame() |
|
413
|
|
|
self.meta_data.set('bytes_per_frame', b_per_f) |
|
414
|
|
|
self.meta_data.set('bytes_per_process', b_per_f*f_per_p) |
|
415
|
|
|
self.meta_data.set('frame_shape', frame_shape) |
|
416
|
|
|
|
|
417
|
|
|
def __log_max_frames(self, mft, mfp, check=True): |
|
418
|
|
|
logging.debug("Setting max frames transfer for plugin %s to %d" % |
|
419
|
|
|
(self._plugin, mft)) |
|
420
|
|
|
logging.debug("Setting max frames process for plugin %s to %d" % |
|
421
|
|
|
(self._plugin, mfp)) |
|
422
|
|
|
self.meta_data.set('max_frames_process', mfp) |
|
423
|
|
|
if check: |
|
424
|
|
|
self.__check_distribution(mft) |
|
425
|
|
|
# (((total_frames/mft)/mpi_procs) % 1) |
|
426
|
|
|
|
|
427
|
|
|
def __check_distribution(self, mft): |
|
428
|
|
|
warn_threshold = 0.85 |
|
429
|
|
|
nprocs = self.meta_data.get('mpi_procs') |
|
430
|
|
|
nframes = self.meta_data.get('total_frames') |
|
431
|
|
|
temp = (((nframes/mft)/float(nprocs)) % 1) |
|
432
|
|
|
if temp != 0.0 and temp < warn_threshold: |
|
433
|
|
|
shape = self.meta_data.get('shape') |
|
434
|
|
|
sdir = self.meta_data.get('sdir') |
|
435
|
|
|
logging.warning('UNEVEN FRAME DISTRIBUTION: shape %s, nframes %s ' + |
|
436
|
|
|
'sdir %s, nprocs %s', shape, nframes, sdir, nprocs) |
|
437
|
|
|
|
|
438
|
|
|
def _set_padding_dict(self): |
|
439
|
|
|
if self.padding and not isinstance(self.padding, Padding): |
|
440
|
|
|
self.pad_dict = copy.deepcopy(self.padding) |
|
441
|
|
|
self.padding = Padding(self) |
|
442
|
|
|
for key in list(self.pad_dict.keys()): |
|
443
|
|
|
getattr(self.padding, key)(self.pad_dict[key]) |
|
444
|
|
|
|
|
445
|
|
|
def plugin_data_setup(self, pattern, nFrames, split=None, slice_axis=None, |
|
446
|
|
|
getall=None, fixed_length=True): |
|
447
|
|
|
""" Setup the PluginData object. |
|
448
|
|
|
|
|
449
|
|
|
:param str pattern: A pattern name |
|
450
|
|
|
:param int nFrames: How many frames to process at a time. Choose from |
|
451
|
|
|
'single', 'multiple', 'fixed_multiple' or an integer (an integer |
|
452
|
|
|
should only ever be passed in exceptional circumstances) |
|
453
|
|
|
:keyword str slice_axis: An axis label associated with the fastest |
|
454
|
|
|
changing (first) slice dimension. |
|
455
|
|
|
:keyword list[pattern, axis_label] getall: A list of two values. If |
|
456
|
|
|
the requested pattern doesn't exist then use all of "axis_label" |
|
457
|
|
|
dimension of "pattern" as this is equivalent to one slice of the |
|
458
|
|
|
original pattern. |
|
459
|
|
|
:keyword fixed_length: Data passed to the plugin is automatically |
|
460
|
|
|
padded to ensure all plugin data has the same dimensions. Set this |
|
461
|
|
|
value to False to turn this off. |
|
462
|
|
|
|
|
463
|
|
|
""" |
|
464
|
|
|
|
|
465
|
|
|
if pattern not in self.data_obj.get_data_patterns() and getall: |
|
466
|
|
|
pattern, nFrames = self.__set_getall_pattern(getall, nFrames) |
|
467
|
|
|
|
|
468
|
|
|
# slice_axis is first slice dimension |
|
469
|
|
|
self.__set_pattern(pattern, first_sdim=slice_axis) |
|
470
|
|
|
if isinstance(nFrames, list): |
|
471
|
|
|
nFrames, self._frame_limit = nFrames |
|
472
|
|
|
self.max_frames = nFrames |
|
473
|
|
|
self.split = split |
|
474
|
|
|
if not fixed_length: |
|
475
|
|
|
self._plugin.fixed_length = fixed_length |
|
476
|
|
|
|
|
477
|
|
|
def __set_getall_pattern(self, getall, nFrames): |
|
478
|
|
|
""" Set framework changes required to get all of a pattern of lower |
|
479
|
|
|
rank. |
|
480
|
|
|
""" |
|
481
|
|
|
pattern, slice_axis = getall |
|
482
|
|
|
dim = self.data_obj.get_data_dimension_by_axis_label(slice_axis) |
|
483
|
|
|
# ensure data remains the same shape when 'getall' dim has length 1 |
|
484
|
|
|
self._set_no_squeeze() |
|
485
|
|
|
if nFrames == 'multiple' or (isinstance(nFrames, int) and nFrames > 1): |
|
486
|
|
|
self._set_rank_inc(1) |
|
487
|
|
|
nFrames = self.data_obj.get_shape()[dim] |
|
488
|
|
|
return pattern, nFrames |
|
489
|
|
|
|
|
490
|
|
|
def plugin_data_transfer_setup(self, copy=None, calc=None): |
|
491
|
|
|
""" Set up the plugin data transfer frame parameters. |
|
492
|
|
|
If copy=pData (another PluginData instance) then copy """ |
|
493
|
|
|
chunks = \ |
|
494
|
|
|
self.data_obj.get_preview().get_starts_stops_steps(key='chunks') |
|
495
|
|
|
if not copy and not calc: |
|
496
|
|
|
mft, mft_shape, mfp = self._calculate_max_frames() |
|
497
|
|
|
elif calc: |
|
498
|
|
|
max_mft = calc.meta_data.get('max_frames_transfer') |
|
499
|
|
|
max_mfp = calc.meta_data.get('max_frames_process') |
|
500
|
|
|
max_nProc = int(np.ceil(max_mft/float(max_mfp))) |
|
501
|
|
|
nProc = max_nProc |
|
502
|
|
|
mfp = 1 if self.max_frames == 'single' else self.max_frames |
|
503
|
|
|
mft = nProc*mfp |
|
504
|
|
|
mft_shape = self._set_shape_transfer(self.__get_slice_size(mft)) |
|
505
|
|
|
elif copy: |
|
506
|
|
|
mft = copy._get_max_frames_transfer() |
|
507
|
|
|
mft_shape = self._set_shape_transfer(self.__get_slice_size(mft)) |
|
508
|
|
|
mfp = copy._get_max_frames_process() |
|
509
|
|
|
|
|
510
|
|
|
self.__set_max_frames(mft, mft_shape, mfp) |
|
|
|
|
|
|
511
|
|
|
|
|
512
|
|
|
if self._plugin and mft \ |
|
513
|
|
|
and (chunks[self.data_obj.get_slice_dimensions()[0]] % mft): |
|
514
|
|
|
self._plugin.chunk = True |
|
515
|
|
|
self.__set_shape() |
|
516
|
|
|
|
|
517
|
|
|
def _calculate_max_frames(self): |
|
518
|
|
|
nFrames = self.max_frames |
|
519
|
|
|
self.__perform_checks(nFrames) |
|
520
|
|
|
td = self.data_obj._get_transport_data() |
|
521
|
|
|
mft, size_list = td._calc_max_frames_transfer(nFrames) |
|
522
|
|
|
self.meta_data.set('size_list', size_list) |
|
523
|
|
|
mfp = td._calc_max_frames_process(nFrames) |
|
524
|
|
|
if mft: |
|
525
|
|
|
mft_shape = self._set_shape_transfer(list(size_list)) |
|
526
|
|
|
return mft, mft_shape, mfp |
|
|
|
|
|
|
527
|
|
|
|
|
528
|
|
|
def __set_max_frames(self, mft, mft_shape, mfp): |
|
529
|
|
|
self.meta_data.set('max_frames_transfer', mft) |
|
530
|
|
|
self.meta_data.set('transfer_shape', mft_shape) |
|
531
|
|
|
self.meta_data.set('max_frames_process', mfp) |
|
532
|
|
|
self.__log_max_frames(mft, mfp) |
|
533
|
|
|
# Retain the shape if the first slice dimension has length 1 |
|
534
|
|
|
if mfp == 1 and self.max_frames == 'multiple': |
|
535
|
|
|
self._set_no_squeeze() |
|
536
|
|
|
|
|
537
|
|
|
def _get_plugin_data_size_params(self): |
|
538
|
|
|
nBytes = self.data_obj.get_itemsize() |
|
539
|
|
|
frame_shape = self.meta_data.get('frame_shape') |
|
540
|
|
|
total_frames = self.meta_data.get('total_frames') |
|
541
|
|
|
tbytes = nBytes*np.prod(frame_shape)*total_frames |
|
542
|
|
|
|
|
543
|
|
|
params = {'nBytes': nBytes, 'frame_shape': frame_shape, |
|
544
|
|
|
'total_frames': total_frames, 'transfer_bytes': tbytes} |
|
545
|
|
|
return params |
|
546
|
|
|
|
|
547
|
|
|
def __perform_checks(self, nFrames): |
|
548
|
|
|
options = ['single', 'multiple'] |
|
549
|
|
|
if not np.issubdtype(type(nFrames), np.int64) and nFrames not in options: |
|
550
|
|
|
e_str = ("The value of nFrames is not recognised. Please choose " |
|
551
|
|
|
+ "from 'single' and 'multiple' (or an integer in exceptional " |
|
552
|
|
|
+ "circumstances).") |
|
553
|
|
|
raise Exception(e_str) |
|
554
|
|
|
|
|
555
|
|
|
def get_frame_limit(self): |
|
556
|
|
|
return self._frame_limit |
|
557
|
|
|
|
|
558
|
|
|
def get_current_frame_idx(self): |
|
559
|
|
|
""" Returns the index of the frames currently being processed. |
|
560
|
|
|
""" |
|
561
|
|
|
global_index = self._plugin.get_global_frame_index() |
|
562
|
|
|
count = self._plugin.get_process_frames_counter() |
|
563
|
|
|
mfp = self.meta_data.get('max_frames_process') |
|
564
|
|
|
start = global_index[count]*mfp |
|
565
|
|
|
index = np.arange(start, start + mfp) |
|
566
|
|
|
nFrames = self.get_total_frames() |
|
567
|
|
|
index[index >= nFrames] = nFrames - 1 |
|
568
|
|
|
return index |
|
569
|
|
|
|