test.unit.test_nifti_loader.get_loader()   A
last analyzed

Complexity

Conditions 4

Size

Total Lines 20
Code Lines 17

Duplication

Lines 17
Ratio 85 %

Importance

Changes 0
Metric Value
eloc 17
dl 17
loc 20
rs 9.55
c 0
b 0
f 0
cc 4
nop 1
1
"""
2
Tests functionality of the NiftiFileLoader
3
"""
4
import os
5
from test.unit.util import is_equal_np
6
7
import numpy as np
8
import pytest
9
10
from deepreg.dataset.loader.nifti_loader import NiftiFileLoader, load_nifti_file
11
12
13
def get_loader(loader_name):
14 View Code Duplication
    if loader_name in [
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
15
        "paired",
16
        "unpaired",
17
        "grouped",
18
    ]:
19
        dir_paths = [f"./data/test/nifti/{loader_name}/test"]
20
        name = "fixed_images" if loader_name == "paired" else "images"
21
        grouped = loader_name == "grouped"
22
    elif loader_name == "multi_dirs_grouped":
23
        dir_paths = [
24
            "./data/test/nifti/grouped/train",
25
            "./data/test/nifti/grouped/test",
26
        ]
27
        name = "images"
28
        grouped = True
29
    else:
30
        raise ValueError
31
    loader = NiftiFileLoader(dir_paths=dir_paths, name=name, grouped=grouped)
32
    return loader
33
34
35
@pytest.mark.parametrize(
36
    "path,shape",
37
    [
38
        ("./data/test/nifti/paired/test/fixed_images/case000026.nii.gz", (44, 59, 41)),
39
        ("./data/test/nifti/unit_test/case000026.nii", (44, 59, 41)),
40
    ],
41
)
42
def test_load_nifti_file(path, shape):
43
    arr = load_nifti_file(file_path=path)
44
    assert arr.shape == shape
45
46
47
def test_load_nifti_file_err():
48
    h5_filepath = "./data/test/h5/paired/test/fixed_images.h5"
49
    with pytest.raises(ValueError) as err_info:
50
        load_nifti_file(file_path=h5_filepath)
51
    assert "Nifti file path must end with .nii or .nii.gz" in str(err_info.value)
52
53
54
class TestNiftiFileLoader:
55
    @pytest.mark.parametrize(
56
        "name,expected",
57
        [
58
            (
59
                "paired",
60
                [
61
                    [
62
                        ("./data/test/nifti/paired/test", "case000025", "nii.gz"),
63
                        ("./data/test/nifti/paired/test", "case000026", "nii.gz"),
64
                    ],
65
                    None,
66
                ],
67
            ),
68
            (
69
                "unpaired",
70
                [
71
                    [
72
                        ("./data/test/nifti/unpaired/test", "case000025", "nii.gz"),
73
                        ("./data/test/nifti/unpaired/test", "case000026", "nii"),
74
                    ],
75
                    None,
76
                ],
77
            ),
78
            (
79
                "grouped",
80
                [
81
                    [
82
                        (
83
                            "./data/test/nifti/grouped/test",
84
                            "group1",
85
                            "case000025",
86
                            "nii.gz",
87
                        ),
88
                        (
89
                            "./data/test/nifti/grouped/test",
90
                            "group1",
91
                            "case000026",
92
                            "nii.gz",
93
                        ),
94
                    ],
95
                    [[0, 1]],
96
                ],
97
            ),
98
            (
99
                "multi_dirs_grouped",
100
                [
101
                    [
102
                        (
103
                            "./data/test/nifti/grouped/test",
104
                            "group1",
105
                            "case000025",
106
                            "nii.gz",
107
                        ),
108
                        (
109
                            "./data/test/nifti/grouped/test",
110
                            "group1",
111
                            "case000026",
112
                            "nii.gz",
113
                        ),
114
                        (
115
                            "./data/test/nifti/grouped/train",
116
                            "group1",
117
                            "case000000",
118
                            "nii.gz",
119
                        ),
120
                        (
121
                            "./data/test/nifti/grouped/train",
122
                            "group1",
123
                            "case000001",
124
                            "nii.gz",
125
                        ),
126
                        (
127
                            "./data/test/nifti/grouped/train",
128
                            "group1",
129
                            "case000003",
130
                            "nii.gz",
131
                        ),
132
                        (
133
                            "./data/test/nifti/grouped/train",
134
                            "group1",
135
                            "case000008",
136
                            "nii.gz",
137
                        ),
138
                        (
139
                            "./data/test/nifti/grouped/train",
140
                            "group2",
141
                            "case000009",
142
                            "nii.gz",
143
                        ),
144
                        (
145
                            "./data/test/nifti/grouped/train",
146
                            "group2",
147
                            "case000011",
148
                            "nii.gz",
149
                        ),
150
                        (
151
                            "./data/test/nifti/grouped/train",
152
                            "group2",
153
                            "case000012",
154
                            "nii.gz",
155
                        ),
156
                    ],
157
                    [[0, 1], [2, 3, 4, 5], [6, 7, 8]],
158
                ],
159
            ),
160
        ],
161
    )
162
    def test_init(self, name, expected):
163
        loader = get_loader(name)
164
        got = [
165
            loader.data_path_splits,
166
            loader.group_struct,
167
        ]
168
        assert got == expected
169
        loader.close()
170
171
    @pytest.mark.parametrize(
172
        "name",
173
        [
174
            "paired",
175
            "unpaired",
176
            "grouped",
177
        ],
178
    )
179
    def test_init_duplicated_dirs(self, name):
180
        # duplicated dir_paths
181
        loader = get_loader(name)
182
        dir_paths = loader.dir_paths * 2
183
        with pytest.raises(ValueError) as err_info:
184
            NiftiFileLoader(
185
                dir_paths=dir_paths, name=loader.name, grouped=loader.grouped
186
            )
187
        assert "dir_paths have repeated elements" in str(err_info.value)
188
        loader.close()
189
190
    @pytest.mark.parametrize(
191
        "name,err_msg",
192
        [
193
            (
194
                "images",
195
                "directory ./data/test/h5/paired/test/images does not exist",
196
            ),  # test not existed files
197
        ],
198
    )
199
    def test_set_data_structure_err1(self, name, err_msg):
200
        dir_paths = ["./data/test/h5/paired/test"]
201
        with pytest.raises(AssertionError) as err_info:
202
            NiftiFileLoader(dir_paths=dir_paths, name=name, grouped=True)
203
        assert err_msg in str(err_info.value)
204
205
    def test_set_data_structure_err2(self):
206
        dir_paths = ["./data/test/nifti/paired/test"]
207
        name = "error"
208
        dir_path = os.path.join(dir_paths[0], name)
209
        os.makedirs(dir_path, exist_ok=True)
210
        with pytest.raises(ValueError) as err_info:
211
            NiftiFileLoader(dir_paths=dir_paths, name=name, grouped=False)
212
        assert "No data collected" in str(err_info.value)
213
        os.removedirs(dir_path)
214
215 View Code Duplication
    @pytest.mark.parametrize(
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
216
        "name,index,expected",
217
        [
218
            ("paired", 0, [(44, 59, 41), [255.0, 0.0, 68.359276, 65.84009]]),
219
            ("unpaired", 0, [(64, 64, 60), [255.0, 0.0, 60.073948, 47.27648]]),
220
            ("grouped", (0, 1), [(64, 64, 60), [255.0, 0.0, 85.67942, 49.193127]]),
221
            (
222
                "multi_dirs_grouped",
223
                (0, 1),
224
                [(64, 64, 60), [255.0, 0.0, 85.67942, 49.193127]],
225
            ),
226
        ],
227
    )
228
    def test_get_data(self, name, index, expected):
229
        loader = get_loader(name)
230
        array = loader.get_data(index)
231
        got = [
232
            np.shape(array),
233
            [np.amax(array), np.amin(array), np.mean(array), np.std(array)],
234
        ]
235
        assert got[0] == expected[0]
236
        assert is_equal_np(got[1], expected[1])
237
        loader.close()
238
239 View Code Duplication
    @pytest.mark.parametrize(
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
240
        "name,expected",
241
        [
242
            (
243
                "paired",
244
                [
245
                    ("./data/test/nifti/paired/test", "case000025"),
246
                    ("./data/test/nifti/paired/test", "case000026"),
247
                ],
248
            ),
249
            (
250
                "unpaired",
251
                [
252
                    ("./data/test/nifti/unpaired/test", "case000025"),
253
                    ("./data/test/nifti/unpaired/test", "case000026"),
254
                ],
255
            ),
256
            (
257
                "grouped",
258
                [
259
                    ("./data/test/nifti/grouped/test", "group1", "case000025"),
260
                    ("./data/test/nifti/grouped/test", "group1", "case000026"),
261
                ],
262
            ),
263
            (
264
                "multi_dirs_grouped",
265
                [
266
                    ("./data/test/nifti/grouped/test", "group1", "case000025"),
267
                    ("./data/test/nifti/grouped/test", "group1", "case000026"),
268
                    ("./data/test/nifti/grouped/train", "group1", "case000000"),
269
                    ("./data/test/nifti/grouped/train", "group1", "case000001"),
270
                    ("./data/test/nifti/grouped/train", "group1", "case000003"),
271
                    ("./data/test/nifti/grouped/train", "group1", "case000008"),
272
                    ("./data/test/nifti/grouped/train", "group2", "case000009"),
273
                    ("./data/test/nifti/grouped/train", "group2", "case000011"),
274
                    ("./data/test/nifti/grouped/train", "group2", "case000012"),
275
                ],
276
            ),
277
        ],
278
    )
279
    def test_get_data_ids(self, name, expected):
280
        loader = get_loader(name)
281
        got = loader.get_data_ids()
282
        assert got == expected
283
        loader.close()
284
285
    @pytest.mark.parametrize(
286
        "index,err_type",
287
        [
288
            (-1, AssertionError),
289
            (64, IndexError),
290
            ((0, 1), AssertionError),
291
            ("wrong", ValueError),
292
        ],
293
    )
294
    def test_get_data_ids_check_err_with_paired(self, index, err_type):
295
        # wrong index for paired
296
        loader = get_loader("paired")
297
        with pytest.raises(err_type):
298
            loader.get_data(index=index)
299
        loader.close()
300
301
    def test_get_data_ids_check_err_with_grouped(self):
302
        # wrong index for paired
303
        loader = get_loader("grouped")
304
        with pytest.raises(AssertionError):
305
            # non-tuple data_index
306
            loader.get_data(index=1)
307
        loader.close()
308
309
    @pytest.mark.parametrize(
310
        "name,expected",
311
        [("paired", 2), ("unpaired", 2), ("grouped", 2), ("multi_dirs_grouped", 9)],
312
    )
313
    def test_get_num_images(self, name, expected):
314
        loader = get_loader(name)
315
        got = loader.get_num_images()
316
        assert got == expected
317
        loader.close()
318
319
    @pytest.mark.parametrize(
320
        "name",
321
        [
322
            "paired",
323
            "unpaired",
324
            "grouped",
325
        ],
326
    )
327
    def test_close(self, name):
328
        loader = get_loader(name)
329
        loader.close()
330