test_get_inter_sample_indices()   A
last analyzed

Complexity

Conditions 2

Size

Total Lines 29
Code Lines 21

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 21
dl 0
loc 29
rs 9.376
c 0
b 0
f 0
cc 2
nop 0
1
"""
2
Tests for deepreg/dataset/loader/grouped_loader.py in
3
pytest style
4
"""
5
from os.path import join
6
from typing import List
7
8
import numpy as np
9
import pytest
10
11
from deepreg.dataset.loader.grouped_loader import GroupedDataLoader
12
from deepreg.dataset.loader.h5_loader import H5FileLoader
13
from deepreg.dataset.loader.nifti_loader import NiftiFileLoader
14
15
FileLoaderDict = dict(nifti=NiftiFileLoader, h5=H5FileLoader)
16
DataPaths = dict(nifti="data/test/nifti/grouped", h5="data/test/h5/grouped")
17
image_shape = (64, 64, 60)
18
19
20
def sample_count(ni: List[int], direction: str) -> int:
21
    """
22
    Count number of samples.
23
24
    :param ni: list, each element correspond to the number of images per group
25
    :param direction: unconstrained/forward/backward
26
    :return: number of samples in total
27
    """
28
    arr = np.array(ni)
29
    if direction == "unconstrained":
30
        sample_total = np.sum(arr * (arr - 1))
31
    else:
32
        sample_total = np.sum(arr * (arr - 1) / 2)
33
    return int(sample_total)
34
35
36
def test_init():
37
    """
38
    Test exceptions with appropriate messages and counts samples correctly
39
    """
40
    for key_file_loader, file_loader in FileLoaderDict.items():
41
        for train_split in ["test", "train"]:
42
            for prob in [0, 0.5, 1]:
43
                for sample_in_group in [True, False]:
44
                    data_dir_paths = [join(DataPaths[key_file_loader], train_split)]
45
                    common_args = dict(
46
                        file_loader=file_loader,
47
                        labeled=True,
48
                        sample_label="all",
49
                        intra_group_prob=prob,
50
                        intra_group_option="forward",
51
                        sample_image_in_group=sample_in_group,
52
                        seed=None,
53
                    )
54
                    if train_split == "test" and prob < 1:
55
                        # sample with fewer than 2 groups.
56
                        # In "test" we only have one group
57
                        with pytest.raises(ValueError) as err_info:
58
                            data_loader = GroupedDataLoader(
59
                                data_dir_paths=data_dir_paths,
60
                                image_shape=image_shape,
61
                                **common_args,
62
                            )
63
                            data_loader.close()
64
                        assert "we need at least two groups" in str(err_info.value)
65
66
                    elif train_split == "train" and sample_in_group is True:
67
                        # ensure sample count is accurate
68
                        # (only for train dir, test dir uses same logic)
69
                        data_loader = GroupedDataLoader(
70
                            data_dir_paths=data_dir_paths,
71
                            image_shape=image_shape,
72
                            **common_args,
73
                        )
74
                        assert data_loader.sample_indices is None
75
                        assert data_loader._num_samples == 2
76
                        data_loader.close()
77
78
                    elif sample_in_group is False and 0 < prob < 1:
79
                        # specifying conflicting intra/inter group parameters
80
                        with pytest.raises(ValueError) as err_info:
81
                            data_loader = GroupedDataLoader(
82
                                data_dir_paths=data_dir_paths,
83
                                image_shape=image_shape,
84
                                **common_args,
85
                            )
86
                            data_loader.close()
87
                        assert "Mixing intra and inter groups is not supported" in str(
88
                            err_info.value
89
                        )
90
91
92
def test_validate_data_files():
93
    """
94
    Test validate_data_files function looks for inconsistencies
95
     in the fixed/moving image and label lists.
96
    If there is any issue it will raise an error, otherwise it returns None.
97
    """
98
    for key_file_loader, file_loader in FileLoaderDict.items():
99
        for train_split in ["train", "test"]:
100
            for labeled in [True, False]:
101
                data_dir_paths = [join(DataPaths[key_file_loader], train_split)]
102
                common_args = dict(
103
                    file_loader=file_loader,
104
                    labeled=labeled,
105
                    sample_label="all",
106
                    intra_group_prob=1,
107
                    intra_group_option="forward",
108
                    sample_image_in_group=False,
109
                    seed=None if train_split == "train" else 0,
110
                )
111
112
                data_loader = GroupedDataLoader(
113
                    data_dir_paths=data_dir_paths,
114
                    image_shape=image_shape,
115
                    **common_args,
116
                )
117
118
                assert data_loader.validate_data_files() is None
119
120
121
def test_get_inter_sample_indices():
122
    """
123
    Test all possible intergroup sampling indices are correctly calculated
124
    """
125
    for key_file_loader, file_loader in FileLoaderDict.items():
126
        data_dir_paths = [join(DataPaths[key_file_loader], "train")]
127
        common_args = dict(
128
            file_loader=file_loader,
129
            labeled=True,
130
            sample_label="all",
131
            intra_group_prob=0,
132
            intra_group_option="forward",
133
            sample_image_in_group=False,
134
            seed=None,
135
        )
136
        data_loader = GroupedDataLoader(
137
            data_dir_paths=data_dir_paths, image_shape=image_shape, **common_args
138
        )
139
140
        ni = np.array(data_loader.num_images_per_group)
141
        num_samples = np.sum(ni) * (np.sum(ni) - 1) - sum(ni * (ni - 1))
142
143
        sample_indices = data_loader.sample_indices
144
        sample_indices.sort()
145
        unique_indices = list(set(sample_indices))
146
        unique_indices.sort()
147
148
        assert data_loader._num_samples == num_samples
149
        assert sample_indices == unique_indices
150
151
152
def test_get_intra_sample_indices():
153
    """
154
    Test all possible intragroup sampling indices are correctly calculated
155
    Ensure exception is thrown for unsupported group_option
156
    """
157
    for key_file_loader, file_loader in FileLoaderDict.items():
158
        for split in ["train", "test"]:
159
            data_dir_paths = [join(DataPaths[key_file_loader], split)]
160
            common_args = dict(
161
                file_loader=file_loader,
162
                labeled=True,
163
                sample_label="all",
164
                intra_group_prob=1,
165
                sample_image_in_group=False,
166
                seed=None,
167
            )
168
            # test feasible intra_group_option
169
            for intra_group_option in ["forward", "backward", "unconstrained"]:
170
                data_loader = GroupedDataLoader(
171
                    data_dir_paths=data_dir_paths,
172
                    image_shape=image_shape,
173
                    intra_group_option=intra_group_option,
174
                    **common_args,
175
                )
176
177
                ni = data_loader.num_images_per_group
178
                num_samples = sample_count(ni, intra_group_option)
179
180
                sample_indices = data_loader.sample_indices
181
                sample_indices.sort()
182
                unique_indices = list(set(sample_indices))
183
                unique_indices.sort()
184
185
                # test all possible indices are generated
186
                assert data_loader._num_samples == num_samples
187
                assert sample_indices == unique_indices
188
189
            # test exception thrown for unsupported group option
190
            with pytest.raises(ValueError) as err_info:
191
                data_loader = GroupedDataLoader(
192
                    data_dir_paths=data_dir_paths,
193
                    image_shape=image_shape,
194
                    intra_group_option="wrong",
195
                    **common_args,
196
                )
197
                data_loader.close()
198
            assert "Unknown intra_group_option," in str(err_info.value)
199
200
201
def test_sample_index_generator():
202
    """
203
    Test to check the randomness and deterministic index generator for train
204
    Test dir not checked because it contains only a single group of 2 images
205
    """
206
207
    for key_file_loader, file_loader in FileLoaderDict.items():
208
        common_args = dict(
209
            image_shape=image_shape,
210
            data_dir_paths=[join(DataPaths[key_file_loader], "train")],
211
            file_loader=file_loader,
212
            labeled=True,
213
            sample_label="all",
214
        )
215
216
        # test feasible intra_group_option
217
        for sample_in_group in [False, True]:
218
            probs = [0, 0.5, 1] if sample_in_group else [0, 1]
219
            for prob in probs:
220
                for direction in ["forward", "backward", "unconstrained"]:
221
                    indices_to_compare = []
222
223
                    for seed in [0, 1, 0]:
224
                        data_loader = GroupedDataLoader(
225
                            intra_group_prob=prob,
226
                            intra_group_option=direction,
227
                            sample_image_in_group=sample_in_group,
228
                            seed=seed,
229
                            **common_args,
230
                        )
231
232
                        data_indices = []
233
                        for (
234
                            moving_index,
235
                            fixed_index,
236
                            indices,
237
                        ) in data_loader.sample_index_generator():
238
                            assert isinstance(moving_index, tuple)
239
                            assert isinstance(fixed_index, tuple)
240
                            assert isinstance(indices, list)
241
                            data_indices += indices
242
243
                        data_loader.close()
244
                        indices_to_compare.append(data_indices)
245
246
                    # test different seeds give different indices
247
                    assert not np.allclose(indices_to_compare[0], indices_to_compare[1])
248
                    # test same seeds give the same indices
249
                    assert np.allclose(indices_to_compare[0], indices_to_compare[2])
250
251
        # test exception thrown for unsupported intra_group_option option
252
        data_loader = GroupedDataLoader(
253
            intra_group_prob=1,
254
            intra_group_option="wrong",
255
            sample_image_in_group=True,
256
            seed=0,
257
            **common_args,
258
        )
259
        with pytest.raises(ValueError) as err_info:
260
            next(data_loader.sample_index_generator())
261
        data_loader.close()
262
        assert "Unknown intra_group_option" in str(err_info.value)
263
264
265
def test_close():
266
    """
267
    Test the close function
268
    Since fixed and moving loaders are the same only need to test the moving
269
    """
270
    for key_file_loader, file_loader in FileLoaderDict.items():
271
        for split in ["train", "test"]:
272
            data_dir_paths = [join(DataPaths[key_file_loader], split)]
273
274
            data_loader = GroupedDataLoader(
275
                data_dir_paths=data_dir_paths,
276
                image_shape=image_shape,
277
                file_loader=file_loader,
278
                labeled=True,
279
                sample_label="all",
280
                intra_group_prob=1,
281
                intra_group_option="forward",
282
                sample_image_in_group=True,
283
                seed=0,
284
            )
285
286
            if key_file_loader == "h5":
287
                data_loader.close()
288
                for f in data_loader.loader_moving_image.h5_files.values():
289
                    assert not f.__bool__()
290