test.unit.test_util.test_save_metric_dict()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 14
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 9
dl 0
loc 14
rs 9.95
c 0
b 0
f 0
cc 1
nop 0
1
import os
2
import re
3
import shutil
4
from test.unit.util import is_equal_np
5
from typing import Tuple
6
7
import nibabel as nib
8
import numpy as np
9
import pytest
10
import tensorflow as tf
11
12
from deepreg.dataset.loader.interface import DataLoader
13
from deepreg.dataset.loader.nifti_loader import load_nifti_file
14
from deepreg.train import build_config
15
from deepreg.util import (
16
    build_dataset,
17
    build_log_dir,
18
    calculate_metrics,
19
    save_array,
20
    save_metric_dict,
21
)
22
23
24
def test_build_dataset():
25
    """
26
    Test build_dataset by checking the output types
27
    """
28
29
    # init arguments
30
    config_path = "config/unpaired_labeled_ddf.yaml"
31
    log_dir = "logs"
32
    exp_name = "test_build_dataset"
33
    ckpt_path = ""
34
35
    # load config
36
    config, _, _ = build_config(
37
        config_path=config_path, log_dir=log_dir, exp_name=exp_name, ckpt_path=ckpt_path
38
    )
39
40
    # build dataset
41
    data_loader_train, dataset_train, steps_per_epoch_train = build_dataset(
42
        dataset_config=config["dataset"],
43
        preprocess_config=config["train"]["preprocess"],
44
        split="train",
45
        training=False,
46
        repeat=False,
47
    )
48
49
    # check output types
50
    assert isinstance(data_loader_train, DataLoader)
51
    assert isinstance(dataset_train, tf.data.Dataset)
52
    assert isinstance(steps_per_epoch_train, int)
53
54
    # remove valid data
55
    config["dataset"]["valid"]["dir"] = ""
56
57
    # build dataset
58
    data_loader_valid, dataset_valid, steps_per_epoch_valid = build_dataset(
59
        dataset_config=config["dataset"],
60
        preprocess_config=config["train"]["preprocess"],
61
        split="valid",
62
        training=False,
63
        repeat=False,
64
    )
65
66
    assert data_loader_valid is None
67
    assert dataset_valid is None
68
    assert steps_per_epoch_valid is None
69
70
71
@pytest.mark.parametrize("log_dir,exp_name", [("logs", ""), ("logs", "custom")])
72
def test_build_log_dir(log_dir: str, exp_name: str):
73
    built_log_dir = build_log_dir(log_dir=log_dir, exp_name=exp_name)
74
    head, tail = os.path.split(built_log_dir)
75
    assert head == log_dir
76
    if exp_name == "":
77
        # use default timestamp based directory
78
        pattern = re.compile("[0-9]{8}-[0-9]{6}")
79
        assert pattern.match(tail)
80
    else:
81
        # use custom directory
82
        assert tail == exp_name
83
84
85
class TestSaveArray:
86
    save_dir = "logs/test_util_save_array"
87
    arr_name = "arr"
88
    png_dir = os.path.join(save_dir, arr_name)
89
    dim_err_msg = "arr must be 3d or 4d numpy array or tf tensor"
90
    ch_err_msg = "4d arr must have 3 channels as last dimension"
91
92
    def setup_method(self, method):
93
        if os.path.exists(self.save_dir):
94
            shutil.rmtree(self.save_dir)
95
96
    def teardown_method(self, method):
97
        if os.path.exists(self.save_dir):
98
            shutil.rmtree(self.save_dir)
99
100
    @staticmethod
101
    def get_num_files_in_dir(dir_path: str, suffix: str):
102
        if os.path.exists(dir_path):
103
            return len([x for x in os.listdir(dir_path) if x.endswith(suffix)])
104
        return 0
105
106
    @pytest.mark.parametrize(
107
        "arr",
108
        [
109
            tf.random.uniform(shape=(2, 3, 4)),
110
            tf.random.uniform(shape=(2, 3, 4, 3)),
111
            np.random.rand(2, 3, 4),
112
            np.random.rand(2, 3, 4, 3),
113
        ],
114
    )
115
    def test_3d_4d(self, arr: Tuple[tf.Tensor, np.ndarray]):
116
        save_array(save_dir=self.save_dir, arr=arr, name=self.arr_name, normalize=True)
117
        assert self.get_num_files_in_dir(self.png_dir, suffix=".png") == 4
118
        assert self.get_num_files_in_dir(self.save_dir, suffix=".nii.gz") == 1
119
120
    @pytest.mark.parametrize(
121
        "arr,err_msg",
122
        [
123
            [tf.random.uniform(shape=(2, 3, 4, 3, 3)), dim_err_msg],
124
            [tf.random.uniform(shape=(2, 3, 4, 1)), ch_err_msg],
125
            [np.random.rand(2, 3, 4, 3, 3), dim_err_msg],
126
            [np.random.rand(2, 3, 4, 1), ch_err_msg],
127
        ],
128
    )
129
    def test_wrong_shape(self, arr: Tuple[tf.Tensor, np.ndarray], err_msg: str):
130
        with pytest.raises(ValueError) as err_info:
131
            save_array(
132
                save_dir=self.save_dir, arr=arr, name=self.arr_name, normalize=True
133
            )
134
        assert err_msg in str(err_info.value)
135
136
    @pytest.mark.parametrize("save_nifti", [True, False])
137
    def test_save_nifti(self, save_nifti: bool):
138
        arr = np.random.rand(2, 3, 4, 3)
139
        save_array(
140
            save_dir=self.save_dir,
141
            arr=arr,
142
            name=self.arr_name,
143
            normalize=True,
144
            save_nifti=save_nifti,
145
        )
146
        assert self.get_num_files_in_dir(self.save_dir, suffix=".nii.gz") == int(
147
            save_nifti
148
        )
149
150
    @pytest.mark.parametrize("save_png", [True, False])
151
    def test_save_png(self, save_png: bool):
152
        arr = np.random.rand(2, 3, 4, 3)
153
        save_array(
154
            save_dir=self.save_dir,
155
            arr=arr,
156
            name=self.arr_name,
157
            normalize=True,
158
            save_png=save_png,
159
        )
160
        assert (
161
            self.get_num_files_in_dir(self.png_dir, suffix=".png") == int(save_png) * 4
162
        )
163
164
    @pytest.mark.parametrize("overwrite", [True, False])
165
    def test_overwrite(self, overwrite: bool):
166
        arr1 = np.random.rand(2, 3, 4, 3)
167
        arr2 = arr1 + 1
168
        nifti_file_path = os.path.join(self.save_dir, self.arr_name + ".nii.gz")
169
        # save arr1
170
        os.makedirs(self.save_dir, exist_ok=True)
171
        nib.save(img=nib.Nifti1Image(arr1, affine=np.eye(4)), filename=nifti_file_path)
172
        # save arr2 w/o overwrite
173
        save_array(
174
            save_dir=self.save_dir,
175
            arr=arr2,
176
            name=self.arr_name,
177
            normalize=True,
178
            overwrite=overwrite,
179
        )
180
        arr_read = load_nifti_file(file_path=nifti_file_path)
181
        assert is_equal_np(arr2 if overwrite else arr1, arr_read)
182
183
184
def test_calculate_metrics():
185
    """
186
    Test calculate_metrics by checking output keys.
187
    Assuming the metrics functions are correct.
188
    """
189
190
    batch_size = 2
191
    fixed_image_shape = (4, 4, 4)  # (f_dim1, f_dim2, f_dim3)
192
193
    fixed_image = tf.random.uniform(shape=(batch_size,) + fixed_image_shape)
194
    fixed_label = tf.random.uniform(shape=(batch_size,) + fixed_image_shape)
195
    pred_fixed_image = tf.random.uniform(shape=(batch_size,) + fixed_image_shape)
196
    pred_fixed_label = tf.random.uniform(shape=(batch_size,) + fixed_image_shape)
197
    fixed_grid_ref = tf.random.uniform(shape=(1,) + fixed_image_shape + (3,))
198
    sample_index = 0
199
200
    # labeled and have pred_fixed_image
201
    got = calculate_metrics(
202
        fixed_image=fixed_image,
203
        fixed_label=fixed_label,
204
        pred_fixed_image=pred_fixed_image,
205
        pred_fixed_label=pred_fixed_label,
206
        fixed_grid_ref=fixed_grid_ref,
207
        sample_index=sample_index,
208
    )
209
    assert got["image_ssd"] is not None
210
    assert got["label_binary_dice"] is not None
211
    assert got["label_tre"] is not None
212
    assert sorted(list(got.keys())) == sorted(
213
        ["image_ssd", "label_binary_dice", "label_tre"]
214
    )
215
216
    # labeled and do not have pred_fixed_image
217
    got = calculate_metrics(
218
        fixed_image=fixed_image,
219
        fixed_label=fixed_label,
220
        pred_fixed_image=None,
221
        pred_fixed_label=pred_fixed_label,
222
        fixed_grid_ref=fixed_grid_ref,
223
        sample_index=sample_index,
224
    )
225
    assert got["image_ssd"] is None
226
    assert got["label_binary_dice"] is not None
227
    assert got["label_tre"] is not None
228
229
    # unlabeled and have pred_fixed_image
230
    got = calculate_metrics(
231
        fixed_image=fixed_image,
232
        fixed_label=None,
233
        pred_fixed_image=pred_fixed_image,
234
        pred_fixed_label=None,
235
        fixed_grid_ref=fixed_grid_ref,
236
        sample_index=sample_index,
237
    )
238
    assert got["image_ssd"] is not None
239
    assert got["label_binary_dice"] is None
240
    assert got["label_tre"] is None
241
242
    # unlabeled and do not have pred_fixed_image
243
    got = calculate_metrics(
244
        fixed_image=fixed_image,
245
        fixed_label=None,
246
        pred_fixed_image=None,
247
        pred_fixed_label=None,
248
        fixed_grid_ref=fixed_grid_ref,
249
        sample_index=sample_index,
250
    )
251
    assert got["image_ssd"] is None
252
    assert got["label_binary_dice"] is None
253
    assert got["label_tre"] is None
254
255
256
def test_save_metric_dict():
257
    """
258
    Test save_metric_dict by checking output files.
259
    """
260
261
    save_dir = "logs/test_save_metric_dict"
262
    metrics = [
263
        dict(image_ssd=0.1, label_dice=0.8, pair_index=[0], label_index=0),
264
        dict(image_ssd=0.2, label_dice=0.7, pair_index=[1], label_index=1),
265
        dict(image_ssd=0.3, label_dice=0.6, pair_index=[2], label_index=0),
266
    ]
267
    save_metric_dict(save_dir=save_dir, metrics=metrics)
268
    assert len([x for x in os.listdir(save_dir) if x.endswith(".csv")]) == 3
269
    shutil.rmtree(save_dir)
270