test.unit.test_dataset_load   A
last analyzed

Complexity

Total Complexity 10

Size/Duplication

Total Lines 94
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 10
eloc 48
dl 0
loc 94
rs 10
c 0
b 0
f 0

6 Methods

Rating   Name   Duplication   Size   Complexity  
A TestGetDataLoader.test_data_loader() 0 14 1
A TestGetDataLoader.test_dir_err() 0 14 2
A TestGetDataLoader.test_empty_config() 0 11 1
A TestGetDataLoader.test_mode_err() 0 6 2
A TestGetDataLoader.test_multi_dir_data_loader() 0 5 1
A TestGetDataLoader.test_empty_path() 0 11 1

1 Function

Rating   Name   Duplication   Size   Complexity  
A load_yaml() 0 9 2
1
# coding=utf-8
2
3
"""
4
Tests for deepreg/dataset/load.py in pytest style
5
"""
6
from typing import Optional
7
8
import pytest
9
import yaml
10
11
import deepreg.dataset.load as load
12
from deepreg.dataset.loader.unpaired_loader import UnpairedDataLoader
13
from deepreg.registry import DATA_LOADER_CLASS, REGISTRY
14
15
16
def load_yaml(file_path: str) -> dict:
17
    """
18
    load the yaml file and return a dictionary
19
20
    :param file_path: path of the yaml file.
21
    """
22
    assert file_path.endswith(".yaml")
23
    with open(file_path) as file:
24
        return yaml.load(file, Loader=yaml.FullLoader)
25
26
27
class TestGetDataLoader:
28
    @pytest.mark.parametrize("data_type", ["paired", "unpaired", "grouped"])
29
    @pytest.mark.parametrize("format", ["nifti", "h5"])
30
    def test_data_loader(self, data_type: str, format: str):
31
        """
32
        Test the data loader can be successfully built.
33
34
        :param data_type: name of data loader for registry
35
        :param format: name of file loader for registry
36
        """
37
        # single paired data loader
38
        config = load_yaml(f"config/test/{data_type}_{format}.yaml")
39
        got = load.get_data_loader(data_config=config["dataset"], split="train")
40
        expected = REGISTRY.get(category=DATA_LOADER_CLASS, key=data_type)
41
        assert isinstance(got, expected)  # type: ignore
42
43
    def test_multi_dir_data_loader(self):
44
        """unpaired data loader with multiple dirs"""
45
        config = load_yaml("config/test/unpaired_nifti_multi_dirs.yaml")
46
        got = load.get_data_loader(data_config=config["dataset"], split="train")
47
        assert isinstance(got, UnpairedDataLoader)
48
49
    @pytest.mark.parametrize("path", ["", None])
50
    def test_empty_path(self, path: Optional[str]):
51
        """
52
        Test return without data path.
53
54
        :param path: training data path to be used
55
        """
56
        config = load_yaml("config/test/paired_nifti.yaml")
57
        config["dataset"]["train"]["dir"] = path
58
        got = load.get_data_loader(data_config=config["dataset"], split="train")
59
        assert got is None
60
61
    @pytest.mark.parametrize("split", ["train", "valid", "test"])
62
    def test_empty_config(self, split: str):
63
        """
64
        Test return without data path for the split.
65
66
        :param split: train or valid or test
67
        """
68
        config = load_yaml("config/test/paired_nifti.yaml")
69
        config["dataset"].pop(split)
70
        got = load.get_data_loader(data_config=config["dataset"], split=split)
71
        assert got is None
72
73
    @pytest.mark.parametrize(
74
        "path", ["config/test/paired_nifti.yaml", "config/test/paired_nifti"]
75
    )
76
    def test_dir_err(self, path: Optional[str]):
77
        """
78
        Check the error is raised when the path is wrong.
79
80
        :param path: training data path to be used
81
        """
82
        config = load_yaml("config/test/paired_nifti.yaml")
83
        config["dataset"]["train"]["dir"] = path
84
        with pytest.raises(ValueError) as err_info:
85
            load.get_data_loader(data_config=config["dataset"], split="train")
86
        assert "is not a directory or does not exist" in str(err_info.value)
87
88
    def test_mode_err(self):
89
        """Check the error is raised when the split is wrong."""
90
        config = load_yaml("config/test/paired_nifti.yaml")
91
        with pytest.raises(ValueError) as err_info:
92
            load.get_data_loader(data_config=config["dataset"], split="example")
93
        assert "split must be one of ['train', 'valid', 'test']" in str(err_info.value)
94