GroupedDataLoader.sample_index_generator()   C
last analyzed

Complexity

Conditions 9

Size

Total Lines 75
Code Lines 46

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 46
dl 0
loc 75
rs 6.4339
c 0
b 0
f 0
cc 9
nop 1

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""
2
Load grouped data.
3
Supported formats: h5 and Nifti.
4
Image data can be labeled or unlabeled.
5
Read https://deepreg.readthedocs.io/en/latest/api/loader.html#module-deepreg.dataset.loader.grouped_loader for more details.
6
"""
7
import random
8
from copy import deepcopy
9
from typing import List, Optional, Tuple, Union
10
11
from deepreg.dataset.loader.interface import (
12
    AbstractUnpairedDataLoader,
13
    GeneratorDataLoader,
14
)
15
from deepreg.dataset.util import check_difference_between_two_lists
16
from deepreg.registry import REGISTRY
17
18
19
@REGISTRY.register_data_loader(name="grouped")
20
class GroupedDataLoader(AbstractUnpairedDataLoader, GeneratorDataLoader):
21
    """
22
    Load grouped data.
23
24
    Yield indexes of images to load using
25
    sample_index_generator from GeneratorDataLoader.
26
    AbstractUnpairedLoader handles different file formats
27
    """
28
29
    def __init__(
30
        self,
31
        file_loader,
32
        data_dir_paths: List[str],
33
        labeled: bool,
34
        sample_label: Optional[str],
35
        intra_group_prob: float,
36
        intra_group_option: str,
37
        sample_image_in_group: bool,
38
        seed: Optional[int],
39
        image_shape: Union[Tuple[int, ...], List[int]],
40
    ):
41
        """
42
        :param file_loader: a subclass of FileLoader
43
        :param data_dir_paths: paths of the directory storing data,
44
          the data has to be saved under two different sub-directories:
45
46
          - images
47
          - labels
48
49
        :param labeled: bool, true if the data is labeled, false if unlabeled
50
        :param sample_label: "sample" or "all", read `get_label_indices`
51
            in deepreg/dataset/util.py for more details.
52
        :param intra_group_prob: float between 0 and 1,
53
54
          - 0 means generating only inter-group samples,
55
          - 1 means generating only intra-group samples
56
57
        :param intra_group_option: str, "forward", "backward, or "unconstrained"
58
        :param sample_image_in_group: bool,
59
60
          - if true, only one image pair will be yielded for each group,
61
            so one epoch has num_groups pairs of data,
62
          - if false, iterate through this loader will generate all possible pairs
63
64
        :param seed: controls the randomness in sampling,
65
            if seed=None, then the randomness is not fixed
66
        :param image_shape: list or tuple of length 3,
67
            corresponding to (dim1, dim2, dim3) of the 3D image
68
        """
69
        super().__init__(
70
            image_shape=image_shape,
71
            labeled=labeled,
72
            sample_label=sample_label,
73
            seed=seed,
74
        )
75
        assert isinstance(
76
            data_dir_paths, list
77
        ), f"data_dir_paths must be list of strings, got {data_dir_paths}"
78
        # init
79
        # the indices for identifying an image pair is (group1, sample1, group2, sample2, label)
80
        self.num_indices = 5
81
        self.intra_group_option = intra_group_option
82
        self.intra_group_prob = intra_group_prob
83
        self.sample_image_in_group = sample_image_in_group
84
        # set file loaders
85
        # grouped data are not paired data, so moving/fixed share the same file loader for images/labels
86
        loader_image = file_loader(
87
            dir_paths=data_dir_paths, name="images", grouped=True
88
        )
89
        self.loader_moving_image = loader_image
90
        self.loader_fixed_image = loader_image
91
        if self.labeled is True:
92
            loader_label = file_loader(
93
                dir_paths=data_dir_paths, name="labels", grouped=True
94
            )
95
            self.loader_moving_label = loader_label
96
            self.loader_fixed_label = loader_label
97
        self.validate_data_files()
98
        # get group related stats
99
        self.num_groups = self.loader_moving_image.get_num_groups()
100
        self.num_images_per_group = self.loader_moving_image.get_num_images_per_group()
101
        if self.intra_group_prob < 1:
102
            if self.num_groups < 2:
103
                raise ValueError(
104
                    f"There are {self.num_groups} groups, "
105
                    f"we need at least two groups for inter group sampling"
106
                )
107
        # calculate number of samples and save pre-calculated sample indices
108
        if self.sample_image_in_group is True:
109
            # one image pair in each group (pair) will be yielded
110
            self.sample_indices = None
111
            self._num_samples = self.num_groups
112
        else:
113
            # all possible pair in each group (pair) will be yielded
114
            if intra_group_prob not in [0, 1]:
115
                raise ValueError(
116
                    "Mixing intra and inter groups is not supported"
117
                    " when not sampling pairs."
118
                )
119
            if intra_group_prob == 0:  # inter group
120
                self.sample_indices = self.get_inter_sample_indices()
121
            else:  # intra group
122
                self.sample_indices = self.get_intra_sample_indices()
123
124
            self._num_samples = len(self.sample_indices)  # type: ignore
125
126
    def validate_data_files(self):
127
        """If the data are labeled, verify image loader and label loader have the same files."""
128
        if self.labeled is True:
129
            image_ids = self.loader_moving_image.get_data_ids()
130
            label_ids = self.loader_moving_label.get_data_ids()
131
            check_difference_between_two_lists(
132
                list1=image_ids,
133
                list2=label_ids,
134
                name="images and labels in grouped loader",
135
            )
136
137
    def get_intra_sample_indices(self) -> list:
138
        """
139
        Calculate the sample indices for intra-group sampling
140
        The index to identify a sample is (group1, image1, group2, image2), means
141
        - image1 of group1 is moving image
142
        - image2 of group2 is fixed image
143
144
        Assuming group i has ni images,
145
        then in total the number of samples are
146
        - sum( ni * (ni-1) / 2 ) for forward/backward
147
        - sum( ni * (ni-1) ) for unconstrained
148
149
        :return: a list of sample indices
150
        """
151
        intra_sample_indices = []
152
        for group_index in range(self.num_groups):
153
            num_images_in_group = self.num_images_per_group[group_index]
154
            if self.intra_group_option == "forward":
155
                for i in range(num_images_in_group):
156
                    for j in range(i):
157
                        # j < i
158
                        intra_sample_indices.append((group_index, j, group_index, i))
159
            elif self.intra_group_option == "backward":
160
                for i in range(num_images_in_group):
161
                    for j in range(i):
162
                        # i > j
163
                        intra_sample_indices.append((group_index, i, group_index, j))
164
            elif self.intra_group_option == "unconstrained":
165
                for i in range(num_images_in_group):
166
                    for j in range(i):
167
                        # j < i, i > j
168
                        intra_sample_indices.append((group_index, j, group_index, i))
169
                        intra_sample_indices.append((group_index, i, group_index, j))
170
            else:
171
                raise ValueError(
172
                    "Unknown intra_group_option, must be forward/backward/unconstrained"
173
                )
174
        return intra_sample_indices
175
176
    def get_inter_sample_indices(self) -> list:
177
        """
178
        Calculate the sample indices for inter-group sampling
179
        The index to identify a sample is (group1, image1, group2, image2), means
180
181
          - image1 of group1 is moving image
182
          - image2 of group2 is fixed image
183
184
        All pairs of images in the dataset are registered.
185
        Assuming group i has ni images and that N=[n1, n2, ..., nI],
186
        then in total the number of samples are:
187
        sum(N) * (sum(N)-1) - sum( N * (N-1) )
188
189
        :return: a list of sample indices
190
        """
191
        inter_sample_indices = []
192
        for group_index1 in range(self.num_groups):
193
            for group_index2 in range(self.num_groups):
194
                if group_index1 == group_index2:  # do not sample from the same group
195
                    continue
196
                num_images_in_group1 = self.num_images_per_group[group_index1]
197
                num_images_in_group2 = self.num_images_per_group[group_index2]
198
                for image_index1 in range(num_images_in_group1):
199
                    for image_index2 in range(num_images_in_group2):
200
                        inter_sample_indices.append(
201
                            (group_index1, image_index1, group_index2, image_index2)
202
                        )
203
        return inter_sample_indices
204
205
    def sample_index_generator(self):
206
        """
207
        Yield (moving_index, fixed_index, image_indices) sequentially, where
208
209
          - moving_index = (group1, image1)
210
          - fixed_index = (group2, image2)
211
          - image_indices = [group1, image1, group2, image2]
212
        """
213
        rnd = random.Random(self.seed)  # set random seed
214
        if self.sample_image_in_group is True:
215
            # for each group sample one image pair only
216
            group_indices = [i for i in range(self.num_groups)]
217
            rnd.shuffle(group_indices)
218
            for group_index in group_indices:
219
                if rnd.random() <= self.intra_group_prob:
220
                    # intra-group sampling
221
                    # inside the group_index-th group, we sample two images as moving/fixed
222
                    group_index1 = group_index
223
                    group_index2 = group_index
224
                    num_images_in_group = self.num_images_per_group[group_index]
225
                    if num_images_in_group < 2:
226
                        # skip groups having <2 images
227
                        # currently have not encountered
228
                        continue  # pragma: no cover
229
230
                    image_index1, image_index2 = rnd.sample(
231
                        [i for i in range(num_images_in_group)], 2
232
                    )  # sample two unique indices
233
                    if self.intra_group_option == "forward":
234
                        # image_index1 < image_index2
235
                        image_index1, image_index2 = (
236
                            min(image_index1, image_index2),
237
                            max(image_index1, image_index2),
238
                        )
239
                    elif self.intra_group_option == "backward":
240
                        # image_index1 > image_index2
241
                        image_index1, image_index2 = (
242
                            max(image_index1, image_index2),
243
                            min(image_index1, image_index2),
244
                        )
245
                    elif self.intra_group_option == "unconstrained":
246
                        pass
247
                    else:
248
                        raise ValueError(
249
                            f"Unknown intra_group_option, "
250
                            f"must be forward/backward/unconstrained, "
251
                            f"got {self.intra_group_option}"
252
                        )
253
                else:
254
                    # inter-group sampling
255
                    # we sample another group, then in each group we sample one image
256
                    group_index1 = group_index
257
                    group_index2 = rnd.choice(
258
                        [i for i in range(self.num_groups) if i != group_index]
259
                    )
260
                    num_images_in_group1 = self.num_images_per_group[group_index1]
261
                    num_images_in_group2 = self.num_images_per_group[group_index2]
262
                    image_index1 = rnd.choice([i for i in range(num_images_in_group1)])
263
                    image_index2 = rnd.choice([i for i in range(num_images_in_group2)])
264
265
                moving_index = (group_index1, image_index1)
266
                fixed_index = (group_index2, image_index2)
267
                image_indices = [group_index1, image_index1, group_index2, image_index2]
268
                yield moving_index, fixed_index, image_indices
269
        else:
270
            # sample indices are pre-calculated
271
            assert self.sample_indices is not None
272
            sample_indices = deepcopy(self.sample_indices)
273
            rnd.shuffle(sample_indices)  # shuffle in place
274
            for sample_index in sample_indices:
275
                group_index1, image_index1, group_index2, image_index2 = sample_index
276
                moving_index = (group_index1, image_index1)
277
                fixed_index = (group_index2, image_index2)
278
                image_indices = [group_index1, image_index1, group_index2, image_index2]
279
                yield moving_index, fixed_index, image_indices
280
281
    def close(self):
282
        """Close file loaders"""
283
        self.loader_moving_image.close()
284
        if self.labeled is True:
285
            self.loader_moving_label.close()
286