Passed
Pull Request — main (#675)
by Yunguan
03:11
created

paired_mrus_prostate.demo_data   A

Complexity

Total Complexity 0

Size/Duplication

Total Lines 81
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 0
eloc 56
dl 0
loc 81
rs 10
c 0
b 0
f 0
1
"""
2
Download the demo data
3
"""
4
import os
5
import shutil
6
import zipfile
7
8
from tensorflow.keras.utils import get_file
9
10
PROJECT_DIR = "demos/paired_mrus_prostate"
11
os.chdir(PROJECT_DIR)
12
13
DATA_PATH = "dataset"
14
ZIP_PATH = "example-data-mrus"
15
ORIGIN = "https://github.com/yipenghu/example-data/archive/mrus.zip"
16
17
zip_file = ZIP_PATH + ".zip"
18
get_file(os.path.abspath(zip_file), ORIGIN)
19
with zipfile.ZipFile(zip_file, "r") as zf:
20
    zf.extractall()
21
22
if os.path.exists(DATA_PATH):
23
    shutil.rmtree(DATA_PATH)
24
os.rename(ZIP_PATH, DATA_PATH)
25
os.remove(zip_file)
26
27
print("\nMR and ultrasound data downloaded: %s." % os.path.abspath(DATA_PATH))
28
29
# now split the data in to num_part partitions
30
num_part = 11
31
32
data_types = ["moving_images", "moving_labels", "fixed_images", "fixed_labels"]
33
filenames = [sorted(os.listdir(os.path.join(DATA_PATH, fn))) for fn in data_types]
34
num_files = [len(x) for x in filenames]
35
if len(set(num_files)) != 1:
36
    raise ValueError(
37
        "Number of data are not the same between moving/fixed/images/labels. "
38
        "Please run this download script again."
39
    )
40
num_data = num_files[0]
41
42
for idx in range(num_part):  # create partition folders
43
    os.makedirs(os.path.join(DATA_PATH, "part%02d" % idx))
44
    for fn in data_types:
45
        os.makedirs(os.path.join(DATA_PATH, "part%02d" % idx, fn))
46
47
for idx in range(num_data):  # copy all files to part folders
48
    for ifn in range(len(data_types)):
49
        os.rename(
50
            os.path.join(DATA_PATH, data_types[ifn], filenames[ifn][idx]),
51
            os.path.join(
52
                DATA_PATH,
53
                "part%02d" % (idx % num_part),
54
                data_types[ifn],
55
                filenames[ifn][idx],
56
            ),
57
        )
58
59
for fn in data_types:  # remove the old type folders
60
    shutil.rmtree(os.path.join(DATA_PATH, fn))
61
62
print("All data are partitioned into %d folders." % num_part)
63
64
## now download the pre-trained model
65
MODEL_PATH = os.path.join(DATA_PATH, "pretrained")
66
if os.path.exists(MODEL_PATH):
67
    shutil.rmtree(MODEL_PATH)
68
os.mkdir(MODEL_PATH)
69
70
ZIP_PATH = "checkpoint"
71
ORIGIN = "https://github.com/DeepRegNet/deepreg-model-zoo/raw/master/demo/paired_mrus_prostate/20210110.zip"
72
73
zip_file = os.path.join(MODEL_PATH, ZIP_PATH + ".zip")
74
get_file(os.path.abspath(zip_file), ORIGIN)
75
with zipfile.ZipFile(zip_file, "r") as zf:
76
    zf.extractall(path=MODEL_PATH)
77
os.remove(zip_file)
78
79
print(
80
    "Pre-trained model is downloaded and unzipped in %s." % os.path.abspath(MODEL_PATH)
81
)
82