|
1
|
|
|
# coding=utf-8 |
|
2
|
|
|
|
|
3
|
|
|
""" |
|
4
|
|
|
Tests for deepreg/dataset/load.py in pytest style |
|
5
|
|
|
""" |
|
6
|
|
|
from typing import Optional |
|
7
|
|
|
|
|
8
|
|
|
import pytest |
|
9
|
|
|
import yaml |
|
10
|
|
|
|
|
11
|
|
|
import deepreg.dataset.load as load |
|
12
|
|
|
from deepreg.dataset.loader.unpaired_loader import UnpairedDataLoader |
|
13
|
|
|
from deepreg.registry import DATA_LOADER_CLASS, REGISTRY |
|
14
|
|
|
|
|
15
|
|
|
|
|
16
|
|
|
def load_yaml(file_path: str) -> dict: |
|
17
|
|
|
""" |
|
18
|
|
|
load the yaml file and return a dictionary |
|
19
|
|
|
|
|
20
|
|
|
:param file_path: path of the yaml file. |
|
21
|
|
|
""" |
|
22
|
|
|
assert file_path.endswith(".yaml") |
|
23
|
|
|
with open(file_path) as file: |
|
24
|
|
|
return yaml.load(file, Loader=yaml.FullLoader) |
|
25
|
|
|
|
|
26
|
|
|
|
|
27
|
|
|
class TestGetDataLoader: |
|
28
|
|
|
@pytest.mark.parametrize("data_type", ["paired", "unpaired", "grouped"]) |
|
29
|
|
|
@pytest.mark.parametrize("format", ["nifti", "h5"]) |
|
30
|
|
|
def test_data_loader(self, data_type: str, format: str): |
|
31
|
|
|
""" |
|
32
|
|
|
Test the data loader can be successfully built. |
|
33
|
|
|
|
|
34
|
|
|
:param data_type: name of data loader for registry |
|
35
|
|
|
:param format: name of file loader for registry |
|
36
|
|
|
""" |
|
37
|
|
|
# single paired data loader |
|
38
|
|
|
config = load_yaml(f"config/test/{data_type}_{format}.yaml") |
|
39
|
|
|
got = load.get_data_loader(data_config=config["dataset"], split="train") |
|
40
|
|
|
expected = REGISTRY.get(category=DATA_LOADER_CLASS, key=data_type) |
|
41
|
|
|
assert isinstance(got, expected) # type: ignore |
|
42
|
|
|
|
|
43
|
|
|
def test_multi_dir_data_loader(self): |
|
44
|
|
|
"""unpaired data loader with multiple dirs""" |
|
45
|
|
|
config = load_yaml("config/test/unpaired_nifti_multi_dirs.yaml") |
|
46
|
|
|
got = load.get_data_loader(data_config=config["dataset"], split="train") |
|
47
|
|
|
assert isinstance(got, UnpairedDataLoader) |
|
48
|
|
|
|
|
49
|
|
|
@pytest.mark.parametrize("path", ["", None]) |
|
50
|
|
|
def test_empty_path(self, path: Optional[str]): |
|
51
|
|
|
""" |
|
52
|
|
|
Test return without data path. |
|
53
|
|
|
|
|
54
|
|
|
:param path: training data path to be used |
|
55
|
|
|
""" |
|
56
|
|
|
config = load_yaml("config/test/paired_nifti.yaml") |
|
57
|
|
|
config["dataset"]["train"]["dir"] = path |
|
58
|
|
|
got = load.get_data_loader(data_config=config["dataset"], split="train") |
|
59
|
|
|
assert got is None |
|
60
|
|
|
|
|
61
|
|
|
@pytest.mark.parametrize("split", ["train", "valid", "test"]) |
|
62
|
|
|
def test_empty_config(self, split: str): |
|
63
|
|
|
""" |
|
64
|
|
|
Test return without data path for the split. |
|
65
|
|
|
|
|
66
|
|
|
:param split: train or valid or test |
|
67
|
|
|
""" |
|
68
|
|
|
config = load_yaml("config/test/paired_nifti.yaml") |
|
69
|
|
|
config["dataset"].pop(split) |
|
70
|
|
|
got = load.get_data_loader(data_config=config["dataset"], split=split) |
|
71
|
|
|
assert got is None |
|
72
|
|
|
|
|
73
|
|
|
@pytest.mark.parametrize( |
|
74
|
|
|
"path", ["config/test/paired_nifti.yaml", "config/test/paired_nifti"] |
|
75
|
|
|
) |
|
76
|
|
|
def test_dir_err(self, path: Optional[str]): |
|
77
|
|
|
""" |
|
78
|
|
|
Check the error is raised when the path is wrong. |
|
79
|
|
|
|
|
80
|
|
|
:param path: training data path to be used |
|
81
|
|
|
""" |
|
82
|
|
|
config = load_yaml("config/test/paired_nifti.yaml") |
|
83
|
|
|
config["dataset"]["train"]["dir"] = path |
|
84
|
|
|
with pytest.raises(ValueError) as err_info: |
|
85
|
|
|
load.get_data_loader(data_config=config["dataset"], split="train") |
|
86
|
|
|
assert "is not a directory or does not exist" in str(err_info.value) |
|
87
|
|
|
|
|
88
|
|
|
def test_mode_err(self): |
|
89
|
|
|
"""Check the error is raised when the split is wrong.""" |
|
90
|
|
|
config = load_yaml("config/test/paired_nifti.yaml") |
|
91
|
|
|
with pytest.raises(ValueError) as err_info: |
|
92
|
|
|
load.get_data_loader(data_config=config["dataset"], split="example") |
|
93
|
|
|
assert "split must be one of ['train', 'valid', 'test']" in str(err_info.value) |
|
94
|
|
|
|