1
|
|
|
# coding=utf-8 |
2
|
|
|
|
3
|
|
|
""" |
4
|
|
|
Module to download additional data and resources that |
5
|
|
|
are not included in releases via command line interface. |
6
|
|
|
""" |
7
|
|
|
|
8
|
|
|
import argparse |
9
|
|
|
import os |
10
|
|
|
from io import BytesIO |
11
|
|
|
from urllib.request import urlopen |
12
|
|
|
from zipfile import ZipFile |
13
|
|
|
|
14
|
|
|
from deepreg import log |
15
|
|
|
|
16
|
|
|
logger = log.get(__name__) |
17
|
|
|
|
18
|
|
|
|
19
|
|
|
def download(dirs, output_dir="./", branch="main"): |
20
|
|
|
""" |
21
|
|
|
Downloads the files and directories from DeepReg into |
22
|
|
|
`output_dir`, keeping only `dirs`. |
23
|
|
|
|
24
|
|
|
:param dirs: the list of directories to save |
25
|
|
|
:param output_dir: directory which we use as the root to save output |
26
|
|
|
:param branch: The name of the branch from which we download the zip. |
27
|
|
|
:return: void |
28
|
|
|
""" |
29
|
|
|
|
30
|
|
|
output_dir = os.path.abspath(output_dir) # Get the output directory. |
31
|
|
|
|
32
|
|
|
if not os.path.exists(output_dir): |
33
|
|
|
os.mkdir(output_dir) |
34
|
|
|
|
35
|
|
|
logger.info("Will download folders: %s into %s.", dirs, output_dir) |
36
|
|
|
|
37
|
|
|
zip_url = f"https://github.com/DeepRegNet/DeepReg/archive/{branch}.zip" |
38
|
|
|
logger.info("Downloading archive from DeepReg repository %s.", zip_url) |
39
|
|
|
response = urlopen(zip_url) # Download the zip. |
40
|
|
|
logger.info("Downloaded archive. Extracting files.") |
41
|
|
|
|
42
|
|
|
with ZipFile(BytesIO(response.read())) as zf: |
43
|
|
|
|
44
|
|
|
pathnames = zf.namelist() |
45
|
|
|
head = pathnames[0] |
46
|
|
|
keepdirs = [ |
47
|
|
|
os.path.join(head, d) for d in dirs |
48
|
|
|
] # Find our folders to keep, based on what user specifies. |
49
|
|
|
|
50
|
|
|
for pathname in pathnames: |
51
|
|
|
if any(d in pathname for d in keepdirs): |
52
|
|
|
|
53
|
|
|
info = zf.getinfo(pathname) |
54
|
|
|
info.filename = info.filename.replace( |
55
|
|
|
head, "" |
56
|
|
|
) # Remove head directory from filepath |
57
|
|
|
zf.extract(info, output_dir) |
58
|
|
|
|
59
|
|
|
logger.info("Downloaded %s", info.filename) |
60
|
|
|
|
61
|
|
|
|
62
|
|
|
def main(args=None): |
63
|
|
|
""" |
64
|
|
|
Entry point for downloading data. |
65
|
|
|
|
66
|
|
|
:param args: |
67
|
|
|
""" |
68
|
|
|
|
69
|
|
|
parser = argparse.ArgumentParser() |
70
|
|
|
parser.add_argument( |
71
|
|
|
"--output_dir", |
72
|
|
|
"-d", |
73
|
|
|
dest="output_dir", |
74
|
|
|
default="./", |
75
|
|
|
help="All directories will be downloaded to the specified directory.", |
76
|
|
|
) |
77
|
|
|
parser.add_argument( |
78
|
|
|
"--branch", |
79
|
|
|
"-b", |
80
|
|
|
dest="branch", |
81
|
|
|
default="main", |
82
|
|
|
help="The name of the branch to download.", |
83
|
|
|
) |
84
|
|
|
args = parser.parse_args(args) |
85
|
|
|
|
86
|
|
|
dirs = [ |
87
|
|
|
"config", |
88
|
|
|
"data", |
89
|
|
|
"demos", |
90
|
|
|
] |
91
|
|
|
|
92
|
|
|
download(dirs, args.output_dir, args.branch) |
93
|
|
|
|
94
|
|
|
logger.info( |
95
|
|
|
"Download complete. " |
96
|
|
|
"Please refer to the DeepReg Quick Start guide for next steps." |
97
|
|
|
) |
98
|
|
|
|
99
|
|
|
|
100
|
|
|
if __name__ == "__main__": |
101
|
|
|
main() # pragma: no cover |
102
|
|
|
|