GroupedDataLoader.__init__()   B
last analyzed

Complexity

Conditions 7

Size

Total Lines 96
Code Lines 49

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 49
dl 0
loc 96
rs 7.269
c 0
b 0
f 0
cc 7
nop 10

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
"""
2
Load grouped data.
3
Supported formats: h5 and Nifti.
4
Image data can be labeled or unlabeled.
5
Read https://deepreg.readthedocs.io/en/latest/api/loader.html#module-deepreg.dataset.loader.grouped_loader for more details.
6
"""
7
import random
8
from copy import deepcopy
9
from typing import List, Optional, Tuple, Union
10
11
from deepreg.dataset.loader.interface import (
12
    AbstractUnpairedDataLoader,
13
    GeneratorDataLoader,
14
)
15
from deepreg.dataset.util import check_difference_between_two_lists
16
from deepreg.registry import REGISTRY
17
18
19
@REGISTRY.register_data_loader(name="grouped")
20
class GroupedDataLoader(AbstractUnpairedDataLoader, GeneratorDataLoader):
21
    """
22
    Load grouped data.
23
24
    Yield indexes of images to load using
25
    sample_index_generator from GeneratorDataLoader.
26
    AbstractUnpairedLoader handles different file formats
27
    """
28
29
    def __init__(
30
        self,
31
        file_loader,
32
        data_dir_paths: List[str],
33
        labeled: bool,
34
        sample_label: Optional[str],
35
        intra_group_prob: float,
36
        intra_group_option: str,
37
        sample_image_in_group: bool,
38
        seed: Optional[int],
39
        image_shape: Union[Tuple[int, ...], List[int]],
40
    ):
41
        """
42
        :param file_loader: a subclass of FileLoader
43
        :param data_dir_paths: paths of the directory storing data,
44
          the data has to be saved under two different sub-directories:
45
46
          - images
47
          - labels
48
49
        :param labeled: bool, true if the data is labeled, false if unlabeled
50
        :param sample_label: "sample" or "all", read `get_label_indices`
51
            in deepreg/dataset/util.py for more details.
52
        :param intra_group_prob: float between 0 and 1,
53
54
          - 0 means generating only inter-group samples,
55
          - 1 means generating only intra-group samples
56
57
        :param intra_group_option: str, "forward", "backward, or "unconstrained"
58
        :param sample_image_in_group: bool,
59
60
          - if true, only one image pair will be yielded for each group,
61
            so one epoch has num_groups pairs of data,
62
          - if false, iterate through this loader will generate all possible pairs
63
64
        :param seed: controls the randomness in sampling,
65
            if seed=None, then the randomness is not fixed
66
        :param image_shape: list or tuple of length 3,
67
            corresponding to (dim1, dim2, dim3) of the 3D image
68
        """
69
        super().__init__(
70
            image_shape=image_shape,
71
            labeled=labeled,
72
            sample_label=sample_label,
73
            seed=seed,
74
        )
75
        assert isinstance(
76
            data_dir_paths, list
77
        ), f"data_dir_paths must be list of strings, got {data_dir_paths}"
78
        # init
79
        # the indices for identifying an image pair is (group1, sample1, group2, sample2, label)
80
        self.num_indices = 5
81
        self.intra_group_option = intra_group_option
82
        self.intra_group_prob = intra_group_prob
83
        self.sample_image_in_group = sample_image_in_group
84
        # set file loaders
85
        # grouped data are not paired data, so moving/fixed share the same file loader for images/labels
86
        loader_image = file_loader(
87
            dir_paths=data_dir_paths, name="images", grouped=True
88
        )
89
        self.loader_moving_image = loader_image
90
        self.loader_fixed_image = loader_image
91
        if self.labeled is True:
92
            loader_label = file_loader(
93
                dir_paths=data_dir_paths, name="labels", grouped=True
94
            )
95
            self.loader_moving_label = loader_label
96
            self.loader_fixed_label = loader_label
97
        self.validate_data_files()
98
        # get group related stats
99
        self.num_groups = self.loader_moving_image.get_num_groups()
100
        self.num_images_per_group = self.loader_moving_image.get_num_images_per_group()
101
        if self.intra_group_prob < 1:
102
            if self.num_groups < 2:
103
                raise ValueError(
104
                    f"There are {self.num_groups} groups, "
105
                    f"we need at least two groups for inter group sampling"
106
                )
107
        # calculate number of samples and save pre-calculated sample indices
108
        if self.sample_image_in_group is True:
109
            # one image pair in each group (pair) will be yielded
110
            self.sample_indices = None
111
            self._num_samples = self.num_groups
112
        else:
113
            # all possible pair in each group (pair) will be yielded
114
            if intra_group_prob not in [0, 1]:
115
                raise ValueError(
116
                    "Mixing intra and inter groups is not supported"
117
                    " when not sampling pairs."
118
                )
119
            if intra_group_prob == 0:  # inter group
120
                self.sample_indices = self.get_inter_sample_indices()
121
            else:  # intra group
122
                self.sample_indices = self.get_intra_sample_indices()
123
124
            self._num_samples = len(self.sample_indices)  # type: ignore
125
126
    def validate_data_files(self):
127
        """If the data are labeled, verify image loader and label loader have the same files."""
128
        if self.labeled is True:
129
            image_ids = self.loader_moving_image.get_data_ids()
130
            label_ids = self.loader_moving_label.get_data_ids()
131
            check_difference_between_two_lists(
132
                list1=image_ids,
133
                list2=label_ids,
134
                name="images and labels in grouped loader",
135
            )
136
137
    def get_intra_sample_indices(self) -> list:
138
        """
139
        Calculate the sample indices for intra-group sampling
140
        The index to identify a sample is (group1, image1, group2, image2), means
141
        - image1 of group1 is moving image
142
        - image2 of group2 is fixed image
143
144
        Assuming group i has ni images,
145
        then in total the number of samples are
146
        - sum( ni * (ni-1) / 2 ) for forward/backward
147
        - sum( ni * (ni-1) ) for unconstrained
148
149
        :return: a list of sample indices
150
        """
151
        intra_sample_indices = []
152
        for group_index in range(self.num_groups):
153
            num_images_in_group = self.num_images_per_group[group_index]
154
            if self.intra_group_option == "forward":
155
                for i in range(num_images_in_group):
156
                    for j in range(i):
157
                        # j < i
158
                        intra_sample_indices.append((group_index, j, group_index, i))
159
            elif self.intra_group_option == "backward":
160
                for i in range(num_images_in_group):
161
                    for j in range(i):
162
                        # i > j
163
                        intra_sample_indices.append((group_index, i, group_index, j))
164
            elif self.intra_group_option == "unconstrained":
165
                for i in range(num_images_in_group):
166
                    for j in range(i):
167
                        # j < i, i > j
168
                        intra_sample_indices.append((group_index, j, group_index, i))
169
                        intra_sample_indices.append((group_index, i, group_index, j))
170
            else:
171
                raise ValueError(
172
                    "Unknown intra_group_option, must be forward/backward/unconstrained"
173
                )
174
        return intra_sample_indices
175
176
    def get_inter_sample_indices(self) -> list:
177
        """
178
        Calculate the sample indices for inter-group sampling
179
        The index to identify a sample is (group1, image1, group2, image2), means
180
181
          - image1 of group1 is moving image
182
          - image2 of group2 is fixed image
183
184
        All pairs of images in the dataset are registered.
185
        Assuming group i has ni images and that N=[n1, n2, ..., nI],
186
        then in total the number of samples are:
187
        sum(N) * (sum(N)-1) - sum( N * (N-1) )
188
189
        :return: a list of sample indices
190
        """
191
        inter_sample_indices = []
192
        for group_index1 in range(self.num_groups):
193
            for group_index2 in range(self.num_groups):
194
                if group_index1 == group_index2:  # do not sample from the same group
195
                    continue
196
                num_images_in_group1 = self.num_images_per_group[group_index1]
197
                num_images_in_group2 = self.num_images_per_group[group_index2]
198
                for image_index1 in range(num_images_in_group1):
199
                    for image_index2 in range(num_images_in_group2):
200
                        inter_sample_indices.append(
201
                            (group_index1, image_index1, group_index2, image_index2)
202
                        )
203
        return inter_sample_indices
204
205
    def sample_index_generator(self):
206
        """
207
        Yield (moving_index, fixed_index, image_indices) sequentially, where
208
209
          - moving_index = (group1, image1)
210
          - fixed_index = (group2, image2)
211
          - image_indices = [group1, image1, group2, image2]
212
        """
213
        rnd = random.Random(self.seed)  # set random seed
214
        if self.sample_image_in_group is True:
215
            # for each group sample one image pair only
216
            group_indices = [i for i in range(self.num_groups)]
217
            rnd.shuffle(group_indices)
218
            for group_index in group_indices:
219
                if rnd.random() <= self.intra_group_prob:
220
                    # intra-group sampling
221
                    # inside the group_index-th group, we sample two images as moving/fixed
222
                    group_index1 = group_index
223
                    group_index2 = group_index
224
                    num_images_in_group = self.num_images_per_group[group_index]
225
                    if num_images_in_group < 2:
226
                        # skip groups having <2 images
227
                        # currently have not encountered
228
                        continue  # pragma: no cover
229
230
                    image_index1, image_index2 = rnd.sample(
231
                        [i for i in range(num_images_in_group)], 2
232
                    )  # sample two unique indices
233
                    if self.intra_group_option == "forward":
234
                        # image_index1 < image_index2
235
                        image_index1, image_index2 = (
236
                            min(image_index1, image_index2),
237
                            max(image_index1, image_index2),
238
                        )
239
                    elif self.intra_group_option == "backward":
240
                        # image_index1 > image_index2
241
                        image_index1, image_index2 = (
242
                            max(image_index1, image_index2),
243
                            min(image_index1, image_index2),
244
                        )
245
                    elif self.intra_group_option == "unconstrained":
246
                        pass
247
                    else:
248
                        raise ValueError(
249
                            f"Unknown intra_group_option, "
250
                            f"must be forward/backward/unconstrained, "
251
                            f"got {self.intra_group_option}"
252
                        )
253
                else:
254
                    # inter-group sampling
255
                    # we sample another group, then in each group we sample one image
256
                    group_index1 = group_index
257
                    group_index2 = rnd.choice(
258
                        [i for i in range(self.num_groups) if i != group_index]
259
                    )
260
                    num_images_in_group1 = self.num_images_per_group[group_index1]
261
                    num_images_in_group2 = self.num_images_per_group[group_index2]
262
                    image_index1 = rnd.choice([i for i in range(num_images_in_group1)])
263
                    image_index2 = rnd.choice([i for i in range(num_images_in_group2)])
264
265
                moving_index = (group_index1, image_index1)
266
                fixed_index = (group_index2, image_index2)
267
                image_indices = [group_index1, image_index1, group_index2, image_index2]
268
                yield moving_index, fixed_index, image_indices
269
        else:
270
            # sample indices are pre-calculated
271
            assert self.sample_indices is not None
272
            sample_indices = deepcopy(self.sample_indices)
273
            rnd.shuffle(sample_indices)  # shuffle in place
274
            for sample_index in sample_indices:
275
                group_index1, image_index1, group_index2, image_index2 = sample_index
276
                moving_index = (group_index1, image_index1)
277
                fixed_index = (group_index2, image_index2)
278
                image_indices = [group_index1, image_index1, group_index2, image_index2]
279
                yield moving_index, fixed_index, image_indices
280
281
    def close(self):
282
        """Close file loaders"""
283
        self.loader_moving_image.close()
284
        if self.labeled is True:
285
            self.loader_moving_label.close()
286