Completed
Push — main ( 0c57ec...f6b5bf )
by Yunguan
18s queued 13s
created

grouped_mask_prostate_longitudinal.demo_data   A

Complexity

Total Complexity 0

Size/Duplication

Total Lines 124
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 0
eloc 92
dl 0
loc 124
rs 10
c 0
b 0
f 0
1
"""
2
Download the demo data and sort them into train, val and test in h5 files
3
"""
4
import os
5
import shutil
6
import zipfile
7
8
import h5py
9
from scipy import ndimage
10
from tensorflow.keras.utils import get_file
11
12
PROJECT_DIR = "demos/grouped_mask_prostate_longitudinal"
13
os.chdir(PROJECT_DIR)
14
15
DATA_PATH = "dataset"
16
ZIP_FILE = "data"
17
ORIGIN = "https://github.com/YipengHu/example-data/raw/master/longi-masks/data.zip"
18
19
if os.path.exists(DATA_PATH):
20
    shutil.rmtree(DATA_PATH)
21
os.mkdir(DATA_PATH)
22
23
zip_file = os.path.join(DATA_PATH, ZIP_FILE + ".zip")
24
get_file(os.path.abspath(zip_file), ORIGIN)
25
with zipfile.ZipFile(zip_file, "r") as zf:
26
    zf.extractall(DATA_PATH)
27
os.remove(zip_file)
28
29
print("\nMask data downloaded: %s." % os.path.abspath(DATA_PATH))
30
31
## now read the data and convert to train/val/test
32
ratio_val = 0.1
33
ratio_test = 0.2
34
35
data_filename = os.path.join(DATA_PATH, ZIP_FILE + ".h5")
36
fid_data = h5py.File(data_filename, "r")
37
num_data = len(fid_data)
38
ids_group, ids_ob = [], []
39
for f in fid_data:
40
    ds, ig, io = fid_data[f].name.split("-")
41
    if ds == "/group":
42
        ids_group.append(int(ig))
43
        ids_ob.append(int(io))
44
ids_group_unique = list(set(ids_group))
45
num_group = len(ids_group_unique)
46
num_val = int(num_group * ratio_val)
47
num_test = int(num_group * ratio_test)
48
num_train = num_group - num_val - num_test
49
50
print("Found %d data in %d groups." % (num_data, num_group))
51
print(
52
    "Dividing into %d-%d-%d for train-val-test (%0.2f-%0.2f-%0.2f)..."
53
    % (
54
        num_train,
55
        num_val,
56
        num_test,
57
        1 - ratio_val - ratio_test,
58
        ratio_val,
59
        ratio_test,
60
    )
61
)
62
63
# write
64
fid_image, fid_label = [], []
65
folders = [
66
    os.path.join(DATA_PATH, "train"),
67
    os.path.join(DATA_PATH, "val"),
68
    os.path.join(DATA_PATH, "test"),
69
]
70
for fn in folders:
71
    os.mkdir(fn)
72
    fid_label.append(h5py.File(os.path.join(fn, "labels.h5"), "w"))
73
    fid_image.append(h5py.File(os.path.join(fn, "images.h5"), "w"))
74
75
for i in range(num_data):
76
    dataset_name = "group-%d-%d" % (ids_group[i], ids_ob[i])
77
    pos_group = ids_group_unique.index(ids_group[i])
78
    if pos_group < num_train:  # train
79
        idf = 0
80
    elif pos_group < (num_train + num_val):  # val
81
        idf = 1
82
    else:  # test
83
        idf = 2
84
    data = fid_data[dataset_name]
85
    fid_label[idf].create_dataset(
86
        dataset_name, shape=data.shape, dtype=data.dtype, data=data
87
    )
88
    fid_label[idf].flush()
89
    image = ndimage.gaussian_filter(
90
        data, sigma=3, output="float32"
91
    )  # smoothing with gaussian
92
    fid_image[idf].create_dataset(
93
        dataset_name, shape=image.shape, dtype=image.dtype, data=image
94
    )
95
    fid_image[idf].flush()
96
    # print(idf,dataset_name)
97
98
# close all
99
fid_data.close()
100
for idf in range(len(folders)):
101
    fid_label[idf].close()
102
    fid_image[idf].close()
103
os.remove(data_filename)
104
105
print("Done. \n")
106
107
## now download the pretrained model
108
MODEL_PATH = os.path.join(DATA_PATH, "pretrained")
109
if os.path.exists(MODEL_PATH):
110
    shutil.rmtree(MODEL_PATH)
111
os.mkdir(MODEL_PATH)
112
113
ZIP_PATH = "grouped_mask_prostate_longitudinal_1"
114
ORIGIN = "https://github.com/DeepRegNet/deepreg-model-zoo/raw/master/demo/grouped_mask_prostate_longitudinal/20210110.zip"
115
116
zip_file = os.path.join(MODEL_PATH, ZIP_PATH + ".zip")
117
get_file(os.path.abspath(zip_file), ORIGIN)
118
with zipfile.ZipFile(zip_file, "r") as zf:
119
    zf.extractall(path=MODEL_PATH)
120
os.remove(zip_file)
121
122
print(
123
    "pretrained model is downloaded and unzipped in %s." % os.path.abspath(MODEL_PATH)
124
)
125