1
|
|
|
import os |
2
|
|
|
from datetime import datetime |
3
|
|
|
from typing import Optional, Tuple, Union |
4
|
|
|
|
5
|
|
|
import matplotlib.pyplot as plt |
6
|
|
|
import nibabel as nib |
7
|
|
|
import numpy as np |
8
|
|
|
import pandas as pd |
9
|
|
|
import tensorflow as tf |
10
|
|
|
|
11
|
|
|
import deepreg.loss.label as label_loss |
12
|
|
|
from deepreg import log |
13
|
|
|
from deepreg.dataset.load import get_data_loader |
14
|
|
|
from deepreg.dataset.loader.interface import DataLoader |
15
|
|
|
from deepreg.dataset.loader.util import normalize_array |
16
|
|
|
|
17
|
|
|
logger = log.get(__name__) |
18
|
|
|
|
19
|
|
|
|
20
|
|
|
def build_dataset( |
21
|
|
|
dataset_config: dict, |
22
|
|
|
preprocess_config: dict, |
23
|
|
|
split: str, |
24
|
|
|
training: bool, |
25
|
|
|
repeat: bool, |
26
|
|
|
) -> Tuple[Optional[DataLoader], Optional[tf.data.Dataset], Optional[int]]: |
27
|
|
|
""" |
28
|
|
|
Function to prepare dataset for training and validation. |
29
|
|
|
:param dataset_config: configuration for dataset |
30
|
|
|
:param preprocess_config: configuration for preprocess |
31
|
|
|
:param split: train or valid or test |
32
|
|
|
:param training: bool, if true, data augmentation and shuffling will be added |
33
|
|
|
:param repeat: bool, if true, dataset will be repeated, |
34
|
|
|
true for train/valid dataset during model.fit |
35
|
|
|
|
36
|
|
|
:return: |
37
|
|
|
- (data_loader_train, dataset_train, steps_per_epoch_train) |
38
|
|
|
- (data_loader_val, dataset_val, steps_per_epoch_valid) |
39
|
|
|
|
40
|
|
|
Cannot move this function into deepreg/dataset/util.py |
41
|
|
|
as we need DataLoader to define the output |
42
|
|
|
""" |
43
|
|
|
assert split in ["train", "valid", "test"] |
44
|
|
|
data_loader = get_data_loader(dataset_config, split) |
45
|
|
|
if data_loader is None: |
46
|
|
|
return None, None, None |
47
|
|
|
|
48
|
|
|
dataset = data_loader.get_dataset_and_preprocess( |
49
|
|
|
training=training, repeat=repeat, **preprocess_config |
50
|
|
|
) |
51
|
|
|
dataset_size = data_loader.num_samples |
52
|
|
|
steps_per_epoch = max(dataset_size // preprocess_config["batch_size"], 1) |
53
|
|
|
return data_loader, dataset, steps_per_epoch |
54
|
|
|
|
55
|
|
|
|
56
|
|
|
def build_log_dir(log_dir: str, exp_name: str) -> str: |
57
|
|
|
""" |
58
|
|
|
Build a log directory for the experiment. |
59
|
|
|
|
60
|
|
|
:param log_dir: path of the log directory. |
61
|
|
|
:param exp_name: name of the experiment. |
62
|
|
|
:return: the path of directory to save logs. |
63
|
|
|
""" |
64
|
|
|
log_dir = os.path.join( |
65
|
|
|
os.path.expanduser(log_dir), |
66
|
|
|
datetime.now().strftime("%Y%m%d-%H%M%S") if exp_name == "" else exp_name, |
67
|
|
|
) |
68
|
|
|
if os.path.exists(log_dir): |
69
|
|
|
logger.warning("Log directory %s exists already.", log_dir) |
70
|
|
|
else: |
71
|
|
|
os.makedirs(log_dir) |
72
|
|
|
return log_dir |
73
|
|
|
|
74
|
|
|
|
75
|
|
|
def save_array( |
76
|
|
|
save_dir: str, |
77
|
|
|
arr: Union[np.ndarray, tf.Tensor], |
78
|
|
|
name: str, |
79
|
|
|
normalize: bool, |
80
|
|
|
save_nifti: bool = True, |
81
|
|
|
save_png: bool = True, |
82
|
|
|
overwrite: bool = True, |
83
|
|
|
): |
84
|
|
|
""" |
85
|
|
|
:param save_dir: path of the directory to save |
86
|
|
|
:param arr: 3D or 4D array to be saved |
87
|
|
|
:param name: name of the array, e.g. image, label, etc. |
88
|
|
|
:param normalize: true if the array's value has to be normalized when saving pngs, |
89
|
|
|
false means the value is between [0, 1]. |
90
|
|
|
:param save_nifti: if true, array will be saved in nifti |
91
|
|
|
:param save_png: if true, array will be saved in png |
92
|
|
|
:param overwrite: if false, will not save the file in case the file exists |
93
|
|
|
""" |
94
|
|
|
if isinstance(arr, tf.Tensor): |
95
|
|
|
arr = arr.numpy() |
96
|
|
|
if len(arr.shape) not in [3, 4]: |
97
|
|
|
raise ValueError(f"arr must be 3d or 4d numpy array or tf tensor, got {arr}") |
98
|
|
|
is_4d = len(arr.shape) == 4 |
99
|
|
|
if is_4d: |
100
|
|
|
# if 4D array, it must be 3 channels |
101
|
|
|
if arr.shape[3] != 3: |
102
|
|
|
raise ValueError( |
103
|
|
|
f"4d arr must have 3 channels as last dimension, " |
104
|
|
|
f"got arr.shape = {arr.shape}" |
105
|
|
|
) |
106
|
|
|
|
107
|
|
|
# save in nifti format |
108
|
|
|
if save_nifti: |
109
|
|
|
nifti_file_path = os.path.join(save_dir, name + ".nii.gz") |
110
|
|
|
if overwrite or (not os.path.exists(nifti_file_path)): |
111
|
|
|
# save only if need to overwrite or doesn't exist |
112
|
|
|
os.makedirs(save_dir, exist_ok=True) |
113
|
|
|
# output with Nifti1Image can be loaded by |
114
|
|
|
# - https://www.slicer.org/ |
115
|
|
|
# - http://www.itksnap.org/ |
116
|
|
|
# - http://ric.uthscsa.edu/mango/ |
117
|
|
|
# However, outputs with Nifti2Image couldn't be loaded |
118
|
|
|
nib.save( |
119
|
|
|
img=nib.Nifti1Image(arr, affine=np.eye(4)), filename=nifti_file_path |
120
|
|
|
) |
121
|
|
|
|
122
|
|
|
# save in png |
123
|
|
|
if save_png: |
124
|
|
|
png_dir = os.path.join(save_dir, name) |
125
|
|
|
dir_existed = os.path.exists(png_dir) |
126
|
|
|
if normalize: |
127
|
|
|
# normalize arr such that it has only values between 0, 1 |
128
|
|
|
arr = normalize_array(arr=arr) |
129
|
|
|
for depth_index in range(arr.shape[2]): |
130
|
|
|
png_file_path = os.path.join(png_dir, f"depth{depth_index}_{name}.png") |
131
|
|
|
if overwrite or (not os.path.exists(png_file_path)): |
132
|
|
|
if not dir_existed: |
133
|
|
|
os.makedirs(png_dir, exist_ok=True) |
134
|
|
|
plt.imsave( |
135
|
|
|
fname=png_file_path, |
136
|
|
|
arr=arr[:, :, depth_index, :] if is_4d else arr[:, :, depth_index], |
137
|
|
|
vmin=0, |
138
|
|
|
vmax=1, |
139
|
|
|
cmap="PiYG" if is_4d else "gray", |
140
|
|
|
) |
141
|
|
|
|
142
|
|
|
|
143
|
|
|
def calculate_metrics( |
144
|
|
|
fixed_image: tf.Tensor, |
145
|
|
|
fixed_label: Optional[tf.Tensor], |
146
|
|
|
pred_fixed_image: Optional[tf.Tensor], |
147
|
|
|
pred_fixed_label: Optional[tf.Tensor], |
148
|
|
|
fixed_grid_ref: tf.Tensor, |
149
|
|
|
sample_index: int, |
150
|
|
|
) -> dict: |
151
|
|
|
""" |
152
|
|
|
Calculate image/label based metrics. |
153
|
|
|
:param fixed_image: shape=(batch, f_dim1, f_dim2, f_dim3) |
154
|
|
|
:param fixed_label: shape=(batch, f_dim1, f_dim2, f_dim3) or None |
155
|
|
|
:param pred_fixed_image: shape=(batch, f_dim1, f_dim2, f_dim3) |
156
|
|
|
:param pred_fixed_label: shape=(batch, f_dim1, f_dim2, f_dim3) or None |
157
|
|
|
:param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3) |
158
|
|
|
:param sample_index: int, |
159
|
|
|
:return: dictionary of metrics |
160
|
|
|
""" |
161
|
|
|
|
162
|
|
|
if pred_fixed_image is not None: |
163
|
|
|
y_true = fixed_image[sample_index : (sample_index + 1), :, :, :] |
164
|
|
|
y_pred = pred_fixed_image[sample_index : (sample_index + 1), :, :, :] |
165
|
|
|
y_true = tf.expand_dims(y_true, axis=4) |
166
|
|
|
y_pred = tf.expand_dims(y_pred, axis=4) |
167
|
|
|
ssd = label_loss.SumSquaredDifference()(y_true=y_true, y_pred=y_pred).numpy() |
168
|
|
|
else: |
169
|
|
|
ssd = None |
170
|
|
|
|
171
|
|
|
if fixed_label is not None and pred_fixed_label is not None: |
172
|
|
|
y_true = fixed_label[sample_index : (sample_index + 1), :, :, :] |
173
|
|
|
y_pred = pred_fixed_label[sample_index : (sample_index + 1), :, :, :] |
174
|
|
|
dice = label_loss.DiceScore(binary=True)(y_true=y_true, y_pred=y_pred).numpy() |
175
|
|
|
tre = label_loss.compute_centroid_distance( |
176
|
|
|
y_true=y_true, y_pred=y_pred, grid=fixed_grid_ref |
177
|
|
|
).numpy()[0] |
178
|
|
|
else: |
179
|
|
|
dice = None |
180
|
|
|
tre = None |
181
|
|
|
|
182
|
|
|
return dict(image_ssd=ssd, label_binary_dice=dice, label_tre=tre) |
183
|
|
|
|
184
|
|
|
|
185
|
|
|
def save_metric_dict(save_dir: str, metrics: list): |
186
|
|
|
""" |
187
|
|
|
:param save_dir: directory to save outputs |
188
|
|
|
:param metrics: list of dicts, dict must have key pair_index and label_index |
189
|
|
|
""" |
190
|
|
|
os.makedirs(name=save_dir, exist_ok=True) |
191
|
|
|
|
192
|
|
|
# build dataframe |
193
|
|
|
# column is pair_index, label_index, and metrics |
194
|
|
|
df = pd.DataFrame(metrics) |
195
|
|
|
|
196
|
|
|
# save overall dataframe |
197
|
|
|
df.to_csv(os.path.join(save_dir, "metrics.csv"), index=False) |
198
|
|
|
|
199
|
|
|
# calculate mean/median/std per label |
200
|
|
|
df_per_label = df.drop(["pair_index"], axis=1) |
201
|
|
|
df_per_label = df_per_label.fillna(value=np.nan) |
202
|
|
|
df_per_label = df_per_label.groupby(["label_index"]) |
203
|
|
|
df_per_label = pd.concat( |
204
|
|
|
[ |
205
|
|
|
df_per_label.mean().add_suffix("_mean"), |
206
|
|
|
df_per_label.median().add_suffix("_median"), |
207
|
|
|
df_per_label.std().add_suffix("_std"), |
208
|
|
|
], |
209
|
|
|
axis=1, |
210
|
|
|
sort=True, |
211
|
|
|
) |
212
|
|
|
df_per_label = df_per_label.reindex( |
213
|
|
|
sorted(df_per_label.columns), axis=1 |
214
|
|
|
) # sort columns |
215
|
|
|
df_per_label.to_csv( |
216
|
|
|
os.path.join(save_dir, "metrics_stats_per_label.csv"), index=True |
217
|
|
|
) |
218
|
|
|
|
219
|
|
|
# calculate overall mean/median/std |
220
|
|
|
df_all = df.drop(["pair_index", "label_index"], axis=1) |
221
|
|
|
df_all.describe().to_csv( |
222
|
|
|
os.path.join(save_dir, "metrics_stats_overall.csv"), index=True |
223
|
|
|
) |
224
|
|
|
|