Code Duplication    Length = 26-28 lines in 2 locations

test/unit/test_paired_loader.py 1 location

@@ 134-161 (lines=28) @@
131
                assert np.allclose(indices_to_compare[0], indices_to_compare[2])
132
133
134
def test_close():
135
    """
136
    Test the close function. Only needed for H5 data loaders for now.
137
    Since fixed/moving loaders are the same for
138
    unpaired data loader, only need to test the moving.
139
    """
140
    for key_file_loader, file_loader in FileLoaderDict.items():
141
        for split in ["train", "test"]:
142
143
            data_dir_path = [join(DataPaths[key_file_loader], split)]
144
            common_args = dict(
145
                file_loader=file_loader,
146
                labeled=True,
147
                sample_label="all",
148
                seed=None if split == "train" else 0,
149
            )
150
151
            data_loader = PairedDataLoader(
152
                data_dir_paths=data_dir_path,
153
                fixed_image_shape=fixed_image_shape,
154
                moving_image_shape=moving_image_shape,
155
                **common_args,
156
            )
157
158
            if key_file_loader == "h5":
159
                data_loader.close()
160
                for f in data_loader.loader_moving_image.h5_files.values():
161
                    assert not f.__bool__()
162

test/unit/test_unpaired_loader.py 1 location

@@ 84-109 (lines=26) @@
81
            assert data_loader.validate_data_files() is None
82
83
84
def test_close():
85
    """
86
    Test the close function. Only needed for H5 data loaders for now.
87
    Since fixed/moving loaders are the same for
88
    unpaired data loader, only need to test the moving.
89
    """
90
    for key_file_loader, file_loader in FileLoaderDict.items():
91
        for split in ["train", "test"]:
92
93
            data_dir_path = [join(DataPaths[key_file_loader], split)]
94
            image_shape = (64, 64, 60)
95
            common_args = dict(
96
                file_loader=file_loader,
97
                labeled=True,
98
                sample_label="all",
99
                seed=None if split == "train" else 0,
100
            )
101
102
            data_loader = UnpairedDataLoader(
103
                data_dir_paths=data_dir_path, image_shape=image_shape, **common_args
104
            )
105
106
            if key_file_loader == "h5":
107
                data_loader.close()
108
                for f in data_loader.loader_moving_image.h5_files.values():
109
                    assert not f.__bool__()
110