torchio.download.extract_archive()   F
last analyzed

Complexity

Conditions 14

Size

Total Lines 26
Code Lines 23

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 14
eloc 23
nop 3
dl 0
loc 26
rs 3.6
c 0
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like torchio.download.extract_archive() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""
2
Most of this code is from torchvision.
3
I will remove all this once verbosity is reduced.
4
More info: https://github.com/pytorch/vision/issues/2830
5
"""
6
7
import os
8
import gzip
9
import urllib
10
import zipfile
11
import hashlib
12
import tarfile
13
from typing import Optional
14
15
from torch.hub import tqdm
16
17
from .typing import TypePath
18
19
20
def calculate_md5(fpath, chunk_size=1024 * 1024):
21
    md5 = hashlib.md5()
22
    with open(fpath, 'rb') as f:
23
        for chunk in iter(lambda: f.read(chunk_size), b''):
24
            md5.update(chunk)
25
    return md5.hexdigest()
26
27
28
def check_md5(fpath, md5, **kwargs):
29
    return md5 == calculate_md5(fpath, **kwargs)
30
31
32
def check_integrity(fpath, md5=None):
33
    if not os.path.isfile(fpath):
34
        return False
35
    if md5 is None:
36
        return True
37
    return check_md5(fpath, md5)
38
39
40
def gen_bar_updater():
41
    pbar = tqdm(total=None)
42
43
    def bar_update(count, block_size, total_size):
44
        if pbar.total is None and total_size:
45
            pbar.total = total_size
46
        progress_bytes = count * block_size
47
        pbar.update(progress_bytes - pbar.n)
48
49
    return bar_update
50
51
52
# Adapted from torchvision, removing print statements
53
def download_and_extract_archive(
54
        url: str,
55
        download_root: TypePath,
56
        extract_root: Optional[TypePath] = None,
57
        filename: Optional[TypePath] = None,
58
        md5: str = None,
59
        remove_finished: bool = False,
60
        ) -> None:
61
    download_root = os.path.expanduser(download_root)
62
    if extract_root is None:
63
        extract_root = download_root
64
    if not filename:
65
        filename = os.path.basename(url)
66
    download_url(url, download_root, filename, md5)
67
    archive = os.path.join(download_root, filename)
68
    extract_archive(archive, extract_root, remove_finished)
69
70
71
def _is_tarxz(filename):
72
    return filename.endswith('.tar.xz')
73
74
75
def _is_tar(filename):
76
    return filename.endswith('.tar')
77
78
79
def _is_targz(filename):
80
    return filename.endswith('.tar.gz')
81
82
83
def _is_tgz(filename):
84
    return filename.endswith('.tgz')
85
86
87
def _is_gzip(filename):
88
    return filename.endswith('.gz') and not filename.endswith('.tar.gz')
89
90
91
def _is_zip(filename):
92
    return filename.endswith('.zip')
93
94
95
def extract_archive(from_path, to_path=None, remove_finished=False):
96
    if to_path is None:
97
        to_path = os.path.dirname(from_path)
98
99
    if _is_tar(from_path):
100
        with tarfile.open(from_path, 'r') as tar:
101
            tar.extractall(path=to_path)
102
    elif _is_targz(from_path) or _is_tgz(from_path):
103
        with tarfile.open(from_path, 'r:gz') as tar:
104
            tar.extractall(path=to_path)
105
    elif _is_tarxz(from_path):
106
        with tarfile.open(from_path, 'r:xz') as tar:
107
            tar.extractall(path=to_path)
108
    elif _is_gzip(from_path):
109
        stem = os.path.splitext(os.path.basename(from_path))[0]
110
        to_path = os.path.join(to_path, stem)
111
        with open(to_path, 'wb') as out_f, gzip.GzipFile(from_path) as zip_f:
112
            out_f.write(zip_f.read())
113
    elif _is_zip(from_path):
114
        with zipfile.ZipFile(from_path, 'r') as z:
115
            z.extractall(to_path)
116
    else:
117
        raise ValueError(f'Extraction of {from_path} not supported')
118
119
    if remove_finished:
120
        os.remove(from_path)
121
122
123
# Adapted from torchvision, removing print statements
124
def download_url(
125
        url: str,
126
        root: TypePath,
127
        filename: Optional[TypePath] = None,
128
        md5: str = None,
129
        ) -> None:
130
    """Download a file from a url and place it in root.
131
132
    Args:
133
        url: URL to download file from
134
        root: Directory to place downloaded file in
135
        filename: Name to save the file under.
136
            If ``None``, use the basename of the URL
137
        md5: MD5 checksum of the download. If None, do not check
138
    """
139
140
    root = os.path.expanduser(root)
141
    if not filename:
142
        filename = os.path.basename(url)
143
    fpath = os.path.join(root, filename)
144
    os.makedirs(root, exist_ok=True)
145
    # check if file is already present locally
146
    if not check_integrity(fpath, md5):
147
        try:
148
            print('Downloading ' + url + ' to ' + fpath)  # noqa: T001
149
            urllib.request.urlretrieve(
150
                url, fpath,
151
                reporthook=gen_bar_updater()
152
            )
153
        except (urllib.error.URLError, OSError) as e:
154
            if url[:5] == 'https':
155
                url = url.replace('https:', 'http:')
156
                message = (
157
                    'Failed download. Trying https -> http instead.'
158
                    ' Downloading ' + url + ' to ' + fpath
159
                )
160
                print(message)  # noqa: T001
161
                urllib.request.urlretrieve(
162
                    url, fpath,
163
                    reporthook=gen_bar_updater()
164
                )
165
            else:
166
                raise e
167
        # check integrity of downloaded file
168
        if not check_integrity(fpath, md5):
169
            raise RuntimeError('File not found or corrupted.')
170