torchio.download.extract_archive()   F
last analyzed

Complexity

Conditions 14

Size

Total Lines 26
Code Lines 23

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 14
eloc 23
nop 3
dl 0
loc 26
rs 3.6
c 0
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like torchio.download.extract_archive() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""Most of this code is from torchvision.
2
3
I will remove all this once verbosity is reduced. More info:
4
https://github.com/pytorch/vision/issues/2830
5
"""
6
7
from __future__ import annotations
8
9
import gzip
10
import hashlib
11
import os
12
import tarfile
13
import urllib
14
import zipfile
15
16
from torch.hub import tqdm
17
18
from .types import TypePath
19
20
21
def calculate_md5(fpath, chunk_size=1024 * 1024):
22
    md5 = hashlib.md5()
23
    with open(fpath, 'rb') as f:
24
        for chunk in iter(lambda: f.read(chunk_size), b''):
25
            md5.update(chunk)
26
    return md5.hexdigest()
27
28
29
def check_md5(fpath, md5, **kwargs):
30
    return md5 == calculate_md5(fpath, **kwargs)
31
32
33
def check_integrity(fpath, md5=None):
34
    if not os.path.isfile(fpath):
35
        return False
36
    if md5 is None:
37
        return True
38
    return check_md5(fpath, md5)
39
40
41
def gen_bar_updater():
42
    pbar = tqdm(total=None)
43
44
    def bar_update(count, block_size, total_size):
45
        if pbar.total is None and total_size:
46
            pbar.total = total_size
47
        progress_bytes = count * block_size
48
        pbar.update(progress_bytes - pbar.n)
49
50
    return bar_update
51
52
53
# Adapted from torchvision, removing print statements
54
def download_and_extract_archive(
55
    url: str,
56
    download_root: TypePath,
57
    extract_root: TypePath | None = None,
58
    filename: TypePath | None = None,
59
    md5: str | None = None,
60
    remove_finished: bool = False,
61
) -> None:
62
    download_root = os.path.expanduser(download_root)
63
    if extract_root is None:
64
        extract_root = download_root
65
    if not filename:
66
        filename = os.path.basename(url)
67
    download_url(url, download_root, filename, md5)
68
    archive = os.path.join(download_root, filename)
69
    extract_archive(archive, extract_root, remove_finished)
70
71
72
def _is_tarxz(filename):
73
    return filename.endswith('.tar.xz')
74
75
76
def _is_tar(filename):
77
    return filename.endswith('.tar')
78
79
80
def _is_targz(filename):
81
    return filename.endswith('.tar.gz')
82
83
84
def _is_tgz(filename):
85
    return filename.endswith('.tgz')
86
87
88
def _is_gzip(filename):
89
    return filename.endswith('.gz') and not filename.endswith('.tar.gz')
90
91
92
def _is_zip(filename):
93
    return filename.endswith('.zip')
94
95
96
def extract_archive(from_path, to_path=None, remove_finished=False):
97
    if to_path is None:
98
        to_path = os.path.dirname(from_path)
99
100
    if _is_tar(from_path):
101
        with tarfile.open(from_path, 'r') as tar:
102
            tar.extractall(path=to_path)
103
    elif _is_targz(from_path) or _is_tgz(from_path):
104
        with tarfile.open(from_path, 'r:gz') as tar:
105
            tar.extractall(path=to_path)
106
    elif _is_tarxz(from_path):
107
        with tarfile.open(from_path, 'r:xz') as tar:
108
            tar.extractall(path=to_path)
109
    elif _is_gzip(from_path):
110
        stem = os.path.splitext(os.path.basename(from_path))[0]
111
        to_path = os.path.join(to_path, stem)
112
        with open(to_path, 'wb') as out_f, gzip.GzipFile(from_path) as zip_f:
113
            out_f.write(zip_f.read())
114
    elif _is_zip(from_path):
115
        with zipfile.ZipFile(from_path, 'r') as z:
116
            z.extractall(to_path)
117
    else:
118
        raise ValueError(f'Extraction of {from_path} not supported')
119
120
    if remove_finished:
121
        os.remove(from_path)
122
123
124
# Adapted from torchvision, removing print statements
125
def download_url(
126
    url: str,
127
    root: TypePath,
128
    filename: TypePath | None = None,
129
    md5: str | None = None,
130
) -> None:
131
    """Download a file from a url and place it in root.
132
133
    Args:
134
        url: URL to download file from
135
        root: Directory to place downloaded file in
136
        filename: Name to save the file under.
137
            If ``None``, use the basename of the URL
138
        md5: MD5 checksum of the download. If None, do not check
139
    """
140
141
    root = os.path.expanduser(root)
142
    if not filename:
143
        filename = os.path.basename(url)
144
    fpath = os.path.join(root, filename)
145
    os.makedirs(root, exist_ok=True)
146
    # check if file is already present locally
147
    if not check_integrity(fpath, md5):
148
        try:
149
            print('Downloading ' + url + ' to ' + fpath)  # noqa: T201
150
            urllib.request.urlretrieve(
151
                url,
152
                fpath,
153
                reporthook=gen_bar_updater(),
154
            )
155
        except (urllib.error.URLError, OSError) as e:
156
            if url[:5] == 'https':
157
                url = url.replace('https:', 'http:')
158
                message = (
159
                    'Failed download. Trying https -> http instead. Downloading '
160
                    + url
161
                    + ' to '
162
                    + fpath
163
                )
164
                print(message)  # noqa: T201
165
                urllib.request.urlretrieve(
166
                    url,
167
                    fpath,
168
                    reporthook=gen_bar_updater(),
169
                )
170
            else:
171
                raise e
172
        # check integrity of downloaded file
173
        if not check_integrity(fpath, md5):
174
            raise RuntimeError('File not found or corrupted.')
175