Passed
Push — master ( ca833f...7f8818 )
by Fernando
01:06
created

torchio.download._is_zip()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
"""
2
Most of this code is from torchvision.
3
I will remove all this once verbosity is reduced.
4
More info: https://github.com/pytorch/vision/issues/2830
5
"""
6
7
import os
8
import gzip
9
import urllib
10
import zipfile
11
import hashlib
12
import tarfile
13
from typing import Optional
14
15
from torch.hub import tqdm
16
17
from .typing import TypePath
18
19
20
def calculate_md5(fpath, chunk_size=1024 * 1024):
21
    md5 = hashlib.md5()
22
    with open(fpath, 'rb') as f:
23
        for chunk in iter(lambda: f.read(chunk_size), b''):
24
            md5.update(chunk)
25
    return md5.hexdigest()
26
27
28
def check_md5(fpath, md5, **kwargs):
29
    return md5 == calculate_md5(fpath, **kwargs)
30
31
32
def check_integrity(fpath, md5=None):
33
    if not os.path.isfile(fpath):
34
        return False
35
    if md5 is None:
36
        return True
37
    return check_md5(fpath, md5)
38
39
40
def gen_bar_updater():
41
    pbar = tqdm(total=None)
42
43
    def bar_update(count, block_size, total_size):
44
        if pbar.total is None and total_size:
45
            pbar.total = total_size
46
        progress_bytes = count * block_size
47
        pbar.update(progress_bytes - pbar.n)
48
49
    return bar_update
50
51
52
# Adapted from torchvision, removing print statements
53
def download_and_extract_archive(
54
        url: str,
55
        download_root: TypePath,
56
        extract_root: Optional[TypePath] = None,
57
        filename: Optional[TypePath] = None,
58
        md5: str = None,
59
        remove_finished: bool = False,
60
        ) -> None:
61
    download_root = os.path.expanduser(download_root)
62
    if extract_root is None:
63
        extract_root = download_root
64
    if not filename:
65
        filename = os.path.basename(url)
66
    download_url(url, download_root, filename, md5)
67
    archive = os.path.join(download_root, filename)
68
    extract_archive(archive, extract_root, remove_finished)
69
70
71
def _is_tarxz(filename):
72
    return filename.endswith('.tar.xz')
73
74
75
def _is_tar(filename):
76
    return filename.endswith('.tar')
77
78
79
def _is_targz(filename):
80
    return filename.endswith('.tar.gz')
81
82
83
def _is_tgz(filename):
84
    return filename.endswith('.tgz')
85
86
87
def _is_gzip(filename):
88
    return filename.endswith('.gz') and not filename.endswith('.tar.gz')
89
90
91
def _is_zip(filename):
92
    return filename.endswith('.zip')
93
94
95
def extract_archive(from_path, to_path=None, remove_finished=False):
96
    if to_path is None:
97
        to_path = os.path.dirname(from_path)
98
99
    if _is_tar(from_path):
100
        with tarfile.open(from_path, 'r') as tar:
101
            tar.extractall(path=to_path)
102
    elif _is_targz(from_path) or _is_tgz(from_path):
103
        with tarfile.open(from_path, 'r:gz') as tar:
104
            tar.extractall(path=to_path)
105
    elif _is_tarxz(from_path):
106
        with tarfile.open(from_path, 'r:xz') as tar:
107
            tar.extractall(path=to_path)
108
    elif _is_gzip(from_path):
109
        to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
110
        with open(to_path, 'wb') as out_f, gzip.GzipFile(from_path) as zip_f:
111
            out_f.write(zip_f.read())
112
    elif _is_zip(from_path):
113
        with zipfile.ZipFile(from_path, 'r') as z:
114
            z.extractall(to_path)
115
    else:
116
        raise ValueError(f'Extraction of {from_path} not supported')
117
118
    if remove_finished:
119
        os.remove(from_path)
120
121
122
# Adapted from torchvision, removing print statements
123
def download_url(
124
        url: str,
125
        root: TypePath,
126
        filename: Optional[TypePath] = None,
127
        md5: str = None,
128
        ) -> None:
129
    """Download a file from a url and place it in root.
130
131
    Args:
132
        url: URL to download file from
133
        root: Directory to place downloaded file in
134
        filename: Name to save the file under.
135
            If ``None``, use the basename of the URL
136
        md5: MD5 checksum of the download. If None, do not check
137
    """
138
139
    root = os.path.expanduser(root)
140
    if not filename:
141
        filename = os.path.basename(url)
142
    fpath = os.path.join(root, filename)
143
    os.makedirs(root, exist_ok=True)
144
    # check if file is already present locally
145
    if not check_integrity(fpath, md5):
146
        try:
147
            print('Downloading ' + url + ' to ' + fpath)  # noqa: T001
148
            urllib.request.urlretrieve(
149
                url, fpath,
150
                reporthook=gen_bar_updater()
151
            )
152
        except (urllib.error.URLError, OSError) as e:
153
            if url[:5] == 'https':
154
                url = url.replace('https:', 'http:')
155
                message = (
156
                    'Failed download. Trying https -> http instead.'
157
                    ' Downloading ' + url + ' to ' + fpath
158
                )
159
                print(message)  # noqa: T001
160
                urllib.request.urlretrieve(
161
                    url, fpath,
162
                    reporthook=gen_bar_updater()
163
                )
164
            else:
165
                raise e
166
        # check integrity of downloaded file
167
        if not check_integrity(fpath, md5):
168
            raise RuntimeError('File not found or corrupted.')
169