1
|
|
|
import os |
2
|
|
|
import re |
3
|
|
|
import shutil |
4
|
|
|
from test.unit.util import is_equal_np |
5
|
|
|
from typing import Tuple |
6
|
|
|
|
7
|
|
|
import nibabel as nib |
8
|
|
|
import numpy as np |
9
|
|
|
import pytest |
10
|
|
|
import tensorflow as tf |
11
|
|
|
|
12
|
|
|
from deepreg.dataset.loader.interface import DataLoader |
13
|
|
|
from deepreg.dataset.loader.nifti_loader import load_nifti_file |
14
|
|
|
from deepreg.train import build_config |
15
|
|
|
from deepreg.util import ( |
16
|
|
|
build_dataset, |
17
|
|
|
build_log_dir, |
18
|
|
|
calculate_metrics, |
19
|
|
|
save_array, |
20
|
|
|
save_metric_dict, |
21
|
|
|
) |
22
|
|
|
|
23
|
|
|
|
24
|
|
|
def test_build_dataset(): |
25
|
|
|
""" |
26
|
|
|
Test build_dataset by checking the output types |
27
|
|
|
""" |
28
|
|
|
|
29
|
|
|
# init arguments |
30
|
|
|
config_path = "config/unpaired_labeled_ddf.yaml" |
31
|
|
|
log_dir = "logs" |
32
|
|
|
exp_name = "test_build_dataset" |
33
|
|
|
ckpt_path = "" |
34
|
|
|
|
35
|
|
|
# load config |
36
|
|
|
config, _, _ = build_config( |
37
|
|
|
config_path=config_path, log_dir=log_dir, exp_name=exp_name, ckpt_path=ckpt_path |
38
|
|
|
) |
39
|
|
|
|
40
|
|
|
# build dataset |
41
|
|
|
data_loader_train, dataset_train, steps_per_epoch_train = build_dataset( |
42
|
|
|
dataset_config=config["dataset"], |
43
|
|
|
preprocess_config=config["train"]["preprocess"], |
44
|
|
|
split="train", |
45
|
|
|
training=False, |
46
|
|
|
repeat=False, |
47
|
|
|
) |
48
|
|
|
|
49
|
|
|
# check output types |
50
|
|
|
assert isinstance(data_loader_train, DataLoader) |
51
|
|
|
assert isinstance(dataset_train, tf.data.Dataset) |
52
|
|
|
assert isinstance(steps_per_epoch_train, int) |
53
|
|
|
|
54
|
|
|
# remove valid data |
55
|
|
|
config["dataset"]["valid"]["dir"] = "" |
56
|
|
|
|
57
|
|
|
# build dataset |
58
|
|
|
data_loader_valid, dataset_valid, steps_per_epoch_valid = build_dataset( |
59
|
|
|
dataset_config=config["dataset"], |
60
|
|
|
preprocess_config=config["train"]["preprocess"], |
61
|
|
|
split="valid", |
62
|
|
|
training=False, |
63
|
|
|
repeat=False, |
64
|
|
|
) |
65
|
|
|
|
66
|
|
|
assert data_loader_valid is None |
67
|
|
|
assert dataset_valid is None |
68
|
|
|
assert steps_per_epoch_valid is None |
69
|
|
|
|
70
|
|
|
|
71
|
|
|
@pytest.mark.parametrize("log_dir,exp_name", [("logs", ""), ("logs", "custom")]) |
72
|
|
|
def test_build_log_dir(log_dir: str, exp_name: str): |
73
|
|
|
built_log_dir = build_log_dir(log_dir=log_dir, exp_name=exp_name) |
74
|
|
|
head, tail = os.path.split(built_log_dir) |
75
|
|
|
assert head == log_dir |
76
|
|
|
if exp_name == "": |
77
|
|
|
# use default timestamp based directory |
78
|
|
|
pattern = re.compile("[0-9]{8}-[0-9]{6}") |
79
|
|
|
assert pattern.match(tail) |
80
|
|
|
else: |
81
|
|
|
# use custom directory |
82
|
|
|
assert tail == exp_name |
83
|
|
|
|
84
|
|
|
|
85
|
|
|
class TestSaveArray: |
86
|
|
|
save_dir = "logs/test_util_save_array" |
87
|
|
|
arr_name = "arr" |
88
|
|
|
png_dir = os.path.join(save_dir, arr_name) |
89
|
|
|
dim_err_msg = "arr must be 3d or 4d numpy array or tf tensor" |
90
|
|
|
ch_err_msg = "4d arr must have 3 channels as last dimension" |
91
|
|
|
|
92
|
|
|
def setup_method(self, method): |
93
|
|
|
if os.path.exists(self.save_dir): |
94
|
|
|
shutil.rmtree(self.save_dir) |
95
|
|
|
|
96
|
|
|
def teardown_method(self, method): |
97
|
|
|
if os.path.exists(self.save_dir): |
98
|
|
|
shutil.rmtree(self.save_dir) |
99
|
|
|
|
100
|
|
|
@staticmethod |
101
|
|
|
def get_num_files_in_dir(dir_path: str, suffix: str): |
102
|
|
|
if os.path.exists(dir_path): |
103
|
|
|
return len([x for x in os.listdir(dir_path) if x.endswith(suffix)]) |
104
|
|
|
return 0 |
105
|
|
|
|
106
|
|
|
@pytest.mark.parametrize( |
107
|
|
|
"arr", |
108
|
|
|
[ |
109
|
|
|
tf.random.uniform(shape=(2, 3, 4)), |
110
|
|
|
tf.random.uniform(shape=(2, 3, 4, 3)), |
111
|
|
|
np.random.rand(2, 3, 4), |
112
|
|
|
np.random.rand(2, 3, 4, 3), |
113
|
|
|
], |
114
|
|
|
) |
115
|
|
|
def test_3d_4d(self, arr: Tuple[tf.Tensor, np.ndarray]): |
116
|
|
|
save_array(save_dir=self.save_dir, arr=arr, name=self.arr_name, normalize=True) |
117
|
|
|
assert self.get_num_files_in_dir(self.png_dir, suffix=".png") == 4 |
118
|
|
|
assert self.get_num_files_in_dir(self.save_dir, suffix=".nii.gz") == 1 |
119
|
|
|
|
120
|
|
|
@pytest.mark.parametrize( |
121
|
|
|
"arr,err_msg", |
122
|
|
|
[ |
123
|
|
|
[tf.random.uniform(shape=(2, 3, 4, 3, 3)), dim_err_msg], |
124
|
|
|
[tf.random.uniform(shape=(2, 3, 4, 1)), ch_err_msg], |
125
|
|
|
[np.random.rand(2, 3, 4, 3, 3), dim_err_msg], |
126
|
|
|
[np.random.rand(2, 3, 4, 1), ch_err_msg], |
127
|
|
|
], |
128
|
|
|
) |
129
|
|
|
def test_wrong_shape(self, arr: Tuple[tf.Tensor, np.ndarray], err_msg: str): |
130
|
|
|
with pytest.raises(ValueError) as err_info: |
131
|
|
|
save_array( |
132
|
|
|
save_dir=self.save_dir, arr=arr, name=self.arr_name, normalize=True |
133
|
|
|
) |
134
|
|
|
assert err_msg in str(err_info.value) |
135
|
|
|
|
136
|
|
|
@pytest.mark.parametrize("save_nifti", [True, False]) |
137
|
|
|
def test_save_nifti(self, save_nifti: bool): |
138
|
|
|
arr = np.random.rand(2, 3, 4, 3) |
139
|
|
|
save_array( |
140
|
|
|
save_dir=self.save_dir, |
141
|
|
|
arr=arr, |
142
|
|
|
name=self.arr_name, |
143
|
|
|
normalize=True, |
144
|
|
|
save_nifti=save_nifti, |
145
|
|
|
) |
146
|
|
|
assert self.get_num_files_in_dir(self.save_dir, suffix=".nii.gz") == int( |
147
|
|
|
save_nifti |
148
|
|
|
) |
149
|
|
|
|
150
|
|
|
@pytest.mark.parametrize("save_png", [True, False]) |
151
|
|
|
def test_save_png(self, save_png: bool): |
152
|
|
|
arr = np.random.rand(2, 3, 4, 3) |
153
|
|
|
save_array( |
154
|
|
|
save_dir=self.save_dir, |
155
|
|
|
arr=arr, |
156
|
|
|
name=self.arr_name, |
157
|
|
|
normalize=True, |
158
|
|
|
save_png=save_png, |
159
|
|
|
) |
160
|
|
|
assert ( |
161
|
|
|
self.get_num_files_in_dir(self.png_dir, suffix=".png") == int(save_png) * 4 |
162
|
|
|
) |
163
|
|
|
|
164
|
|
|
@pytest.mark.parametrize("overwrite", [True, False]) |
165
|
|
|
def test_overwrite(self, overwrite: bool): |
166
|
|
|
arr1 = np.random.rand(2, 3, 4, 3) |
167
|
|
|
arr2 = arr1 + 1 |
168
|
|
|
nifti_file_path = os.path.join(self.save_dir, self.arr_name + ".nii.gz") |
169
|
|
|
# save arr1 |
170
|
|
|
os.makedirs(self.save_dir, exist_ok=True) |
171
|
|
|
nib.save(img=nib.Nifti1Image(arr1, affine=np.eye(4)), filename=nifti_file_path) |
172
|
|
|
# save arr2 w/o overwrite |
173
|
|
|
save_array( |
174
|
|
|
save_dir=self.save_dir, |
175
|
|
|
arr=arr2, |
176
|
|
|
name=self.arr_name, |
177
|
|
|
normalize=True, |
178
|
|
|
overwrite=overwrite, |
179
|
|
|
) |
180
|
|
|
arr_read = load_nifti_file(file_path=nifti_file_path) |
181
|
|
|
assert is_equal_np(arr2 if overwrite else arr1, arr_read) |
182
|
|
|
|
183
|
|
|
|
184
|
|
|
def test_calculate_metrics(): |
185
|
|
|
""" |
186
|
|
|
Test calculate_metrics by checking output keys. |
187
|
|
|
Assuming the metrics functions are correct. |
188
|
|
|
""" |
189
|
|
|
|
190
|
|
|
batch_size = 2 |
191
|
|
|
fixed_image_shape = (4, 4, 4) # (f_dim1, f_dim2, f_dim3) |
192
|
|
|
|
193
|
|
|
fixed_image = tf.random.uniform(shape=(batch_size,) + fixed_image_shape) |
194
|
|
|
fixed_label = tf.random.uniform(shape=(batch_size,) + fixed_image_shape) |
195
|
|
|
pred_fixed_image = tf.random.uniform(shape=(batch_size,) + fixed_image_shape) |
196
|
|
|
pred_fixed_label = tf.random.uniform(shape=(batch_size,) + fixed_image_shape) |
197
|
|
|
fixed_grid_ref = tf.random.uniform(shape=(1,) + fixed_image_shape + (3,)) |
198
|
|
|
sample_index = 0 |
199
|
|
|
|
200
|
|
|
# labeled and have pred_fixed_image |
201
|
|
|
got = calculate_metrics( |
202
|
|
|
fixed_image=fixed_image, |
203
|
|
|
fixed_label=fixed_label, |
204
|
|
|
pred_fixed_image=pred_fixed_image, |
205
|
|
|
pred_fixed_label=pred_fixed_label, |
206
|
|
|
fixed_grid_ref=fixed_grid_ref, |
207
|
|
|
sample_index=sample_index, |
208
|
|
|
) |
209
|
|
|
assert got["image_ssd"] is not None |
210
|
|
|
assert got["label_binary_dice"] is not None |
211
|
|
|
assert got["label_tre"] is not None |
212
|
|
|
assert sorted(list(got.keys())) == sorted( |
213
|
|
|
["image_ssd", "label_binary_dice", "label_tre"] |
214
|
|
|
) |
215
|
|
|
|
216
|
|
|
# labeled and do not have pred_fixed_image |
217
|
|
|
got = calculate_metrics( |
218
|
|
|
fixed_image=fixed_image, |
219
|
|
|
fixed_label=fixed_label, |
220
|
|
|
pred_fixed_image=None, |
221
|
|
|
pred_fixed_label=pred_fixed_label, |
222
|
|
|
fixed_grid_ref=fixed_grid_ref, |
223
|
|
|
sample_index=sample_index, |
224
|
|
|
) |
225
|
|
|
assert got["image_ssd"] is None |
226
|
|
|
assert got["label_binary_dice"] is not None |
227
|
|
|
assert got["label_tre"] is not None |
228
|
|
|
|
229
|
|
|
# unlabeled and have pred_fixed_image |
230
|
|
|
got = calculate_metrics( |
231
|
|
|
fixed_image=fixed_image, |
232
|
|
|
fixed_label=None, |
233
|
|
|
pred_fixed_image=pred_fixed_image, |
234
|
|
|
pred_fixed_label=None, |
235
|
|
|
fixed_grid_ref=fixed_grid_ref, |
236
|
|
|
sample_index=sample_index, |
237
|
|
|
) |
238
|
|
|
assert got["image_ssd"] is not None |
239
|
|
|
assert got["label_binary_dice"] is None |
240
|
|
|
assert got["label_tre"] is None |
241
|
|
|
|
242
|
|
|
# unlabeled and do not have pred_fixed_image |
243
|
|
|
got = calculate_metrics( |
244
|
|
|
fixed_image=fixed_image, |
245
|
|
|
fixed_label=None, |
246
|
|
|
pred_fixed_image=None, |
247
|
|
|
pred_fixed_label=None, |
248
|
|
|
fixed_grid_ref=fixed_grid_ref, |
249
|
|
|
sample_index=sample_index, |
250
|
|
|
) |
251
|
|
|
assert got["image_ssd"] is None |
252
|
|
|
assert got["label_binary_dice"] is None |
253
|
|
|
assert got["label_tre"] is None |
254
|
|
|
|
255
|
|
|
|
256
|
|
|
def test_save_metric_dict(): |
257
|
|
|
""" |
258
|
|
|
Test save_metric_dict by checking output files. |
259
|
|
|
""" |
260
|
|
|
|
261
|
|
|
save_dir = "logs/test_save_metric_dict" |
262
|
|
|
metrics = [ |
263
|
|
|
dict(image_ssd=0.1, label_dice=0.8, pair_index=[0], label_index=0), |
264
|
|
|
dict(image_ssd=0.2, label_dice=0.7, pair_index=[1], label_index=1), |
265
|
|
|
dict(image_ssd=0.3, label_dice=0.6, pair_index=[2], label_index=0), |
266
|
|
|
] |
267
|
|
|
save_metric_dict(save_dir=save_dir, metrics=metrics) |
268
|
|
|
assert len([x for x in os.listdir(save_dir) if x.endswith(".csv")]) == 3 |
269
|
|
|
shutil.rmtree(save_dir) |
270
|
|
|
|