Passed
Pull Request — main (#675)
by Yunguan
03:15
created

unpaired_ct_abdomen.demo_data   A

Complexity

Total Complexity 0

Size/Duplication

Total Lines 115
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 0
eloc 80
dl 0
loc 115
rs 10
c 0
b 0
f 0
1
"""Download and preprocess data."""
2
import os
3
import shutil
4
import zipfile
5
6
import nibabel as nib
7
import numpy as np
8
from tensorflow.keras.utils import get_file
9
10
PROJECT_DIR = "demos/unpaired_ct_abdomen"
11
os.chdir(PROJECT_DIR)
12
13
ORIGIN = "https://github.com/ucl-candi/datasets_deepreg_demo/archive/abdct.zip"
14
ZIP_PATH = "abdct.zip"
15
DATA_PATH = "dataset"
16
17
get_file(os.path.abspath(ZIP_PATH), ORIGIN)
18
19
zf = zipfile.ZipFile(ZIP_PATH)
20
filenames_all = [fn for fn in zf.namelist() if fn.split(".")[-1] == "gz"]
21
num_data = int(len(filenames_all) / 2)
22
# check indices
23
filenames_indices = list(
24
    set([int(fn.split("/")[-1].split(".")[0]) for fn in filenames_all])
25
)
26
if len(filenames_indices) is not num_data:
27
    raise ValueError("Images and labels are not in pairs.")
28
29
print("\nAbdominal CT data downloaded with %d image-label pairs." % num_data)
30
31
ratio_val = 0.1
32
ratio_test = 0.15
33
num_val = int(num_data * ratio_val)
34
num_test = int(num_data * ratio_test)
35
num_train = num_data - num_val - num_test
36
37
print(
38
    "Extracting data into %d-%d-%d for train-val-test (%0.2f-%0.2f-%0.2f)..."
39
    % (num_train, num_val, num_test, 1 - ratio_val - ratio_test, ratio_val, ratio_test)
40
)
41
42
# extract to respective folders
43
folders = [os.path.join(DATA_PATH, dn) for dn in ["train", "val", "test"]]
44
if os.path.exists(DATA_PATH):
45
    shutil.rmtree(DATA_PATH)
46
os.mkdir(DATA_PATH)
47
for fn in folders:
48
    os.mkdir(fn)
49
    os.mkdir(os.path.join(fn, "images"))
50
    os.mkdir(os.path.join(fn, "labels"))
51
52
for filename in filenames_all:
53
    # images or labels
54
    if filename.startswith("datasets_deepreg_demo-abdct/dataset/images"):
55
        typename = "images"
56
    elif filename.startswith("datasets_deepreg_demo-abdct/dataset/labels"):
57
        typename = "labels"
58
    else:
59
        continue
60
    # train, val or test
61
    idx = filenames_indices.index(int(filename.split("/")[-1].split(".")[0]))
62
    if idx < num_train:  # train
63
        fidx = 0
64
    elif idx < (num_train + num_val):  # val
65
        fidx = 1
66
    else:  # test
67
        fidx = 2
68
    filename_dst = os.path.join(folders[fidx], typename, filename.split("/")[-1])
69
    with zf.open(filename) as sf, open(filename_dst, "wb") as df:
70
        shutil.copyfileobj(sf, df)
71
    # re-encode the label files - hard-coded using 13 of them regardless exists or not
72
    if typename == "labels":
73
        img = nib.load(filename_dst)
74
        img1 = np.stack([np.asarray(img.dataobj) == i for i in range(1, 14)], axis=3)
75
        img1 = nib.Nifti1Image(img1.astype(np.int8), img.affine)
76
        img1.to_filename(filename_dst)
77
78
os.remove(ZIP_PATH)
79
80
print("Done. \n")
81
82
# Download the pretrained models
83
# https://github.com/DeepRegNet/deepreg-model-zoo/raw/master/unpaired_ct_abdomen-unsup.zip
84
# https://github.com/DeepRegNet/deepreg-model-zoo/raw/master/unpaired_ct_abdomen-weakly.zip
85
# https://github.com/DeepRegNet/deepreg-model-zoo/raw/master/unpaired_ct_abdomen-comb.zip
86
# will be downloaded to, respectively,
87
# dataset/pretrained/unsup
88
# dataset/pretrained/weakly
89
# dataset/pretrained/comb
90
91
MODEL_PATH = os.path.join(DATA_PATH, "pretrained")
92
if os.path.exists(MODEL_PATH):
93
    shutil.rmtree(MODEL_PATH)
94
os.mkdir(MODEL_PATH)
95
96
model_names = ["unsup", "weakly", "comb"]
97
for mname in model_names:
98
    model_path_single = os.path.join(MODEL_PATH, mname)
99
    os.mkdir(model_path_single)
100
    zip_path = "unpaired_ct_abdomen-" + mname
101
    origin = (
102
        "https://github.com/DeepRegNet/deepreg-model-zoo/raw/master/"
103
        + zip_path
104
        + ".zip"
105
    )
106
    zip_file = os.path.join(model_path_single, zip_path + ".zip")
107
    get_file(os.path.abspath(zip_file), origin)
108
    with zipfile.ZipFile(zip_file, "r") as zf:
109
        zf.extractall(path=model_path_single)
110
    os.remove(zip_file)
111
112
print(
113
    "Pretrained models are downloaded and unzipped in individual folders at %s."
114
    % os.path.abspath(MODEL_PATH)
115
)
116