Test Failed
Pull Request — master (#830)
by Daniil
04:40
created

savu.plugins.loaders.random_hdf5_loader   A

Complexity

Total Complexity 24

Size/Duplication

Total Lines 172
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 118
dl 0
loc 172
rs 10
c 0
b 0
f 0
wmc 24

8 Methods

Rating   Name   Duplication   Size   Complexity  
A RandomHdf5Loader._get_n_entries() 0 2 1
A RandomHdf5Loader._set_rotation_angles() 0 17 4
A RandomHdf5Loader.__convert_patterns() 0 12 2
A RandomHdf5Loader.__parameter_checks() 0 4 2
A RandomHdf5Loader.__get_start_slice_list() 0 15 5
A RandomHdf5Loader.setup() 0 17 1
A RandomHdf5Loader.__init__() 0 2 1
C RandomHdf5Loader.__get_backing_file() 0 58 8
1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: random_hdf5_loader
17
   :platform: Unix
18
   :synopsis: A loader that creates a random number generated hdf5 dataset of\
19
       any size.
20
21
.. moduleauthor:: Nicola Wadeson <[email protected]>
22
23
"""
24
25
import os
26
import h5py
27
import logging
28
import numpy as np
29
30
from savu.data.chunking import Chunking
31
from savu.plugins.utils import register_plugin
32
from savu.plugins.loaders.base_loader import BaseLoader
33
from savu.plugins.savers.utils.hdf5_utils import Hdf5Utils
34
35
36
@register_plugin
37
class RandomHdf5Loader(BaseLoader):
38
    def __init__(self, name='RandomHdf5Loader'):
39
        super(RandomHdf5Loader, self).__init__(name)
40
41
    def setup(self):
42
        exp = self.exp
43
        data_obj = exp.create_data_object('in_data',
44
                                          self.parameters['dataset_name'])
45
46
        data_obj.set_axis_labels(*self.parameters['axis_labels'])
47
        self.__convert_patterns(data_obj)
48
        self.__parameter_checks(data_obj)
49
50
        data_obj.backing_file = self.__get_backing_file(data_obj)
51
        data_obj.data = data_obj.backing_file['/']['test']
52
        data_obj.data.dtype # Need to do something to .data to keep the file open!
53
54
        data_obj.set_shape(data_obj.data.shape)
55
        self.n_entries = data_obj.get_shape()[0]
56
        self._set_rotation_angles(data_obj, self._get_n_entries())
57
        return data_obj
58
59
    def __get_backing_file(self, data_obj):
60
        fname = '%s/%s.h5' % \
61
            (self.exp.get('out_path'), self.parameters['file_name'])
62
63
        if os.path.exists(fname):
64
            return h5py.File(fname, 'r')
65
66
        self.hdf5 = Hdf5Utils(self.exp)
67
68
        size = tuple(self.parameters['size'])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable tuple does not seem to be defined.
Loading history...
69
70
        patterns = data_obj.get_data_patterns()
71
        p_name = patterns[self.parameters['pattern']] if \
72
            self.parameters['pattern'] is not None else list(patterns.keys())[0]
73
        p_name = list(patterns.keys())[0]
74
        p_dict = patterns[p_name]
75
        p_dict['max_frames_transfer'] = 1
76
        nnext = {p_name: p_dict}
77
78
        pattern_idx = {'current': nnext, 'next': nnext}
79
        chunking = Chunking(self.exp, pattern_idx)
80
        chunks = chunking._calculate_chunking(size, np.int16)
81
82
        h5file = self.hdf5._open_backing_h5(fname, 'w')
83
        dset = h5file.create_dataset('test', size, chunks=chunks)
84
85
        self.exp._barrier()
86
87
        slice_dirs = list(nnext.values())[0]['slice_dims']
88
        nDims = len(dset.shape)
89
        total_frames = np.prod([dset.shape[i] for i in slice_dirs])
90
        sub_size = \
91
            [1 if i in slice_dirs else dset.shape[i] for i in range(nDims)]
92
93
        # need an mpi barrier after creating the file before populating it
94
        idx = 0
95
        sl, total_frames = \
96
            self.__get_start_slice_list(slice_dirs, dset.shape, total_frames)
97
        # calculate the first slice
98
        for i in range(total_frames):
99
            low, high = self.parameters['range']
100
            dset[tuple(sl)] = np.random.randint(
101
                low, high=high, size=sub_size, dtype=self.parameters['dtype_'])
102
            if sl[slice_dirs[idx]].stop == dset.shape[slice_dirs[idx]]:
103
                idx += 1
104
                if idx == len(slice_dirs):
105
                    break
106
            tmp = sl[slice_dirs[idx]]
107
            sl[slice_dirs[idx]] = slice(tmp.start+1, tmp.stop+1)
108
109
        self.exp._barrier()
110
111
        try:
112
            h5file.close()
113
        except:
114
            logging.debug('There was a problem trying to close the file in random_hdf5_loader')
115
116
        return self.hdf5._open_backing_h5(fname, 'r')
117
118
    def __get_start_slice_list(self, slice_dirs, shape, n_frames):
119
        n_processes = len(self.exp.get('processes'))
120
        rank = self.exp.get('process')
121
        frames = np.array_split(np.arange(n_frames), n_processes)[rank]
122
        f_range = list(range(0, frames[0])) if len(frames) else []
123
        sl = [slice(0, 1) if i in slice_dirs else slice(None)
124
              for i in range(len(shape))]
125
        idx = 0
126
        for i in f_range:
127
            if sl[slice_dirs[idx]] == shape[slice_dirs[idx]]-1:
128
                idx += 1
129
            tmp = sl[slice_dirs[idx]]
130
            sl[slice_dirs[idx]] = slice(tmp.start+1, tmp.stop+1)
131
132
        return sl, len(frames)
133
134
    def __convert_patterns(self, data_obj):
135
        pattern_list = self.parameters['patterns']
136
        for p in pattern_list:
137
            p_split = p.split('.')
138
            name = p_split[0]
139
            dims = p_split[1:]
140
            core_dims = tuple([int(i[0]) for i in [d.split('c') for d in dims]
141
                              if len(i) == 2])
142
            slice_dims = tuple([int(i[0]) for i in [d.split('s') for d in dims]
143
                               if len(i) == 2])
144
            data_obj.add_pattern(
145
                    name, core_dims=core_dims, slice_dims=slice_dims)
146
147
    def _set_rotation_angles(self, data_obj, n_entries):
148
        angles = self.parameters['angles']
149
150
        if angles is None:
151
            angles = np.linspace(0, 180, n_entries)
152
        else:
153
            try:
154
                angles = eval(angles)
155
            except:
156
                raise Exception('Cannot set angles in loader.')
157
158
        n_angles = len(angles)
159
        data_angles = n_entries
160
        if data_angles != n_angles:
161
            raise Exception("The number of angles %s does not match the data "
162
                            "dimension length %s", n_angles, data_angles)
163
        data_obj.meta_data.set("rotation_angle", angles)
164
165
    def __parameter_checks(self, data_obj):
166
        if not self.parameters['size']:
167
            raise Exception(
168
                    'Please specifiy the size of the dataset to create.')
169
170
    def _get_n_entries(self):
171
        return self.n_entries
172