Passed
Pull Request — main (#675)
by Yunguan
03:11
created

grouped_mr_heart.demo_data   A

Complexity

Total Complexity 0

Size/Duplication

Total Lines 140
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 0
eloc 99
dl 0
loc 140
rs 10
c 0
b 0
f 0
1
import os
2
import shutil
3
import zipfile
4
5
import nibabel as nib
6
from scipy import ndimage
7
from tensorflow.keras.utils import get_file
8
9
output_pixdim = 1.5
10
11
PROJECT_DIR = "demos/grouped_mr_heart"
12
os.chdir(PROJECT_DIR)
13
14
ORIGIN = "https://github.com/ucl-candi/datasets_deepreg_demo/archive/myops.zip"
15
ZIP_PATH = "myops.zip"
16
DATA_PATH = "dataset"
17
18
get_file(os.path.abspath(ZIP_PATH), ORIGIN)
19
20
zf = zipfile.ZipFile(ZIP_PATH)
21
filenames_all = [fn for fn in zf.namelist() if fn.split(".")[-1] == "gz"]
22
num_data = int(len(filenames_all) / 3)
23
# check indices
24
filenames_indices = list(
25
    set([int(fn.split("/")[-1].split("_")[0]) for fn in filenames_all])
26
)
27
if len(filenames_indices) is not num_data:
28
    raise ValueError("Missing data in image groups.")
29
30
if os.path.exists(DATA_PATH):
31
    shutil.rmtree(DATA_PATH)
32
os.mkdir(DATA_PATH)
33
34
print(
35
    "\nCMR data from %d subjects downloaded, being extracted and resampled..."
36
    % num_data
37
)
38
print("This may take a few minutes...")
39
40
# extract into image groups
41
images_path = os.path.join(DATA_PATH, "images")
42
os.mkdir(images_path)
43
44
for filename in filenames_all:
45
    # groups, here same as subjects
46
    idx, seq_name = filename.split("/")[-1].split("_")
47
    idx_group = filenames_indices.index(int(idx))
48
    group_path = os.path.join(images_path, "subject" + "%03d" % idx_group)
49
    if os.path.exists(group_path) is not True:
50
        os.mkdir(group_path)
51
52
    # extract image
53
    img_path = os.path.join(group_path, seq_name)
54
    with zf.open(filename) as sf, open(img_path, "wb") as df:
55
        shutil.copyfileobj(sf, df)
56
    # pre-processing
57
    img = nib.load(img_path)
58
    img = nib.Nifti1Image(
59
        ndimage.zoom(
60
            img.dataobj, [pd / output_pixdim for pd in img.header.get_zooms()]
61
        ),
62
        [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]],
63
    )  # to a generic affine after resampling
64
    img.to_filename(img_path)
65
66
os.remove(ZIP_PATH)
67
68
print("Done")
69
70
ratio_val = 0.05
71
ratio_test = 0.10
72
num_val = int(num_data * ratio_val)
73
num_test = int(num_data * ratio_test)
74
num_train = num_data - num_val - num_test
75
76
print(
77
    "Splitting data into %d-%d-%d for train-val-test (%0.2f-%0.2f-%0.2f)..."
78
    % (
79
        num_train,
80
        num_val,
81
        num_test,
82
        1 - ratio_val - ratio_test,
83
        ratio_val,
84
        ratio_test,
85
    )
86
)
87
88
# move images to respective folders
89
folders = [os.path.join(DATA_PATH, dn) for dn in ["train", "val", "test"]]
90
91
for fn in folders:
92
    os.mkdir(fn)
93
    os.mkdir(os.path.join(fn, "images"))
94
95
group_names = os.listdir(images_path)
96
for g_idx, group in enumerate(group_names):
97
    if g_idx < num_train:  # train
98
        fidx = 0
99
    elif g_idx < (num_train + num_val):  # val
100
        fidx = 1
101
    else:  # test
102
        fidx = 2
103
    shutil.move(os.path.join(images_path, group), os.path.join(folders[fidx], "images"))
104
105
os.rmdir(images_path)
106
107
print("Done. \n")
108
109
# Download the pretrained models
110
MODEL_PATH = os.path.join(DATA_PATH, "pretrained")
111
if os.path.exists(MODEL_PATH):
112
    shutil.rmtree(MODEL_PATH)
113
os.mkdir(MODEL_PATH)
114
115
num_zipfiles = 11
116
zip_filepath = os.path.abspath(os.path.join(MODEL_PATH, "checkpoint.zip"))
117
zip_file_parts = [zip_filepath + ".%02d" % idx for idx in range(num_zipfiles)]
118
for zip_file_idx, zip_file in enumerate(zip_file_parts):
119
    ORIGIN = (
120
        "https://github.com/DeepRegNet/deepreg-model-zoo/raw/master/demo/grouped_mr_heart/20210110/part.%02d"
121
        % zip_file_idx
122
    )
123
    get_file(zip_file, ORIGIN)
124
125
# combine all the files then extract
126
with open(os.path.join(MODEL_PATH, zip_filepath), "ab") as f:
127
    for zip_file in zip_file_parts:
128
        with open(zip_file, "rb") as z:
129
            f.write(z.read())
130
with zipfile.ZipFile(zip_filepath, "r") as zf:
131
    zf.extractall(path=MODEL_PATH)
132
133
# remove zip files
134
for zip_file in zip_file_parts:
135
    os.remove(zip_file)
136
os.remove(zip_filepath)
137
138
print(
139
    "pretrained model is downloaded and unzipped in %s." % os.path.abspath(MODEL_PATH)
140
)
141