TestGetDataLoader.test_empty_config()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 11
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 6
dl 0
loc 11
rs 10
c 0
b 0
f 0
cc 1
nop 2
1
# coding=utf-8
2
3
"""
4
Tests for deepreg/dataset/load.py in pytest style
5
"""
6
from typing import Optional
7
8
import pytest
9
import yaml
10
11
import deepreg.dataset.load as load
12
from deepreg.dataset.loader.unpaired_loader import UnpairedDataLoader
13
from deepreg.registry import DATA_LOADER_CLASS, REGISTRY
14
15
16
def load_yaml(file_path: str) -> dict:
17
    """
18
    load the yaml file and return a dictionary
19
20
    :param file_path: path of the yaml file.
21
    """
22
    assert file_path.endswith(".yaml")
23
    with open(file_path) as file:
24
        return yaml.load(file, Loader=yaml.FullLoader)
25
26
27
class TestGetDataLoader:
28
    @pytest.mark.parametrize("data_type", ["paired", "unpaired", "grouped"])
29
    @pytest.mark.parametrize("format", ["nifti", "h5"])
30
    def test_data_loader(self, data_type: str, format: str):
31
        """
32
        Test the data loader can be successfully built.
33
34
        :param data_type: name of data loader for registry
35
        :param format: name of file loader for registry
36
        """
37
        # single paired data loader
38
        config = load_yaml(f"config/test/{data_type}_{format}.yaml")
39
        got = load.get_data_loader(data_config=config["dataset"], split="train")
40
        expected = REGISTRY.get(category=DATA_LOADER_CLASS, key=data_type)
41
        assert isinstance(got, expected)  # type: ignore
42
43
    def test_multi_dir_data_loader(self):
44
        """unpaired data loader with multiple dirs"""
45
        config = load_yaml("config/test/unpaired_nifti_multi_dirs.yaml")
46
        got = load.get_data_loader(data_config=config["dataset"], split="train")
47
        assert isinstance(got, UnpairedDataLoader)
48
49
    @pytest.mark.parametrize("path", ["", None])
50
    def test_empty_path(self, path: Optional[str]):
51
        """
52
        Test return without data path.
53
54
        :param path: training data path to be used
55
        """
56
        config = load_yaml("config/test/paired_nifti.yaml")
57
        config["dataset"]["train"]["dir"] = path
58
        got = load.get_data_loader(data_config=config["dataset"], split="train")
59
        assert got is None
60
61
    @pytest.mark.parametrize("split", ["train", "valid", "test"])
62
    def test_empty_config(self, split: str):
63
        """
64
        Test return without data path for the split.
65
66
        :param split: train or valid or test
67
        """
68
        config = load_yaml("config/test/paired_nifti.yaml")
69
        config["dataset"].pop(split)
70
        got = load.get_data_loader(data_config=config["dataset"], split=split)
71
        assert got is None
72
73
    @pytest.mark.parametrize(
74
        "path", ["config/test/paired_nifti.yaml", "config/test/paired_nifti"]
75
    )
76
    def test_dir_err(self, path: Optional[str]):
77
        """
78
        Check the error is raised when the path is wrong.
79
80
        :param path: training data path to be used
81
        """
82
        config = load_yaml("config/test/paired_nifti.yaml")
83
        config["dataset"]["train"]["dir"] = path
84
        with pytest.raises(ValueError) as err_info:
85
            load.get_data_loader(data_config=config["dataset"], split="train")
86
        assert "is not a directory or does not exist" in str(err_info.value)
87
88
    def test_mode_err(self):
89
        """Check the error is raised when the split is wrong."""
90
        config = load_yaml("config/test/paired_nifti.yaml")
91
        with pytest.raises(ValueError) as err_info:
92
            load.get_data_loader(data_config=config["dataset"], split="example")
93
        assert "split must be one of ['train', 'valid', 'test']" in str(err_info.value)
94