PluginData._get_rank_inc()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 7
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 7
rs 10
c 0
b 0
f 0
1
# Copyright 201i Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: plugin_data
17
   :platform: Unix
18
   :synopsis: Contains the PluginData class. Each Data set used in a plugin \
19
       has a PluginData object encapsulated within it, for the duration of a \
20
       plugin run.
21
22
.. moduleauthor:: Nicola Wadeson <[email protected]>
23
24
"""
25
import sys
26
import copy
27
import h5py
28
import logging
29
import numpy as np
30
from fractions import gcd
31
32
from savu.data.meta_data import MetaData
33
from savu.data.data_structures.data_add_ons import Padding
34
35
36
class PluginData(object):
37
    """ The PluginData class contains plugin specific information about a Data
38
    object for the duration of a plugin.  An instance of the class is
39
    encapsulated inside the Data object during the plugin run
40
    """
41
42
    def __init__(self, data_obj, plugin=None):
43
        self.data_obj = data_obj
44
        self._preview = None
45
        self.data_obj._set_plugin_data(self)
46
        self.meta_data = MetaData()
47
        self.padding = None
48
        self.pad_dict = None
49
        self.shape = None
50
        self.core_shape = None
51
        self.multi_params = {}
52
        self.extra_dims = []
53
        self._plugin = plugin
54
        self.fixed_dims = True
55
        self.split = None
56
        self.boundary_padding = None
57
        self.no_squeeze = False
58
        self.pre_tuning_shape = None
59
        self._frame_limit = None
60
        self._increase_rank = 0
61
62
    def _get_preview(self):
63
        return self._preview
64
65
    def get_total_frames(self):
66
        """ Get the total number of frames to process (all MPI processes).
67
68
        :returns: Number of frames
69
        :rtype: int
70
        """
71
        temp = 1
72
        slice_dir = \
73
            self.data_obj.get_data_patterns()[
74
                self.get_pattern_name()]["slice_dims"]
75
        for tslice in slice_dir:
76
            temp *= self.data_obj.get_shape()[tslice]
77
        return temp
78
79
    def __set_pattern(self, name, first_sdim=None):
80
        """ Set the pattern related information in the meta data dict.
81
        """
82
        pattern = self.data_obj.get_data_patterns()[name]
83
        self.meta_data.set("name", name)
84
        self.meta_data.set("core_dims", pattern['core_dims'])
85
        self.__set_slice_dimensions(first_sdim=first_sdim)
86
87
    def get_pattern_name(self):
88
        """ Get the pattern name.
89
90
        :returns: the pattern name
91
        :rtype: str
92
        """
93
        try:
94
            name = self.meta_data.get("name")
95
            return name
96
        except KeyError:
97
            raise Exception("The pattern name has not been set.")
98
99
    def get_pattern(self):
100
        """ Get the current pattern.
101
102
        :returns: dict of the pattern name against the pattern.
103
        :rtype: dict
104
        """
105
        pattern_name = self.get_pattern_name()
106
        return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]}
107
108
    def __set_shape(self):
109
        """ Set the shape of the plugin data processing chunk.
110
        """
111
        core_dir = self.data_obj.get_core_dimensions()
112
        slice_dir = self.data_obj.get_slice_dimensions()
113
        dirs = list(set(core_dir + (slice_dir[0],)))
114
        slice_idx = dirs.index(slice_dir[0])
115
        dshape = self.data_obj.get_shape()
116
        shape = []
117
        for core in set(core_dir):
118
            shape.append(dshape[core])
119
        self.__set_core_shape(tuple(shape))
120
121
        mfp = self._get_max_frames_process()
122
        if mfp > 1 or self._get_no_squeeze():
123
            shape.insert(slice_idx, mfp)
124
        self.shape = tuple(shape)
125
126
    def _set_shape_transfer(self, slice_size):
127
        dshape = self.data_obj.get_shape()
128
        shape_before_tuning = self._get_shape_before_tuning()
129
        add = [1]*(len(dshape) - len(shape_before_tuning))
130
        slice_size = slice_size + add
131
132
        core_dir = self.data_obj.get_core_dimensions()
133
        slice_dir = self.data_obj.get_slice_dimensions()
134
        shape = [None]*len(dshape)
135
        for dim in core_dir:
136
            shape[dim] = dshape[dim]
137
        i = 0
138
        for dim in slice_dir:
139
            shape[dim] = slice_size[i]
140
            i += 1
141
        return tuple(shape)
142
143
    def __get_slice_size(self, mft):
144
        """ Calculate the number of frames transfer in each dimension given
145
            mft. """
146
        dshape = list(self.data_obj.get_shape())
147
148
        if 'fixed_dimensions' in list(self.meta_data.get_dictionary().keys()):
149
            fixed_dims = self.meta_data.get('fixed_dimensions')
150
            for d in fixed_dims:
151
                dshape[d] = 1
152
153
        dshape = [dshape[i] for i in self.meta_data.get('slice_dims')]
154
        size_list = [1]*len(dshape)
155
        i = 0
156
157
        while(mft > 1 and i < len(size_list)):
158
            size_list[i] = min(dshape[i], mft)
159
            mft //= np.prod(size_list) if np.prod(size_list) > 1 else 1
160
            i += 1
161
            
162
        # case of fixed integer max_frames, where max_frames > nSlices
163
        if mft > 1:
164
            size_list[0] *= mft
165
166
        self.meta_data.set('size_list', size_list)
167
        return size_list
168
169
    def set_bytes_per_frame(self):
170
        """ Return the size of a single frame in bytes. """
171
        nBytes = self.data_obj.get_itemsize()
172
        dims = list(self.get_pattern().values())[0]['core_dims']
173
        frame_shape = [self.data_obj.get_shape()[d] for d in dims]
174
        b_per_f = np.prod(frame_shape)*nBytes
175
        return frame_shape, b_per_f
176
177
    def get_shape(self):
178
        """ Get the shape of the data (without padding) that is passed to the
179
        plugin process_frames method.
180
        """
181
        return self.shape
182
183
    def _set_padded_shape(self):
184
        pass
185
186
    def get_padded_shape(self):
187
        """ Get the shape of the data (with padding) that is passed to the
188
        plugin process_frames method.
189
        """
190
        return self.shape
191
192
    def get_shape_transfer(self):
193
        """ Get the shape of the plugin data to be transferred each time.
194
        """
195
        return self.meta_data.get('transfer_shape')
196
197
    def __set_core_shape(self, shape):
198
        """ Set the core shape to hold only the shape of the core dimensions
199
        """
200
        self.core_shape = shape
201
202
    def get_core_shape(self):
203
        """ Get the shape of the core dimensions only.
204
205
        :returns: shape of core dimensions
206
        :rtype: tuple
207
        """
208
        return self.core_shape
209
210
    def _set_shape_before_tuning(self, shape):
211
        """ Set the shape of the full dataset used during each run of the \
212
        plugin (i.e. ignore extra dimensions due to parameter tuning). """
213
        self.pre_tuning_shape = shape
214
215
    def _get_shape_before_tuning(self):
216
        """ Return the shape of the full dataset used during each run of the \
217
        plugin (i.e. ignore extra dimensions due to parameter tuning). """
218
        return self.pre_tuning_shape if self.pre_tuning_shape else\
219
            self.data_obj.get_shape()
220
221
    def __check_dimensions(self, indices, core_dir, slice_dir, nDims):
222
        if len(indices) is not len(slice_dir):
223
            sys.exit("Incorrect number of indices specified when accessing "
224
                     "data.")
225
226
        if (len(core_dir)+len(slice_dir)) is not nDims:
227
            sys.exit("Incorrect number of data dimensions specified.")
228
229
    def __set_slice_dimensions(self, first_sdim=None):
230
        """ Set the slice dimensions in the pluginData meta data dictionary.\
231
        Reorder pattern slice_dims to ensure first_sdim is at the front.
232
        """
233
        pattern = self.data_obj.get_data_patterns()[self.get_pattern_name()]
234
        slice_dims = pattern['slice_dims']
235
236
        if first_sdim:
237
            slice_dims = list(slice_dims)
238
            first_sdim = \
239
                self.data_obj.get_data_dimension_by_axis_label(first_sdim)
240
            slice_dims.insert(0, slice_dims.pop(slice_dims.index(first_sdim)))
241
            pattern['slice_dims'] = tuple(slice_dims)
242
243
        self.meta_data.set('slice_dims', tuple(slice_dims))
244
245
    def get_slice_dimension(self):
246
        """
247
        Return the position of the slice dimension in relation to the data
248
        handed to the plugin.
249
        """
250
        core_dirs = self.data_obj.get_core_dimensions()
251
        slice_dir = self.data_obj.get_slice_dimensions()[0]
252
        return list(set(core_dirs + (slice_dir,))).index(slice_dir)
253
254
    def get_data_dimension_by_axis_label(self, label, contains=False):
255
        """
256
        Return the dimension of the data in the plugin that has the specified
257
        axis label.
258
        """
259
        label_dim = self.data_obj.get_data_dimension_by_axis_label(
260
                label, contains=contains)
261
        plugin_dims = self.data_obj.get_core_dimensions()
262
        if self._get_max_frames_process() > 1 or self.max_frames == 'multiple':
263
            plugin_dims += (self.get_slice_dimension(),)
264
        return list(set(plugin_dims)).index(label_dim)
265
266
    def set_slicing_order(self, order):  # should this function be deleted?
267
        """
268
        Reorder the slice dimensions.  The fastest changing slice dimension
269
        will always be the first one stated in the pattern key ``slice_dir``.
270
        The input param is a tuple stating the desired order of slicing
271
        dimensions relative to the current order.
272
        """
273
        slice_dirs = self.data_obj.get_slice_dimensions()
274
        if len(slice_dirs) < len(order):
275
            raise Exception("Incorrect number of dimensions specifed.")
276
        ordered = [slice_dirs[o] for o in order]
277
        remaining = [s for s in slice_dirs if s not in ordered]
278
        new_slice_dirs = tuple(ordered + remaining)
279
        self.get_pattern()['slice_dir'] = new_slice_dirs
280
281
    def get_core_dimensions(self):
282
        """
283
        Return the position of the core dimensions in relation to the data
284
        handed to the plugin.
285
        """
286
        core_dims = self.data_obj.get_core_dimensions()
287
        first_slice_dim = (self.data_obj.get_slice_dimensions()[0],)
288
        plugin_dims = np.sort(core_dims + first_slice_dim)
289
        return np.searchsorted(plugin_dims, np.sort(core_dims))
290
291
    def set_fixed_dimensions(self, dims, values):
292
        """ Fix a data direction to the index in values list.
293
294
        :param list(int) dims: Directions to fix
295
        :param list(int) value: Index of fixed directions
296
        """
297
        slice_dirs = self.data_obj.get_slice_dimensions()
298
        if set(dims).difference(set(slice_dirs)):
299
            raise Exception("You are trying to fix a direction that is not"
300
                            " a slicing direction")
301
        self.meta_data.set("fixed_dimensions", dims)
302
        self.meta_data.set("fixed_dimensions_values", values)
303
        self.__set_slice_dimensions()
304
        shape = list(self.data_obj.get_shape())
305
        for dim in dims:
306
            shape[dim] = 1
307
        self.data_obj.set_shape(tuple(shape))
308
        #self.__set_shape()
309
310
    def _get_fixed_dimensions(self):
311
        """ Get the fixed data directions and their indices
312
313
        :returns: Fixed directions and their associated values
314
        :rtype: list(list(int), list(int))
315
        """
316
        fixed = []
317
        values = []
318
        if 'fixed_dimensions' in self.meta_data.get_dictionary():
319
            fixed = self.meta_data.get("fixed_dimensions")
320
            values = self.meta_data.get("fixed_dimensions_values")
321
        return [fixed, values]
322
323
    def _get_data_slice_list(self, plist):
324
        """ Convert a plugin data slice list to a slice list for the whole
325
        dataset, i.e. add in any missing dimensions.
326
        """
327
        nDims = len(self.get_shape())
328
        all_dims = self.get_core_dimensions() + self.get_slice_dimension()
329
        extra_dims = all_dims[nDims:]
330
        dlist = list(plist)
331
        for i in extra_dims:
332
            dlist.insert(i, slice(None))
333
        return tuple(dlist)
334
335
    def _get_max_frames_process(self):
336
        """ Get the number of frames to process for each run of process_frames.
337
338
        If the number of frames is not divisible by the previewing ``chunk``
339
        value then amend the number of frames to gcd(frames, chunk)
340
341
        :returns: Number of frames to process
342
        :rtype: int
343
        """
344
        if self._plugin and self._plugin.chunk > 1:
345
            frame_chunk = self.meta_data.get("max_frames_process")
346
            chunk = self.data_obj.get_preview().get_starts_stops_steps(
347
                key='chunks')[self.get_slice_directions()[0]]
348
            self.meta_data.set('max_frames_process', gcd(frame_chunk, chunk))
349
        return self.meta_data.get("max_frames_process")
350
351
    def _get_max_frames_transfer(self):
352
        """ Get the number of frames to transfer for each run of
353
        process_frames. """
354
        return self.meta_data.get('max_frames_transfer')
355
356
    def _set_no_squeeze(self):
357
        self.no_squeeze = True
358
359
    def _get_no_squeeze(self):
360
        return self.no_squeeze
361
    
362
    def _set_rank_inc(self, n):
363
        """ Increase the rank of the array passed to the plugin by n.
364
        
365
        :param int n: Rank increment.
366
        """
367
        self._increase_rank = n
368
    
369
    def _get_rank_inc(self):
370
        """ Return the increased rank value
371
        
372
        :returns: Rank increment
373
        :rtype: int
374
        """
375
        return self._increase_rank
376
377
    def _set_meta_data(self):
378
        fixed, _ = self._get_fixed_dimensions()
379
        sdir = \
380
            [s for s in self.data_obj.get_slice_dimensions() if s not in fixed]
381
        shape = self.data_obj.get_shape()
382
        shape_before_tuning = self._get_shape_before_tuning()
383
384
        diff = len(shape) - len(shape_before_tuning)
385
        if diff:
386
            shape = shape_before_tuning
387
            sdir = sdir[:-diff]
388
389
        if 'fix_total_frames' in list(self.meta_data.get_dictionary().keys()):
390
            frames = self.meta_data.get('fix_total_frames')
391
        else:
392
            frames = np.prod([shape[d] for d in sdir])
393
394
        base_names = [p.__name__ for p in self._plugin.__class__.__bases__]
395
        processes = self.data_obj.exp.meta_data.get('processes')
396
397
        if 'GpuPlugin' in base_names:
398
            n_procs = len([n for n in processes if 'GPU' in n])
399
        else:
400
            n_procs = len(processes)
401
402
        # Fixing f_per_p to be just the first slice dimension for now due to
403
        # slow performance from HDF5 when not slicing multiple dimensions
404
        # concurrently
405
        #f_per_p = np.ceil(frames/n_procs)
406
        f_per_p = np.ceil(shape[sdir[0]]/n_procs)
407
        self.meta_data.set('shape', shape)
408
        self.meta_data.set('sdir', sdir)
409
        self.meta_data.set('total_frames', frames)
410
        self.meta_data.set('mpi_procs', n_procs)
411
        self.meta_data.set('frames_per_process', f_per_p)
412
        frame_shape, b_per_f = self.set_bytes_per_frame()
413
        self.meta_data.set('bytes_per_frame', b_per_f)
414
        self.meta_data.set('bytes_per_process', b_per_f*f_per_p)
415
        self.meta_data.set('frame_shape', frame_shape)
416
417
    def __log_max_frames(self, mft, mfp, check=True):
418
        logging.debug("Setting max frames transfer for plugin %s to %d" %
419
                      (self._plugin, mft))
420
        logging.debug("Setting max frames process for plugin %s to %d" %
421
                      (self._plugin, mfp))
422
        self.meta_data.set('max_frames_process', mfp)
423
        if check:
424
            self.__check_distribution(mft)
425
        # (((total_frames/mft)/mpi_procs) % 1)
426
427
    def __check_distribution(self, mft):
428
        warn_threshold = 0.85
429
        nprocs = self.meta_data.get('mpi_procs')
430
        nframes = self.meta_data.get('total_frames')
431
        temp = (((nframes/mft)/float(nprocs)) % 1)
432
        if temp != 0.0 and temp < warn_threshold:
433
            shape = self.meta_data.get('shape')
434
            sdir = self.meta_data.get('sdir')
435
            logging.warning('UNEVEN FRAME DISTRIBUTION: shape %s, nframes %s ' +
436
                         'sdir %s, nprocs %s', shape, nframes, sdir, nprocs)
437
438
    def _set_padding_dict(self):
439
        if self.padding and not isinstance(self.padding, Padding):
440
            self.pad_dict = copy.deepcopy(self.padding)
441
            self.padding = Padding(self)
442
            for key in list(self.pad_dict.keys()):
443
                getattr(self.padding, key)(self.pad_dict[key])
444
445
    def plugin_data_setup(self, pattern, nFrames, split=None, slice_axis=None,
446
                          getall=None, fixed_length=True):
447
        """ Setup the PluginData object.
448
449
        :param str pattern: A pattern name
450
        :param int nFrames: How many frames to process at a time.  Choose from
451
         'single', 'multiple', 'fixed_multiple' or an integer (an integer
452
         should only ever be passed in exceptional circumstances)
453
        :keyword str slice_axis: An axis label associated with the fastest
454
         changing (first) slice dimension.
455
        :keyword list[pattern, axis_label] getall: A list of two values.  If
456
         the requested pattern doesn't exist then use all of "axis_label"
457
         dimension of "pattern" as this is equivalent to one slice of the
458
         original pattern.
459
        :keyword fixed_length: Data passed to the plugin is automatically
460
         padded to ensure all plugin data has the same dimensions. Set this
461
         value to False to turn this off.
462
463
        """
464
465
        if pattern not in self.data_obj.get_data_patterns() and getall:
466
            pattern, nFrames = self.__set_getall_pattern(getall, nFrames)
467
468
        # slice_axis is first slice dimension
469
        self.__set_pattern(pattern, first_sdim=slice_axis)
470
        if isinstance(nFrames, list):
471
            nFrames, self._frame_limit = nFrames
472
        self.max_frames = nFrames
473
        self.split = split
474
        if not fixed_length:
475
            self._plugin.fixed_length = fixed_length
476
477
    def __set_getall_pattern(self, getall, nFrames):
478
        """ Set framework changes required to get all of a pattern of lower
479
        rank.
480
        """
481
        pattern, slice_axis = getall
482
        dim = self.data_obj.get_data_dimension_by_axis_label(slice_axis)
483
        # ensure data remains the same shape when 'getall' dim has length 1
484
        self._set_no_squeeze()
485
        if nFrames == 'multiple' or (isinstance(nFrames, int) and nFrames > 1):
486
            self._set_rank_inc(1)
487
        nFrames = self.data_obj.get_shape()[dim]
488
        return pattern, nFrames
489
490
    def plugin_data_transfer_setup(self, copy=None, calc=None):
491
        """ Set up the plugin data transfer frame parameters.
492
        If copy=pData (another PluginData instance) then copy """
493
        chunks = \
494
            self.data_obj.get_preview().get_starts_stops_steps(key='chunks')
495
        if not copy and not calc:
496
            mft, mft_shape, mfp = self._calculate_max_frames()
497
        elif calc:
498
            max_mft = calc.meta_data.get('max_frames_transfer')
499
            max_mfp = calc.meta_data.get('max_frames_process')
500
            max_nProc = int(np.ceil(max_mft/float(max_mfp)))
501
            nProc = max_nProc
502
            mfp = 1 if self.max_frames == 'single' else self.max_frames
503
            mft = nProc*mfp
504
            mft_shape = self._set_shape_transfer(self.__get_slice_size(mft))
505
        elif copy:
506
            mft = copy._get_max_frames_transfer()
507
            mft_shape = self._set_shape_transfer(self.__get_slice_size(mft))
508
            mfp = copy._get_max_frames_process()
509
510
        self.__set_max_frames(mft, mft_shape, mfp)
0 ignored issues
show
introduced by
The variable mfp does not seem to be defined for all execution paths.
Loading history...
introduced by
The variable mft does not seem to be defined for all execution paths.
Loading history...
introduced by
The variable mft_shape does not seem to be defined for all execution paths.
Loading history...
511
512
        if self._plugin and mft \
513
                and (chunks[self.data_obj.get_slice_dimensions()[0]] % mft):
514
            self._plugin.chunk = True
515
        self.__set_shape()
516
517
    def _calculate_max_frames(self):
518
        nFrames = self.max_frames
519
        self.__perform_checks(nFrames)
520
        td = self.data_obj._get_transport_data()
521
        mft, size_list = td._calc_max_frames_transfer(nFrames)
522
        self.meta_data.set('size_list', size_list)
523
        mfp = td._calc_max_frames_process(nFrames)
524
        if mft:
525
            mft_shape = self._set_shape_transfer(list(size_list))
526
        return mft, mft_shape, mfp
0 ignored issues
show
introduced by
The variable mft_shape does not seem to be defined in case mft on line 524 is False. Are you sure this can never be the case?
Loading history...
527
528
    def __set_max_frames(self, mft, mft_shape, mfp):
529
        self.meta_data.set('max_frames_transfer', mft)
530
        self.meta_data.set('transfer_shape', mft_shape)
531
        self.meta_data.set('max_frames_process', mfp)
532
        self.__log_max_frames(mft, mfp)
533
        # Retain the shape if the first slice dimension has length 1
534
        if mfp == 1 and self.max_frames == 'multiple':
535
            self._set_no_squeeze()
536
537
    def _get_plugin_data_size_params(self):
538
        nBytes = self.data_obj.get_itemsize()
539
        frame_shape = self.meta_data.get('frame_shape')
540
        total_frames = self.meta_data.get('total_frames')
541
        tbytes = nBytes*np.prod(frame_shape)*total_frames
542
543
        params = {'nBytes': nBytes, 'frame_shape': frame_shape,
544
                  'total_frames': total_frames, 'transfer_bytes': tbytes}
545
        return params
546
547
    def __perform_checks(self, nFrames):
548
        options = ['single', 'multiple']
549
        if not np.issubdtype(type(nFrames), np.int64) and nFrames not in options:
550
            e_str = ("The value of nFrames is not recognised.  Please choose "
551
            + "from 'single' and 'multiple' (or an integer in exceptional "
552
            + "circumstances).")
553
            raise Exception(e_str)
554
555
    def get_frame_limit(self):
556
        return self._frame_limit
557
558
    def get_current_frame_idx(self):
559
        """ Returns the index of the frames currently being processed.
560
        """
561
        global_index = self._plugin.get_global_frame_index()
562
        count = self._plugin.get_process_frames_counter()
563
        mfp = self.meta_data.get('max_frames_process')
564
        start = global_index[count]*mfp
565
        index = np.arange(start, start + mfp)
566
        nFrames = self.get_total_frames()
567
        index[index >= nFrames] = nFrames - 1
568
        return index
569