test_sample_index_generator()   B
last analyzed

Complexity

Conditions 5

Size

Total Lines 39
Code Lines 27

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 27
dl 0
loc 39
rs 8.7653
c 0
b 0
f 0
cc 5
nop 0
1
# coding=utf-8
2
3
"""
4
Tests for deepreg/dataset/util.py in
5
pytest style
6
"""
7
from os.path import join
8
9
import numpy as np
10
11
from deepreg.dataset.loader.h5_loader import H5FileLoader
12
from deepreg.dataset.loader.nifti_loader import NiftiFileLoader
13
from deepreg.dataset.loader.unpaired_loader import UnpairedDataLoader
14
15
FileLoaderDict = dict(nifti=NiftiFileLoader, h5=H5FileLoader)
16
DataPaths = dict(nifti="data/test/nifti/unpaired", h5="data/test/h5/unpaired")
17
18
19
def test_sample_index_generator():
20
    """
21
    Test to check the randomness and deterministic index generator
22
    for train/test respectively.
23
    """
24
    image_shape = (64, 64, 60)
25
26
    for key_file_loader, file_loader in FileLoaderDict.items():
27
        for split in ["train", "test"]:
28
            data_dir_path = [join(DataPaths[key_file_loader], split)]
29
            indices_to_compare = []
30
31
            for seed in [0, 1, 0]:
32
                data_loader = UnpairedDataLoader(
33
                    data_dir_paths=data_dir_path,
34
                    image_shape=image_shape,
35
                    file_loader=file_loader,
36
                    labeled=True,
37
                    sample_label="all",
38
                    seed=seed,
39
                )
40
41
                data_indices = []
42
                for (
43
                    moving_index,
44
                    fixed_index,
45
                    indices,
46
                ) in data_loader.sample_index_generator():
47
                    assert isinstance(moving_index, int)
48
                    assert isinstance(fixed_index, int)
49
                    assert isinstance(indices, list)
50
                    data_indices += indices
51
52
                indices_to_compare.append(data_indices)
53
54
            # test different seeds give different indices
55
            assert not np.allclose(indices_to_compare[0], indices_to_compare[1])
56
            # test same seeds give the same indices
57
            assert np.allclose(indices_to_compare[0], indices_to_compare[2])
58
59
60
def test_validate_data_files():
61
    """
62
    Test the validate_data_files functions that looks for inconsistencies
63
    in the fixed/moving image and label lists.
64
    If there is any issue it will raise an error, otherwise it returns None.
65
    """
66
    for key_file_loader, file_loader in FileLoaderDict.items():
67
        for split in ["train", "test"]:
68
            data_dir_path = [join(DataPaths[key_file_loader], split)]
69
            image_shape = (64, 64, 60)
70
            common_args = dict(
71
                file_loader=file_loader,
72
                labeled=True,
73
                sample_label="all",
74
                seed=None if split == "train" else 0,
75
            )
76
77
            data_loader = UnpairedDataLoader(
78
                data_dir_paths=data_dir_path, image_shape=image_shape, **common_args
79
            )
80
81
            assert data_loader.validate_data_files() is None
82
83
84 View Code Duplication
def test_close():
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
85
    """
86
    Test the close function. Only needed for H5 data loaders for now.
87
    Since fixed/moving loaders are the same for
88
    unpaired data loader, only need to test the moving.
89
    """
90
    for key_file_loader, file_loader in FileLoaderDict.items():
91
        for split in ["train", "test"]:
92
93
            data_dir_path = [join(DataPaths[key_file_loader], split)]
94
            image_shape = (64, 64, 60)
95
            common_args = dict(
96
                file_loader=file_loader,
97
                labeled=True,
98
                sample_label="all",
99
                seed=None if split == "train" else 0,
100
            )
101
102
            data_loader = UnpairedDataLoader(
103
                data_dir_paths=data_dir_path, image_shape=image_shape, **common_args
104
            )
105
106
            if key_file_loader == "h5":
107
                data_loader.close()
108
                for f in data_loader.loader_moving_image.h5_files.values():
109
                    assert not f.__bool__()
110