test.unit.test_train   A
last analyzed

Complexity

Total Complexity 3

Size/Duplication

Total Lines 107
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 3
eloc 69
dl 0
loc 107
rs 10
c 0
b 0
f 0

2 Methods

Rating   Name   Duplication   Size   Complexity  
A TestBuildConfig.test_ckpt_path() 0 12 1
A TestBuildConfig.test_max_epochs() 0 13 1

1 Function

Rating   Name   Duplication   Size   Complexity  
B test_train_and_predict_main() 0 55 1
1
# coding=utf-8
2
3
"""
4
Tests for deepreg/train.py
5
pytest style
6
"""
7
8
import os
9
import shutil
10
11
import pytest
12
13
from deepreg.predict import main as predict_main
14
from deepreg.train import build_config
15
from deepreg.train import main as train_main
16
17
18
class TestBuildConfig:
19
    # in the config, epochs = save_period = 2
20
    config_path = "config/unpaired_labeled_ddf.yaml"
21
    exp_name = "test_build_config"
22
    log_dir = "logs"
23
24
    @pytest.mark.parametrize("ckpt_path", ["", "example.ckpt"])
25
    def test_ckpt_path(self, ckpt_path):
26
        # check the code can pass
27
28
        got_config, got_log_dir, _ = build_config(
29
            config_path=self.config_path,
30
            log_dir=self.log_dir,
31
            exp_name=self.exp_name,
32
            ckpt_path=ckpt_path,
33
        )
34
        assert isinstance(got_config, dict)
35
        assert got_log_dir == os.path.join(self.log_dir, self.exp_name)
36
37
    @pytest.mark.parametrize(
38
        "max_epochs, expected_epochs, expected_save_period", [(-1, 2, 2), (3, 3, 2)]
39
    )
40
    def test_max_epochs(self, max_epochs, expected_epochs, expected_save_period):
41
        got_config, _, _ = build_config(
42
            config_path=self.config_path,
43
            log_dir=self.log_dir,
44
            exp_name=self.exp_name,
45
            ckpt_path="",
46
            max_epochs=max_epochs,
47
        )
48
        assert got_config["train"]["epochs"] == expected_epochs
49
        assert got_config["train"]["save_period"] == expected_save_period
50
51
52
@pytest.mark.parametrize(
53
    "config_paths",
54
    [
55
        ["config/unpaired_labeled_ddf.yaml"],
56
        ["config/unpaired_labeled_ddf.yaml", "config/test/affine.yaml"],
57
    ],
58
)
59
def test_train_and_predict_main(config_paths):
60
    """
61
    Test main in train and predict by checking it can run.
62
63
    :param config_paths: list of file paths for configuration.
64
    """
65
    train_main(
66
        args=[
67
            "--gpu",
68
            "",
69
            "--exp_name",
70
            "test_train",
71
            "--config_path",
72
        ]
73
        + config_paths
74
    )
75
76
    # check output folders
77
    assert os.path.isdir("logs/test_train/save")
78
    assert os.path.isdir("logs/test_train/train")
79
    assert os.path.isdir("logs/test_train/validation")
80
    assert os.path.isfile("logs/test_train/config.yaml")
81
82
    predict_main(
83
        args=[
84
            "--gpu",
85
            "",
86
            "--ckpt_path",
87
            "logs/test_train/save/ckpt-2",
88
            "--split",
89
            "test",
90
            "--exp_name",
91
            "test_predict",
92
            "--save_nifti",
93
            "--save_png",
94
        ]
95
    )
96
97
    # check output folders
98
    assert os.path.isdir("logs/test_predict/test/pair_0_1/label_0")
99
    assert os.path.isdir("logs/test_predict/test/pair_0_1/label_1")
100
    assert os.path.isdir("logs/test_predict/test/pair_0_1/label_2")
101
    assert os.path.isfile("logs/test_predict/test/metrics.csv")
102
    assert os.path.isfile("logs/test_predict/test/metrics_stats_per_label.csv")
103
    assert os.path.isfile("logs/test_predict/test/metrics_stats_overall.csv")
104
105
    shutil.rmtree("logs/test_train")
106
    shutil.rmtree("logs/test_predict")
107