Total Complexity | 0 |
Total Lines | 46 |
Duplicated Lines | 0 % |
Changes | 0 |
1 | """ |
||
2 | Download the demo data |
||
3 | """ |
||
4 | import os |
||
5 | import shutil |
||
6 | import zipfile |
||
7 | |||
8 | from tensorflow.keras.utils import get_file |
||
9 | |||
10 | PROJECT_DIR = "demos/unpaired_us_prostate_cv" |
||
11 | os.chdir(PROJECT_DIR) |
||
12 | |||
13 | DATA_PATH = "dataset" |
||
14 | DATA_REPO = "dataset_trus3d-master" |
||
15 | ZIP_PATH = "master.zip" |
||
16 | ORIGIN = "https://github.com/ucl-candi/dataset_trus3d/archive/master.zip" |
||
17 | |||
18 | get_file(os.path.abspath(ZIP_PATH), ORIGIN) |
||
19 | with zipfile.ZipFile(ZIP_PATH, "r") as zf: |
||
20 | zf.extractall() |
||
21 | |||
22 | if os.path.exists(DATA_PATH): |
||
23 | shutil.rmtree(DATA_PATH) |
||
24 | shutil.move(DATA_REPO, DATA_PATH) |
||
25 | os.remove(ZIP_PATH) |
||
26 | |||
27 | print("TRUS 3d data downloaded: %s." % os.path.abspath(DATA_PATH)) |
||
28 | |||
29 | # Download the pretrained models |
||
30 | MODEL_PATH = os.path.join(DATA_PATH, "pretrained") |
||
31 | if os.path.exists(MODEL_PATH): |
||
32 | shutil.rmtree(MODEL_PATH) |
||
33 | os.mkdir(MODEL_PATH) |
||
34 | |||
35 | ZIP_PATH = "unpaired_us_prostate_cv_1" |
||
36 | ORIGIN = "https://github.com/DeepRegNet/deepreg-model-zoo/raw/master/demo/unpaired_us_prostate_cv/20210110.zip" |
||
37 | |||
38 | zip_file = os.path.join(MODEL_PATH, ZIP_PATH + ".zip") |
||
39 | get_file(os.path.abspath(zip_file), ORIGIN) |
||
40 | with zipfile.ZipFile(zip_file, "r") as zf: |
||
41 | zf.extractall(path=MODEL_PATH) |
||
42 | os.remove(zip_file) |
||
43 | |||
44 | print( |
||
45 | "pretrained model is downloaded and unzipped in %s." % os.path.abspath(MODEL_PATH) |
||
46 | ) |
||
47 |