Test Failed
Pull Request — master (#708)
by Daniil
03:33
created

savu.data.transport_data.slice_lists   F

Complexity

Total Complexity 88

Size/Duplication

Total Lines 470
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 336
dl 0
loc 470
rs 2
c 0
b 0
f 0
wmc 88

32 Methods

Rating   Name   Duplication   Size   Complexity  
A SliceLists._get_local_single_slice_list() 0 16 3
A SliceLists._get_global_single_slice_list() 0 12 2
A SliceLists._get_frames_per_process() 0 10 2
A LocalData.__get_unpad_slice_list() 0 9 3
A GlobalData._get_dict() 0 3 2
A SliceLists._get_slice_dirs_index() 0 16 2
A SliceLists.__get_split_length() 0 8 3
A SliceLists.__init__() 0 4 1
A SliceLists._get_process_data() 0 2 1
A SliceLists.__get_shape_of_slice_dirs() 0 11 4
A SliceLists._split_list() 0 2 1
A GlobalData._get_dict_in() 0 13 2
A LocalData._get_dict() 0 3 2
A SliceLists._single_slice_list() 0 15 4
B SliceLists.__get_split_frame_entries() 0 28 7
B GlobalData._get_padded_data() 0 31 8
A SliceLists._fix_list_length() 0 6 1
A LocalData._get_slice_list() 0 16 2
A SliceLists._get_core_slices() 0 20 5
A SliceLists.__split_frames() 0 14 3
A SliceLists._banked_list() 0 18 4
A SliceLists.__chunk_length_repeat() 0 22 3
A LocalData._get_dict_in() 0 6 1
A LocalData.__init__() 0 7 1
A SliceLists._pad_slice_list() 0 22 5
A GlobalData._get_slice_list() 0 22 4
A SliceLists._group_dimension() 0 6 1
A GlobalData._get_dict_out() 0 5 1
A LocalData._get_dict_out() 0 5 1
A GlobalData.__init__() 0 6 1
A SliceLists._group_slice_list_in_multiple_dimensions() 0 19 4
A SliceLists._group_slice_list_in_one_dimension() 0 15 4

How to fix   Complexity   

Complexity

Complex classes like savu.data.transport_data.slice_lists often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: slice_lists
17
   :platform: Unix
18
   :synopsis: Contains classes for creating global and local slice lists
19
20
.. moduleauthor:: Nicola Wadeson <[email protected]>
21
22
"""
23
24
import numpy as np
25
26
27
class SliceLists(object):
28
    """
29
    SliceLists class creates global and local slices lists used to transfer
30
    the data
31
    """
32
33
    def __init__(self, name='SliceLists'):
34
        super(SliceLists, self).__init__()
35
        self.pad = False
36
        self.transfer_data = None
37
38
    def _get_process_data(self):
39
        return self.process_data
40
41
    def _single_slice_list(self, nSlices, nDims, core_slice, core_dirs,
42
                           slice_dirs, fix, index):
43
44
        fix_dirs, value = fix
45
        slice_list = []
46
        for i in range(nSlices):
47
            getitem = np.array([slice(None)]*nDims)
48
            getitem[core_dirs] = core_slice[np.arange(len(core_dirs))]
49
            for f in range(len(fix_dirs)):
50
                getitem[fix_dirs[f]] = slice(value[f], value[f] + 1, 1)
51
            for sdir in range(len(slice_dirs)):
52
                getitem[slice_dirs[sdir]] = slice(index[sdir, i],
53
                                                  index[sdir, i] + 1, 1)
54
            slice_list.append(tuple(getitem))
55
        return slice_list
56
57
    def _get_slice_dirs_index(self, slice_dirs, shape, value, calc=None):
58
        """
59
        returns a list of arrays for each slice dimension, where each array
60
        gives the indices for that slice dimension.
61
        """
62
        # create the indexing array
63
        chunk, length, repeat = self.__chunk_length_repeat(slice_dirs, shape)
64
        values = None
65
        idx_list = []
66
        for i in range(len(slice_dirs)):
67
            c = chunk[i]
68
            r = repeat[i]
69
            values = eval(value)
70
            idx = np.ravel(np.kron(values, np.ones((r, c))))
71
            idx_list.append(idx.astype(int))
72
        return np.array(idx_list)
73
74
    def __chunk_length_repeat(self, slice_dirs, shape):
75
        """
76
        For each slice dimension, determine 3 values relevant to the slicing.
77
78
        :returns: chunk, length, repeat
79
            chunk: how many repeats of the same index value before an increment
80
            length: the slice dimension length (sequence length)
81
            repeat: how many times does the sequence of chunked numbers repeat
82
        :rtype: [int, int, int]
83
        """
84
        sshape = self.__get_shape_of_slice_dirs(slice_dirs, shape)
85
        if not slice_dirs:
86
            return [1], [1], [1]
87
88
        chunk = []
89
        length = []
90
        repeat = []
91
        for dim in range(len(slice_dirs)):
92
            chunk.append(int(np.prod(sshape[0:dim])))
93
            length.append(sshape[dim])
94
            repeat.append(int(np.prod(sshape[dim+1:])))
95
        return chunk, length, repeat
96
97
    def __get_shape_of_slice_dirs(self, slice_dirs, shape):
98
        sshape = [shape[sslice] for sslice in slice_dirs]
99
        if 'var' in sshape:
100
            shape = list(shape)
101
            for index, value in enumerate(shape):
102
                if isinstance(value, str):
103
                    shape[index] = \
104
                        len(self.data_info.get('axis_labels')[index])
105
            shape = tuple(shape)
106
            sshape = [shape[sslice] for sslice in slice_dirs]
107
        return sshape
108
109
    def _get_core_slices(self, core_dirs):
110
        core_slice = []
111
        starts, stops, steps, chunks = \
112
            self.data.get_preview().get_starts_stops_steps()
113
114
        for c in core_dirs:
115
            if (chunks[c]) > 1:
116
                if (stops[c] - starts[c] == 1):
117
                    start = starts[c] - int(chunks[c] / 2.0)
118
                    if start < 0:
119
                        raise Exception('Cannot have a negative value in the '
120
                                        'slice list.')
121
                    stop = starts[c] + (chunks[c] - int(chunks[c] / 2.0))
122
                    core_slice.append(slice(start, stop, 1))
123
                else:
124
                    raise Exception("The core dimension does not support "
125
                                    "multiple chunks.")
126
            else:
127
                core_slice.append(slice(starts[c], stops[c], steps[c]))
128
        return np.array(core_slice)
129
130
    def _banked_list(self, slice_list, max_frames, pad=False):
131
        shape = self.data.get_shape()
132
        slice_dirs = self.data.get_slice_dimensions()
133
        chunk, length, repeat = self.__chunk_length_repeat(slice_dirs, shape)
134
        sdir_shape = [shape[i] for i in slice_dirs]
135
        split, split_dim = self.__get_split_length(max_frames, sdir_shape)
136
        split_list = self._split_list(slice_list, split)
137
138
        banked = []
139
        for s in split_list:
140
            b = self._split_list(s, max_frames)
141
            if pad:
142
                diff = max_frames - len(b[-1])
143
                #print("diff", diff, max_frames, len(b[-1]))
144
                b[-1][-1] = self._fix_list_length(b[-1][-1], diff, split_dim) if diff \
145
                    else b[-1][-1]
146
            banked.extend(b)
147
        return banked
148
149
    def __get_split_length(self, max_frames, shape):
150
        nDims = 0
151
        while(nDims < len(shape)):
152
            nDims += 1
153
            prod = np.prod([shape[i] for i in range(nDims)])
154
            if prod/float(max_frames) >= 1:
155
                break
156
        return prod, nDims-1
0 ignored issues
show
introduced by
The variable prod does not seem to be defined in case the while loop on line 151 is not entered. Are you sure this can never be the case?
Loading history...
157
158
    def _group_dimension(self, sl, dim, step):
159
        start = sl[0][dim].start
160
        stop = sl[-1][dim].stop
161
        working_slice = list(sl[0])
162
        working_slice[dim] = slice(start, stop, step)
163
        return tuple(working_slice)
164
165
    def _split_list(self, the_list, size):
166
        return [the_list[x:x+size] for x in range(0, len(the_list), size)]
167
168
    # This method only works if the split dimensions in the slice list contain
169
    # slice objects
170
    def __split_frames(self, slice_list, split_list):
171
        split = [list(map(int, a.split('.'))) for a in split_list]
172
        dims = [s[0] for s in split]
173
        length = [s[1] for s in split]
174
        replace = self.__get_split_frame_entries(slice_list, dims, length)
175
        # now replace each slice list entry with multiple entries
176
        array_list = []
177
        for sl in slice_list:
178
            new_list = np.array([sl for i in range(len(replace[0]))])
179
            for d, i in zip(dims, list(range(len(dims)))):
180
                new_list[:, d] = replace[i]
181
            array_list += [tuple(a) for a in new_list]
182
183
        return tuple(array_list)
184
185
    def __get_split_frame_entries(self, slice_list, dims, length):
186
        shape = self.get_shape
187
        replace = []
188
        seq_len = []
189
190
        # get the new entries
191
        for d, l in zip(dims, length):
192
            sl = slice_list[0][d]
193
            start = 0 if sl.start is None else sl.start
194
            stop = shape[d] if sl.stop is None else sl.stop
195
            inc = l*sl.step if sl.step else l
196
            temp_list = [slice(a, a+inc) for a in np.arange(start, stop, inc)]
197
            if temp_list[-1].stop > stop:
198
                temp = temp_list[-1]
199
                temp_list[-1] = slice(temp.start, stop, temp.step)
200
            replace.append(temp_list)
201
            seq_len.append(len(temp_list))
202
203
        # calculate the permutations
204
        length = np.array(seq_len)
205
        chunk = [int(np.prod(length[0:dim])) for dim in range(len(dims))]
206
        repeat = [int(np.prod(length[dim+1:])) for dim in range(len(dims))]
207
        full_replace = []
208
        for d in range(len(dims)):
209
            temp = [[replace[d][i]]*chunk[d] for x in range(repeat[d]) for i
210
                    in range(len(replace[d]))]
211
            full_replace.append([t for sub in temp for t in sub])
212
        return full_replace
213
214
    def _get_frames_per_process(self, slice_list):
215
        processes = self.data.exp.meta_data.get("processes")
216
        process = self.data.exp.meta_data.get("process")
217
        frame_idx = np.arange(len(slice_list))
218
        try:
219
            frames = np.array_split(frame_idx, len(processes))[process]
220
            slice_list = slice_list[frames[0]:frames[-1]+1]
221
        except IndexError:
222
            slice_list = []
223
        return slice_list, frames
224
225
    def _pad_slice_list(self, slice_list, inc_start_str: str, inc_stop_str: str):
226
        """ Amend the slice lists to include padding.  Includes variations for
227
        transfer and process slice lists. """
228
        pData = self.data._get_plugin_data()
229
        if not pData.padding:
230
            return slice_list
231
232
        pad_dict = pData.padding._get_padding_directions()
233
234
        shape = self.data.get_shape()
235
        for ddir, value in pad_dict.items():
236
            inc_start = eval(inc_start_str)
237
            inc_stop = eval(inc_stop_str)
238
            for i in range(len(slice_list)):
239
                slice_list[i] = list(slice_list[i])
240
                sl = slice_list[i][ddir]
241
                if sl.start is None:
242
                    sl = slice(0, shape[ddir], 1)
243
                slice_list[i][ddir] = \
244
                    slice(sl.start + inc_start, sl.stop + inc_stop, sl.step)
245
                slice_list[i] = tuple(slice_list[i])
246
        return slice_list
247
248
    def _fix_list_length(self, sl, length, dim):
249
        sdir = self.data.get_slice_dimensions()
250
        sl = list(sl)
251
        s = sl[sdir[dim]]
252
        sl[sdir[dim]] = slice(s.start, s.stop + length*s.step, s.step)
253
        return tuple(sl)
254
255
    def _get_local_single_slice_list(self, shape):
256
        slice_dirs = self.data.get_slice_dimensions()
257
        core_dirs = np.array(self.data.get_core_dimensions())
258
        fix = [[]]*2
259
        core_slice = np.array([slice(None)]*len(core_dirs))
260
        shape = tuple([shape[i] for i in range(len(shape))])
261
        values = 'np.arange(shape[slice_dirs[i]])'
262
        index = self._get_slice_dirs_index(slice_dirs, shape, values)
263
        # there may be no slice dirs
264
        index = index if index.size else np.array([[0]])
265
        nSlices = index.shape[1] if index.size else len(fix[0])
266
        nDims = len(shape)
267
268
        ssl = self._single_slice_list(
269
            nSlices, nDims, core_slice, core_dirs, slice_dirs, fix, index)
270
        return ssl
271
272
    def _group_slice_list_in_one_dimension(self, slice_list, max_frames,
273
                                           group_dim, pad=False):
274
        """ Group the slice list in one dimension, stopping at \
275
        boundaries - prepare a slice list for multi-frame plugin processing.
276
        """
277
        if group_dim is None:
278
            return slice_list
279
280
        banked = self._banked_list(slice_list, max_frames, pad=pad)
281
        grouped = []
282
        for group in banked:
283
            sub_groups = self._split_list(group, max_frames)
284
            for sub in sub_groups:
285
                grouped.append(self._group_dimension(sub, group_dim, 1))
286
        return grouped
287
288
    def _get_global_single_slice_list(self, shape):
289
        slice_dirs = self.data.get_slice_dimensions()
290
        core_dirs = np.array(self.data.get_core_dimensions())
291
        fix = self.data._get_plugin_data()._get_fixed_dimensions()
292
        core_slice = self._get_core_slices(core_dirs)
293
        values = 'self._get_slice_dir_index(slice_dirs[i])'
294
        index = self._get_slice_dirs_index(slice_dirs, shape, values)
295
        nSlices = index.shape[1] if index.size else len(fix[0])
296
        nDims = len(shape)
297
        ssl = self._single_slice_list(
298
            nSlices, nDims, core_slice, core_dirs, slice_dirs, fix, index)
299
        return ssl
300
301
    def _group_slice_list_in_multiple_dimensions(self, slice_list, max_frames,
302
                                                 group_dim, pad=False):
303
        """ Group the slice list in multiple dimensions - prepare a slice list\
304
        for file transfer.
305
        """
306
        if group_dim is None:
307
            return slice_list
308
309
        steps = self.data.get_preview().get_starts_stops_steps('steps')
310
        sub_groups = self._banked_list(slice_list, max_frames, pad=pad)
311
312
        grouped = []
313
        for sub in sub_groups:
314
            temp = list(sub[0])
315
            for dim in group_dim:
316
                temp[dim] = self._group_dimension(sub, dim, steps[dim])[dim]
317
            grouped.append(tuple(temp))
318
319
        return grouped
320
321
class LocalData(object):
322
    """ The LocalData class organises the slicing of transferred data to \
323
    give the shape requested by a plugin for each run of 'process_frames'.
324
    """
325
326
    def __init__(self, dtype, transport_data):
327
        self.dtype = dtype  # in or out ProcessData object
328
        self.td = transport_data
329
        self.data = transport_data.data
330
        self.pData = self.data._get_plugin_data()
331
        self.shape = self.data.get_shape()
332
        self.sdir = None
333
334
    def _get_dict(self):
335
        return self._get_dict_in() if self.dtype == 'in' else \
336
            self._get_dict_out()
337
338
    def _get_dict_in(self):
339
        sl_dict = {}
340
        sl = self._get_slice_list()
341
        sl = self.td._pad_slice_list(sl, '0', 'sum(value.values())')
342
        sl_dict['process'] = sl
343
        return sl_dict
344
345
    def _get_dict_out(self):
346
        sl_dict = {}
347
        sl_dict['process'] = self._get_slice_list()
348
        sl_dict['unpad'] = self.__get_unpad_slice_list(len(sl_dict['process']))
349
        return sl_dict
350
351
    def _get_slice_list(self):
352
        """ Splits a file transfer slice list into a list of (padded) slices
353
        required for each loop of process_frames.
354
        """
355
        slice_dirs = self.data.get_slice_dimensions()
356
        self.sdir = slice_dirs[0] if len(slice_dirs) > 0 else None
357
358
        pData = self.pData
359
        mf_process = pData.meta_data.get('max_frames_process')
360
        shape = pData.get_shape_transfer()
361
362
        process_ssl = self.td._get_local_single_slice_list(shape)
363
364
        process_gsl = self.td._group_slice_list_in_one_dimension(
365
                process_ssl, mf_process, self.sdir, pad=True) # pad if mfp is > nSlices (e.g. mfp = fixed int)
366
        return process_gsl
367
368
    def __get_unpad_slice_list(self, reps):
369
        # setting process slice list unpad here - not currently working for 4D data
370
        sl = [slice(None)]*len(self.pData.get_shape_transfer())
371
        if not self.pData.padding:
372
            return tuple([tuple(sl)]*reps)
373
        pad_dict = self.pData.padding._get_padding_directions()
374
        for ddir, value in pad_dict.items():
375
            sl[ddir] = slice(value['before'], -value['after'])
376
        return tuple([tuple(sl)]*reps)
377
378
379
class GlobalData(object):
380
    """
381
    The GlobalData class organises the movement and slicing of the data from
382
    file.
383
    """
384
385
    def __init__(self, dtype, transport):
386
        self.dtype = dtype  # in or out TransferData object
387
        self.trans = transport
388
        self.data = transport.data
389
        self.pData = self.data._get_plugin_data()
390
        self.shape = self.data.get_shape()
391
392
    def _get_dict(self):
393
        return self._get_dict_in() if self.dtype == 'in' else \
394
            self._get_dict_out()
395
396
    def _get_dict_in(self):
397
        sl_dict = {}
398
        sl, current = \
399
            self._get_slice_list(self.shape, current_sl=True, pad=True)
400
401
        sl_dict['current'], _ = self.trans._get_frames_per_process(current)
402
        sl, sl_dict['frames'] = self.trans._get_frames_per_process(sl)
403
404
        if self.trans.pad:
405
            sl = self.trans._pad_slice_list(
406
                sl, "-value['before']", "value['after']")
407
        sl_dict['transfer'] = sl
408
        return sl_dict
409
410
    def _get_dict_out(self):
411
        sl_dict = {}
412
        sl, _ = self._get_slice_list(self.shape)
413
        sl_dict['transfer'], _ = self.trans._get_frames_per_process(sl)
414
        return sl_dict
415
416
    def _get_slice_list(self, shape, current_sl=None, pad=False):
417
        mft = self.pData._get_max_frames_transfer()
418
        transfer_ssl = self.trans._get_global_single_slice_list(shape)
419
420
        if transfer_ssl is None:
421
            raise Exception("Data type %s does not support slicing in "
422
                            "directions %s" % (self.get_current_pattern_name(),
423
                                               self.get_slice_directions()))
424
        slice_dims = self.data.get_slice_dimensions()
425
        transfer_gsl = self.trans._group_slice_list_in_multiple_dimensions(
426
                transfer_ssl, mft, slice_dims, pad=pad)
427
428
        if current_sl:
429
            mfp = self.pData._get_max_frames_process()
430
            current_sl = self.trans._group_slice_list_in_multiple_dimensions(
431
                    transfer_ssl, mfp, slice_dims, pad=pad)
432
433
        split_list = self.pData.split
434
        transfer_gsl = self.__split_frames(transfer_gsl, split_list) if \
435
            split_list else transfer_gsl
436
437
        return transfer_gsl, current_sl
438
439
    def _get_padded_data(self, slice_list, end=False):
440
        slice_list = list(slice_list)
441
        pData = self.pData
442
        pad_dims = list(set(self.data.get_core_dimensions() +
443
                            (self.data.get_slice_dimensions())))
444
        pad_list = []
445
        for i in range(len(slice_list)):
446
            pad_list.append([0, 0])
447
448
        data_dict = self.data.data_info.get_dictionary()
449
        shape = data_dict['orig_shape'] if 'orig_shape' in list(data_dict.keys()) \
450
            else self.data.get_shape()
451
452
        for dim in range(len(pad_dims)):
453
            sl = slice_list[dim]
454
            if sl.start < 0:
455
                pad_list[dim][0] = -sl.start
456
                slice_list[dim] = slice(0, sl.stop, sl.step)
457
            diff = sl.stop - shape[dim]
458
            if diff > 0:
459
                pad_list[dim][1] = diff
460
                slice_list[dim] = \
461
                    slice(slice_list[dim].start, sl.stop - diff, sl.step)
462
463
        data = self.data.data[tuple(slice_list)]
464
465
        if np.sum(pad_list):
466
            mode = pData.padding.mode if pData.padding else 'edge'
467
            temp = np.pad(data, tuple(pad_list), mode=mode)
468
            return temp
469
        return data
470