test.unit.test_paired_loader   A
last analyzed

Complexity

Total Complexity 19

Size/Duplication

Total Lines 162
Duplicated Lines 17.28 %

Importance

Changes 0
Metric Value
wmc 19
eloc 102
dl 28
loc 162
rs 10
c 0
b 0
f 0

4 Functions

Rating   Name   Duplication   Size   Complexity  
A test_init() 0 39 3
A test_validate_data_files_label() 0 25 4
B test_sample_index_generator() 0 42 6
B test_close() 28 28 6

How to fix   Duplicated Code   

Duplicated Code

Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.

Common duplication problems, and corresponding solutions are:

1
"""
2
Tests functionality of the PairedDataLoader
3
"""
4
from os.path import join
5
6
import numpy as np
7
import pytest
8
9
from deepreg.dataset.loader.h5_loader import H5FileLoader
10
from deepreg.dataset.loader.nifti_loader import NiftiFileLoader
11
from deepreg.dataset.loader.paired_loader import PairedDataLoader
12
13
# assign values to input vars
14
moving_image_shape = (64, 64, 60)
15
fixed_image_shape = (32, 32, 60)
16
17
FileLoaderDict = dict(nifti=NiftiFileLoader, h5=H5FileLoader)
18
DataPaths = dict(nifti="data/test/nifti/paired", h5="data/test/h5/paired")
19
20
21
def test_init():
22
    """
23
    Check that data loader __init__() method is correct:
24
    """
25
26
    for key_file_loader, file_loader in FileLoaderDict.items():
27
        data_dir_path = [
28
            join(DataPaths[key_file_loader], "train"),
29
            join(DataPaths[key_file_loader], "test"),
30
        ]
31
        common_args = dict(
32
            file_loader=file_loader, labeled=True, sample_label="all", seed=None
33
        )
34
        data_loader = PairedDataLoader(
35
            data_dir_paths=data_dir_path,
36
            fixed_image_shape=fixed_image_shape,
37
            moving_image_shape=moving_image_shape,
38
            **common_args,
39
        )
40
41
        # Check that file loaders are initialized correctly
42
        file_loader_method = file_loader(
43
            dir_paths=data_dir_path, name="moving_images", grouped=False
44
        )
45
        assert isinstance(data_loader.loader_moving_image, type(file_loader_method))
46
        assert isinstance(data_loader.loader_fixed_image, type(file_loader_method))
47
        assert isinstance(data_loader.loader_moving_label, type(file_loader_method))
48
        assert isinstance(data_loader.loader_fixed_label, type(file_loader_method))
49
50
        data_loader.close()
51
52
        # Check the data_dir_path variable assertion error.
53
        data_dir_path_int = [0, "1", 2, 3]
54
        with pytest.raises(AssertionError):
55
            PairedDataLoader(
56
                data_dir_paths=data_dir_path_int,
57
                fixed_image_shape=fixed_image_shape,
58
                moving_image_shape=moving_image_shape,
59
                **common_args,
60
            )
61
62
63
def test_validate_data_files_label():
64
    """
65
    Test the validate_data_files functions
66
    that looks for inconsistencies in the fixed/moving image and label lists.
67
    If there is any issue it will raise an error, otherwise it returns None.
68
    """
69
    for key_file_loader, file_loader in FileLoaderDict.items():
70
        for split in ["train", "test"]:
71
            data_dir_path = [join(DataPaths[key_file_loader], split)]
72
            common_args = dict(
73
                file_loader=file_loader,
74
                labeled=True,
75
                sample_label="all",
76
                seed=None if split == "train" else 0,
77
            )
78
79
            data_loader = PairedDataLoader(
80
                data_dir_paths=data_dir_path,
81
                fixed_image_shape=fixed_image_shape,
82
                moving_image_shape=moving_image_shape,
83
                **common_args,
84
            )
85
86
            assert data_loader.validate_data_files() is None
87
            data_loader.close()
88
89
90
def test_sample_index_generator():
91
    """
92
    Test to check the randomness and deterministic index generator
93
    for train/test respectively.
94
    """
95
96
    for key_file_loader, file_loader in FileLoaderDict.items():
97
        for split in ["train", "test"]:
98
            data_dir_path = [join(DataPaths[key_file_loader], split)]
99
            indices_to_compare = []
100
101
            for seed in [0, 1, 0]:
102
                data_loader = PairedDataLoader(
103
                    data_dir_paths=data_dir_path,
104
                    fixed_image_shape=fixed_image_shape,
105
                    moving_image_shape=moving_image_shape,
106
                    file_loader=file_loader,
107
                    labeled=True,
108
                    sample_label="all",
109
                    seed=seed,
110
                )
111
112
                data_indices = []
113
                for (
114
                    moving_index,
115
                    fixed_index,
116
                    indices,
117
                ) in data_loader.sample_index_generator():
118
                    assert isinstance(moving_index, int)
119
                    assert isinstance(fixed_index, int)
120
                    assert isinstance(indices, list)
121
                    assert moving_index == fixed_index
122
                    data_indices += indices
123
124
                indices_to_compare.append(data_indices)
125
                data_loader.close()
126
127
            if data_loader.num_images > 1:
0 ignored issues
show
introduced by
The variable data_loader does not seem to be defined for all execution paths.
Loading history...
128
                # test different seeds give different indices
129
                assert not (np.allclose(indices_to_compare[0], indices_to_compare[1]))
130
                # test same seeds give the same indices
131
                assert np.allclose(indices_to_compare[0], indices_to_compare[2])
132
133
134 View Code Duplication
def test_close():
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
135
    """
136
    Test the close function. Only needed for H5 data loaders for now.
137
    Since fixed/moving loaders are the same for
138
    unpaired data loader, only need to test the moving.
139
    """
140
    for key_file_loader, file_loader in FileLoaderDict.items():
141
        for split in ["train", "test"]:
142
143
            data_dir_path = [join(DataPaths[key_file_loader], split)]
144
            common_args = dict(
145
                file_loader=file_loader,
146
                labeled=True,
147
                sample_label="all",
148
                seed=None if split == "train" else 0,
149
            )
150
151
            data_loader = PairedDataLoader(
152
                data_dir_paths=data_dir_path,
153
                fixed_image_shape=fixed_image_shape,
154
                moving_image_shape=moving_image_shape,
155
                **common_args,
156
            )
157
158
            if key_file_loader == "h5":
159
                data_loader.close()
160
                for f in data_loader.loader_moving_image.h5_files.values():
161
                    assert not f.__bool__()
162