GlobalData._get_dict()   A
last analyzed

Complexity

Conditions 2

Size

Total Lines 4
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 4
nop 2
dl 0
loc 4
rs 10
c 0
b 0
f 0
1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: slice_lists
17
   :platform: Unix
18
   :synopsis: Contains classes for creating global and local slice lists
19
20
.. moduleauthor:: Nicola Wadeson <[email protected]>
21
22
"""
23
24
import numpy as np
25
26
27
class SliceLists(object):
28
    """
29
    SliceLists class creates global and local slices lists used to transfer
30
    the data
31
    """
32
33
    def __init__(self, name='SliceLists'):
34
        super(SliceLists, self).__init__()
35
        self.pad = False
36
        self.transfer_data = None
37
38
    def _get_process_data(self):
39
        return self.process_data
40
41
    def _single_slice_list(self, nSlices, nDims, core_slice, core_dirs,
42
                           slice_dirs, fix, index):
43
44
        fix_dirs, value = fix
45
        slice_list = []
46
        for i in range(nSlices):
47
            getitem = np.array([slice(None)]*nDims)
48
            getitem[core_dirs] = core_slice[np.arange(len(core_dirs))]
49
            for f in range(len(fix_dirs)):
50
                getitem[fix_dirs[f]] = slice(value[f], value[f] + 1, 1)
51
            for sdir in range(len(slice_dirs)):
52
                getitem[slice_dirs[sdir]] = slice(index[sdir, i],
53
                                                  index[sdir, i] + 1, 1)
54
            slice_list.append(tuple(getitem))
55
        return slice_list
56
57
    def _get_slice_dirs_index(self, slice_dirs, shape, value, extra=None, calc=None):
58
        """
59
        returns a list of arrays for each slice dimension, where each array
60
        gives the indices for that slice dimension.
61
        """
62
        # create the indexing array
63
        chunk, length, repeat = self._chunk_length_repeat(slice_dirs, shape)
64
        values = None
65
        idx_list = []
66
        for i in range(len(slice_dirs)):
67
            c = chunk[i]
68
            r = repeat[i]
69
            values = eval(value)
70
            idx = np.ravel(np.kron(values, np.ones((r, c))))
71
            idx_list.append(idx.astype(int))
72
        return np.array(idx_list)
73
74
    def _chunk_length_repeat(self, slice_dirs, shape):
75
        """
76
        For each slice dimension, determine 3 values relevant to the slicing.
77
78
        :returns: chunk, length, repeat
79
            chunk: how many repeats of the same index value before an increment
80
            length: the slice dimension length (sequence length)
81
            repeat: how many times does the sequence of chunked numbers repeat
82
        :rtype: [int, int, int]
83
        """
84
        sshape = self.__get_shape_of_slice_dirs(slice_dirs, shape)
85
        if not slice_dirs:
86
            return [1], [1], [1]
87
88
        chunk = []
89
        length = []
90
        repeat = []
91
        for dim in range(len(slice_dirs)):
92
            chunk.append(int(np.prod(sshape[0:dim])))
93
            length.append(sshape[dim])
94
            repeat.append(int(np.prod(sshape[dim+1:])))
95
        return chunk, length, repeat
96
97
    def __get_shape_of_slice_dirs(self, slice_dirs, shape):
98
        sshape = [shape[sslice] for sslice in slice_dirs]
99
        if 'var' in sshape:
100
            shape = list(shape)
101
            for index, value in enumerate(shape):
102
                if isinstance(value, str):
103
                    shape[index] = \
104
                        len(self.data_info.get('axis_labels')[index])
105
            shape = tuple(shape)
106
            sshape = [shape[sslice] for sslice in slice_dirs]
107
        return sshape
108
109
    def _get_core_slices(self, core_dirs):
110
        core_slice = []
111
        starts, stops, steps, chunks = \
112
            self.data.get_preview().get_starts_stops_steps()
113
114
        for c in core_dirs:
115
            if (chunks[c]) > 1:
116
                if (stops[c] - starts[c] == 1):
117
                    start = starts[c] - int(chunks[c] / 2.0)
118
                    if start < 0:
119
                        raise Exception('Cannot have a negative value in the '
120
                                        'slice list.')
121
                    stop = starts[c] + (chunks[c] - int(chunks[c] / 2.0))
122
                    core_slice.append(slice(start, stop, 1))
123
                else:
124
                    raise Exception("The core dimension does not support "
125
                                    "multiple chunks.")
126
            else:
127
                core_slice.append(slice(starts[c], stops[c], steps[c]))
128
        return np.array(core_slice)
129
130
    def _group_dimension(self, sl, dim, step):
131
        start = sl[0][dim].start
132
        stop = sl[-1][dim].stop
133
        working_slice = list(sl[0])
134
        working_slice[dim] = slice(start, stop, step)
135
        return tuple(working_slice)
136
137
    def _split_list(self, the_list, size):
138
        return [the_list[x:x+size] for x in range(0, len(the_list), size)]
139
140
    # This method only works if the split dimensions in the slice list contain
141
    # slice objects
142
    def __split_frames(self, slice_list, split_list):
143
        split = [list(map(int, a.split('.'))) for a in split_list]
144
        dims = [s[0] for s in split]
145
        length = [s[1] for s in split]
146
        replace = self.__get_split_frame_entries(slice_list, dims, length)
147
        # now replace each slice list entry with multiple entries
148
        array_list = []
149
        for sl in slice_list:
150
            new_list = np.array([sl for i in range(len(replace[0]))])
151
            for d, i in zip(dims, list(range(len(dims)))):
152
                new_list[:, d] = replace[i]
153
            array_list += [tuple(a) for a in new_list]
154
155
        return tuple(array_list)
156
157
    def __get_split_frame_entries(self, slice_list, dims, length):
158
        shape = self.get_shape
159
        replace = []
160
        seq_len = []
161
162
        # get the new entries
163
        for d, l in zip(dims, length):
164
            sl = slice_list[0][d]
165
            start = 0 if sl.start is None else sl.start
166
            stop = shape[d] if sl.stop is None else sl.stop
167
            inc = l*sl.step if sl.step else l
168
            temp_list = [slice(a, a+inc) for a in np.arange(start, stop, inc)]
169
            if temp_list[-1].stop > stop:
170
                temp = temp_list[-1]
171
                temp_list[-1] = slice(temp.start, stop, temp.step)
172
            replace.append(temp_list)
173
            seq_len.append(len(temp_list))
174
175
        # calculate the permutations
176
        length = np.array(seq_len)
177
        chunk = [int(np.prod(length[0:dim])) for dim in range(len(dims))]
178
        repeat = [int(np.prod(length[dim+1:])) for dim in range(len(dims))]
179
        full_replace = []
180
        for d in range(len(dims)):
181
            temp = [[replace[d][i]]*chunk[d] for x in range(repeat[d]) for i
182
                    in range(len(replace[d]))]
183
            full_replace.append([t for sub in temp for t in sub])
184
        return full_replace
185
186
    def _get_frames_per_process(self, slice_list):
187
        processes = self.data.exp.meta_data.get("processes")
188
        process = self.data.exp.meta_data.get("process")
189
        frame_idx = np.arange(len(slice_list))
190
        try:
191
            frames = np.array_split(frame_idx, len(processes))[process]
192
            slice_list = slice_list[frames[0]:frames[-1]+1]
193
        except IndexError:
194
            slice_list = []
195
        return slice_list, frames
196
197
    def _pad_slice_list(self, slice_list, inc_start_str: str, inc_stop_str: str):
198
        """ Amend the slice lists to include padding.  Includes variations for
199
        transfer and process slice lists. """
200
        pData = self.data._get_plugin_data()
201
        if not pData.padding:
202
            return slice_list
203
204
        pad_dict = pData.padding._get_padding_directions()
205
206
        shape = self.data.get_shape()
207
        for ddir, value in pad_dict.items():
208
            inc_start = eval(inc_start_str)
209
            inc_stop = eval(inc_stop_str)
210
            for i in range(len(slice_list)):
211
                slice_list[i] = list(slice_list[i])
212
                sl = slice_list[i][ddir]
213
                if sl.start is None:
214
                    sl = slice(0, shape[ddir], 1)
215
                slice_list[i][ddir] = \
216
                    slice(sl.start + inc_start, sl.stop + inc_stop, sl.step)
217
                slice_list[i] = tuple(slice_list[i])
218
        return slice_list
219
220
    def _fix_list_length(self, sl, pad):
221
        sl = list(sl)
222
        steps = self.data.data_info.get("steps")
223
        for i, s in enumerate(sl):
224
            sl[i] = slice(s.start, s.stop + steps[i]*pad[i], s.step)
225
        return tuple(sl)
226
227
    def _get_local_single_slice_list(self, shape):
228
        slice_dirs = self.data.get_slice_dimensions()
229
        core_dirs = np.array(self.data.get_core_dimensions())
230
        fix = [[]]*2
231
        core_slice = np.array([slice(None)]*len(core_dirs))
232
        shape = tuple([shape[i] for i in range(len(shape))])
233
        values = 'np.arange(shape[slice_dirs[i]])'
234
        index = self._get_slice_dirs_index(slice_dirs, shape, values)
235
        # there may be no slice dirs
236
        index = index if index.size else np.array([[0]])
237
        nSlices = index.shape[1] if index.size else len(fix[0])
238
        nDims = len(shape)
239
240
        ssl = self._single_slice_list(
241
            nSlices, nDims, core_slice, core_dirs, slice_dirs, fix, index)
242
        return ssl
243
244
    def _group_slice_list_in_one_dimension(self, slice_list, max_frames,
245
                                           group_dim, pad=False):
246
        """ Group the slice list in one dimension, stopping at \
247
        boundaries - prepare a slice list for multi-frame plugin processing.
248
        """
249
        if group_dim is None:
250
            return slice_list
251
252
        banked = self._banked_list(slice_list, max_frames, pad=pad)
253
        grouped = []
254
        for group in banked:
255
            sub_groups = self._split_list(group, max_frames)
256
            for sub in sub_groups:
257
                grouped.append(self._group_dimension(sub, group_dim, 1))
258
        return grouped
259
260
    def _group_slice_list_in_multiple_dimensions(self, slice_list, max_frames,
261
                                                 group_dim, pad=False):
262
        """ Group the slice list in multiple dimensions - prepare a slice list\
263
        for file transfer.
264
        """
265
        if group_dim is None:
266
            return slice_list
267
268
        steps = self.data.get_preview().get_starts_stops_steps('steps')
269
        sub_groups = self._banked_list(slice_list, max_frames, pad=pad)
270
271
        grouped = []
272
        for sub in sub_groups:
273
            temp = list(sub[0])
274
            for dim in group_dim:
275
                temp[dim] = self._group_dimension(sub, dim, steps[dim])[dim]
276
            grouped.append(tuple(temp))
277
278
        return grouped
279
280
281
class LocalData(SliceLists):
282
    """ The LocalData class organises the slicing of transferred data to \
283
    give the shape requested by a plugin for each run of 'process_frames'.
284
    """
285
286
    def __init__(self, dtype, transport_data):
287
        self.dtype = dtype  # in or out ProcessData object
288
        self.td = transport_data
289
        self.data = transport_data.data
290
        self.pData = self.data._get_plugin_data()
291
        self.shape = self.data.get_shape()
292
        self.sdir = None
293
294
    def _get_dict(self):
295
        return self._get_dict_in() if self.dtype == 'in' else \
296
            self._get_dict_out()
297
298
    def _get_dict_in(self):
299
        sl_dict = {}
300
        sl = self._get_slice_list()
301
        sl = self._pad_slice_list(sl, '0', 'sum(value.values())')
302
        sl_dict['process'] = sl
303
        return sl_dict
304
305
    def _get_dict_out(self):
306
        sl_dict = {}
307
        sl_dict['process'] = self._get_slice_list()
308
        sl_dict['unpad'] = self.__get_unpad_slice_list(len(sl_dict['process']))
309
        return sl_dict
310
311
    def _get_slice_list(self):
312
        """ Splits a file transfer slice list into a list of (padded) slices
313
        required for each loop of process_frames.
314
        """
315
        slice_dirs = self.data.get_slice_dimensions()
316
        self.sdir = slice_dirs[0] if len(slice_dirs) > 0 else None
317
318
        pData = self.pData
319
        mf_process = pData.meta_data.get('max_frames_process')
320
        shape = pData.get_shape_transfer()
321
        process_ssl = self._get_local_single_slice_list(shape)
322
        
323
        process_gsl = self._group_slice_list_in_one_dimension(
324
                process_ssl, mf_process, self.sdir)
325
        return process_gsl
326
327
    def _banked_list(self, slice_list, max_frames, pad=False):
328
        shape = self.data.get_shape()
329
        slice_dirs = self.data.get_slice_dimensions()
330
        chunk, length, repeat = self._chunk_length_repeat(slice_dirs, shape)
331
        return self._split_list(slice_list, max_frames)
332
333
    def __get_unpad_slice_list(self, reps):
334
        # setting process slice list unpad here - not currently working for 4D data
335
        sl = [slice(None)]*len(self.pData.get_shape_transfer())
336
        if not self.pData.padding:
337
            return tuple([tuple(sl)]*reps)
338
        pad_dict = self.pData.padding._get_padding_directions()
339
        for ddir, value in pad_dict.items():
340
            sl[ddir] = slice(value['before'], -value['after'])
341
        return tuple([tuple(sl)]*reps)
342
343
344
class GlobalData(SliceLists):
345
    """
346
    The GlobalData class organises the movement and slicing of the data from
347
    file.
348
    """
349
350
    def __init__(self, dtype, transport):
351
        self.dtype = dtype  # in or out TransferData object
352
        self.trans = transport
353
        self.data = transport.data
354
        self.pData = self.data._get_plugin_data()
355
        self.shape = self.data.get_shape()
356
357
    def _get_dict(self, pad):
358
        temp = self._get_dict_in(pad) if self.dtype == 'in' else \
359
            self._get_dict_out()
360
        return temp
361
362
    def _get_dict_in(self, pad):
363
        sl_dict = {}
364
        sl, current = \
365
            self._get_slice_list(self.shape, current_sl=True, pad=pad)
366
367
        sl_dict['current'], _ = self._get_frames_per_process(current)
368
        sl, sl_dict['frames'] = self._get_frames_per_process(sl)
369
        if self.trans.pad:
370
            sl = self._pad_slice_list(
371
                sl, "-value['before']", "value['after']")
372
        sl_dict['transfer'] = sl
373
        return sl_dict
374
375
    def _get_dict_out(self):
376
        sl_dict = {}
377
        sl, _ = self._get_slice_list(self.shape)
378
        sl_dict['transfer'], _ = self._get_frames_per_process(sl)
379
        return sl_dict
380
381
    def _banked_list(self, slice_list, max_frames, pad=False):
382
        shape = self.data.get_shape()
383
        slice_dirs = self.data.get_slice_dimensions()
384
        chunk, length, repeat = self._chunk_length_repeat(slice_dirs, shape)
385
        sdir_shape = [shape[i] for i in slice_dirs]
386
        split, split_dim = self._get_split_length(max_frames, sdir_shape)
387
        # split at the boundaries
388
        split_list = self._split_list(slice_list, split) 
389
390
        banked = []
391
        for s in split_list:
392
            # split at max_frames
393
            b = self._split_list(s, max_frames)
394
            banked.extend(b)
395
            if pad and any(pad):
396
                b[-1][-1] = self._fix_list_length(b[-1][-1], pad)
397
398
        return banked
399
400
    def _get_split_length(self, max_frames, sdir_shape):
401
        nDims = 0
402
        while(nDims < len(sdir_shape)):
403
            nDims += 1
404
            prod = np.prod([sdir_shape[i] for i in range(nDims)])
405
            if prod/float(max_frames) >= 1:
406
                break
407
        sdir = self.data.get_slice_dimensions()
408
        return prod, sdir[nDims-1]
0 ignored issues
show
introduced by
The variable prod does not seem to be defined in case the while loop on line 402 is not entered. Are you sure this can never be the case?
Loading history...
409
410
    def _get_padded_shape(self, orig_shape):
411
        """
412
        Get the (fake) shape of the data if it was exactly divisible by mft.
413
        """
414
        trans_shape = self.pData.meta_data.get("transfer_shape")
415
        pad = []
416
        for i, shape in enumerate(orig_shape):
417
            mod = shape % trans_shape[i]
418
            mod = (trans_shape[i] - mod) % trans_shape[i]
419
            diff = trans_shape[i] - shape
420
            pad.append(max(diff, mod))
421
        return pad
422
423
    def _get_global_single_slice_list(self, shape):
424
        slice_dirs = self.data.get_slice_dimensions()
425
        core_dirs = np.array(self.data.get_core_dimensions())
426
        fix = self.data._get_plugin_data()._get_fixed_dimensions()
427
        core_slice = self._get_core_slices(core_dirs)
428
        values = 'self._get_slice_dir_index(slice_dirs[i])'
429
        index = self._get_slice_dirs_index(slice_dirs, shape, values)
430
        nSlices = index.shape[1] if index.size else len(fix[0])
431
        nDims = len(shape)
432
        ssl = self._single_slice_list(
433
            nSlices, nDims, core_slice, core_dirs, slice_dirs, fix, index)
434
        return ssl
435
436
    def _get_slice_dir_index(self, dim, boolean=False):
437
        starts, stops, steps, chunks = \
438
            self.data.get_preview().get_starts_stops_steps()
439
        if chunks[dim] > 1:
440
            dir_idx = np.ravel(np.transpose(
441
                self.trans._get_slice_dir_matrix(dim)))
442
            if boolean:
443
                return self.__get_bool_slice_dir_index(dim, dir_idx)
444
            return dir_idx
445
        else:
446
            fix_dirs, value = \
447
                self.data._get_plugin_data()._get_fixed_dimensions()
448
            if dim in fix_dirs:
449
                return value[fix_dirs.index(dim)]
450
            else:
451
                return np.arange(starts[dim], stops[dim], steps[dim])
452
453
    def _get_slice_list(self, shape, current_sl=None, pad=False):
454
        mft = self.pData._get_max_frames_transfer()
455
        pad = self._get_padded_shape(shape) if pad else False
456
        transfer_ssl = self._get_global_single_slice_list(shape)
457
458
        if transfer_ssl is None:
459
            raise Exception("Data type %s does not support slicing in "
460
                            "directions %s" % (self.get_current_pattern_name(),
461
                                               self.get_slice_directions()))
462
        slice_dims = self.data.get_slice_dimensions()
463
464
        transfer_gsl = self._group_slice_list_in_multiple_dimensions(
465
                transfer_ssl, mft, slice_dims, pad=pad)
466
467
        if current_sl:
468
            mfp = self.pData._get_max_frames_process()
469
            current_sl = self._group_slice_list_in_multiple_dimensions(
470
                    transfer_ssl, mfp, slice_dims, pad=pad)
471
        split_list = self.pData.split
472
        transfer_gsl = self.__split_frames(transfer_gsl, split_list) if \
473
            split_list else transfer_gsl
474
475
        return transfer_gsl, current_sl
476
477
    def _get_padded_data(self, slice_list, end=False):
478
        slice_list = list(slice_list)
479
        pData = self.pData
480
        pad_dims = list(set(self.data.get_core_dimensions() +
481
                            (self.data.get_slice_dimensions())))
482
        pad_list = []
483
        for i in range(len(slice_list)):
484
            pad_list.append([0, 0])
485
486
        data_dict = self.data.data_info.get_dictionary()
487
        shape = data_dict['orig_shape'] if 'orig_shape' in list(data_dict.keys()) \
488
            else self.data.get_shape()
489
490
        for dim in range(len(pad_dims)):
491
            sl = slice_list[dim]
492
            if sl.start < 0:
493
                pad_list[dim][0] = -sl.start
494
                slice_list[dim] = slice(0, sl.stop, sl.step)
495
            diff = sl.stop - shape[dim]
496
            if diff > 0:
497
                pad_list[dim][1] = diff
498
                slice_list[dim] = \
499
                    slice(slice_list[dim].start, sl.stop - diff, sl.step)
500
501
        data = self.data.data[tuple(slice_list)]
502
503
        if np.sum(pad_list):
504
            mode = pData.padding.mode if pData.padding else 'edge'
505
            temp = np.pad(data, tuple(pad_list), mode=mode)
506
            return temp
507
        return data
508