Passed
Push — master ( 0e3b0b...4497b8 )
by Fernando
01:17
created

torchio.utils.check_uint_to_int()   A

Complexity

Conditions 3

Size

Total Lines 7
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 6
nop 1
dl 0
loc 7
rs 10
c 0
b 0
f 0
1
import os
2
import ast
3
import gzip
4
import shutil
5
import tempfile
6
from pathlib import Path
7
from typing import Union, Iterable, Tuple, Any, Optional, List, Sequence
8
9
from torch.utils.data._utils.collate import default_collate
10
import numpy as np
11
import nibabel as nib
12
import SimpleITK as sitk
13
from tqdm import trange
14
15
from .constants import INTENSITY
16
from .typing import TypeNumber, TypePath
17
18
19
def to_tuple(
20
        value: Union[TypeNumber, Iterable[TypeNumber]],
21
        length: int = 1,
22
        ) -> Tuple[TypeNumber, ...]:
23
    """
24
    to_tuple(1, length=1) -> (1,)
25
    to_tuple(1, length=3) -> (1, 1, 1)
26
27
    If value is an iterable, n is ignored and tuple(value) is returned
28
    to_tuple((1,), length=1) -> (1,)
29
    to_tuple((1, 2), length=1) -> (1, 2)
30
    to_tuple([1, 2], length=3) -> (1, 2)
31
    """
32
    try:
33
        iter(value)
34
        value = tuple(value)
35
    except TypeError:
36
        value = length * (value,)
37
    return value
38
39
40
def get_stem(
41
        path: Union[TypePath, List[TypePath]]
42
        ) -> Union[str, List[str]]:
43
    """
44
    '/home/user/image.nii.gz' -> 'image'
45
    """
46
    def _get_stem(path_string):
47
        return Path(path_string).name.split('.')[0]
48
    if isinstance(path, (str, Path)):
49
        return _get_stem(path)
50
    return [_get_stem(p) for p in path]
51
52
53
def create_dummy_dataset(
54
        num_images: int,
55
        size_range: Tuple[int, int],
56
        directory: Optional[TypePath] = None,
57
        suffix: str = '.nii.gz',
58
        force: bool = False,
59
        verbose: bool = False,
60
        ):
61
    from .data import ScalarImage, LabelMap, Subject
62
    output_dir = tempfile.gettempdir() if directory is None else directory
63
    output_dir = Path(output_dir)
64
    images_dir = output_dir / 'dummy_images'
65
    labels_dir = output_dir / 'dummy_labels'
66
67
    if force:
68
        shutil.rmtree(images_dir)
69
        shutil.rmtree(labels_dir)
70
71
    subjects: List[Subject] = []
72
    if images_dir.is_dir():
73
        for i in trange(num_images):
74
            image_path = images_dir / f'image_{i}{suffix}'
75
            label_path = labels_dir / f'label_{i}{suffix}'
76
            subject = Subject(
77
                one_modality=ScalarImage(image_path),
78
                segmentation=LabelMap(label_path),
79
            )
80
            subjects.append(subject)
81
    else:
82
        images_dir.mkdir(exist_ok=True, parents=True)
83
        labels_dir.mkdir(exist_ok=True, parents=True)
84
        if verbose:
85
            print('Creating dummy dataset...')  # noqa: T001
86
            iterable = trange(num_images)
87
        else:
88
            iterable = range(num_images)
89
        for i in iterable:
90
            shape = np.random.randint(*size_range, size=3)
91
            affine = np.eye(4)
92
            image = np.random.rand(*shape)
93
            label = np.ones_like(image)
94
            label[image < 0.33] = 0
95
            label[image > 0.66] = 2
96
            image *= 255
97
98
            image_path = images_dir / f'image_{i}{suffix}'
99
            nii = nib.Nifti1Image(image.astype(np.uint8), affine)
100
            nii.to_filename(str(image_path))
101
102
            label_path = labels_dir / f'label_{i}{suffix}'
103
            nii = nib.Nifti1Image(label.astype(np.uint8), affine)
104
            nii.to_filename(str(label_path))
105
106
            subject = Subject(
107
                one_modality=ScalarImage(image_path),
108
                segmentation=LabelMap(label_path),
109
            )
110
            subjects.append(subject)
111
    return subjects
112
113
114
def apply_transform_to_file(
115
        input_path: TypePath,
116
        transform,  # : Transform seems to create a circular import
117
        output_path: TypePath,
118
        type: str = INTENSITY,  # noqa: A002
119
        verbose: bool = False,
120
        ):
121
    from . import Image, Subject
122
    subject = Subject(image=Image(input_path, type=type))
123
    transformed = transform(subject)
124
    transformed.image.save(output_path)
125
    if verbose and transformed.history:
126
        print('Applied transform:', transformed.history[0])  # noqa: T001
127
128
129
def guess_type(string: str) -> Any:
130
    # Adapted from
131
    # https://www.reddit.com/r/learnpython/comments/4599hl/module_to_guess_type_from_a_string/czw3f5s
132
    string = string.replace(' ', '')
133
    try:
134
        value = ast.literal_eval(string)
135
    except ValueError:
136
        result_type = str
137
    else:
138
        result_type = type(value)
139
    if result_type in (list, tuple):
140
        string = string[1:-1]  # remove brackets
141
        split = string.split(',')
142
        list_result = [guess_type(n) for n in split]
143
        value = tuple(list_result) if result_type is tuple else list_result
144
        return value
145
    try:
146
        value = result_type(string)
147
    except TypeError:
148
        value = None
149
    return value
150
151
152
def get_torchio_cache_dir():
153
    return Path('~/.cache/torchio').expanduser()
154
155
156
def round_up(value: float) -> int:
157
    """Round half towards infinity.
158
159
    Args:
160
        value: The value to round.
161
162
    Example:
163
164
        >>> round(2.5)
165
        2
166
        >>> round(3.5)
167
        4
168
        >>> round_up(2.5)
169
        3
170
        >>> round_up(3.5)
171
        4
172
173
    """
174
    return int(np.floor(value + 0.5))
175
176
177
def compress(input_path, output_path):
178
    with open(input_path, 'rb') as f_in:
179
        with gzip.open(output_path, 'wb') as f_out:
180
            shutil.copyfileobj(f_in, f_out)
181
182
183
def check_sequence(sequence: Sequence, name: str):
184
    try:
185
        iter(sequence)
186
    except TypeError:
187
        message = f'"{name}" must be a sequence, not {type(name)}'
188
        raise TypeError(message)
189
190
191
def get_major_sitk_version() -> int:
192
    # This attribute was added in version 2
193
    # https://github.com/SimpleITK/SimpleITK/pull/1171
194
    version = getattr(sitk, '__version__', None)
195
    major_version = 1 if version is None else 2
196
    return major_version
197
198
199
def history_collate(batch: Sequence, collate_transforms=True):
200
    attr = 'history' if collate_transforms else 'applied_transforms'
201
    # Adapted from
202
    # https://github.com/romainVala/torchQC/blob/master/segmentation/collate_functions.py
203
    from .data import Subject
204
    first_element = batch[0]
205
    if isinstance(first_element, Subject):
206
        dictionary = {
207
            key: default_collate([d[key] for d in batch])
208
            for key in first_element
209
        }
210
        if hasattr(first_element, attr):
211
            dictionary.update({attr: [getattr(d, attr) for d in batch]})
212
        return dictionary
213
214
215
# Adapted from torchvision, removing print statements
216
def download_and_extract_archive(
217
        url: str,
218
        download_root: TypePath,
219
        extract_root: Optional[TypePath] = None,
220
        filename: Optional[TypePath] = None,
221
        md5: str = None,
222
        remove_finished: bool = False,
223
        ) -> None:
224
    download_root = os.path.expanduser(download_root)
225
    if extract_root is None:
226
        extract_root = download_root
227
    if not filename:
228
        filename = os.path.basename(url)
229
    download_url(url, download_root, filename, md5)
230
    archive = os.path.join(download_root, filename)
231
    from torchvision.datasets.utils import extract_archive
232
    extract_archive(archive, extract_root, remove_finished)
233
234
235
# Adapted from torchvision, removing print statements
236
def download_url(
237
        url: str,
238
        root: TypePath,
239
        filename: Optional[TypePath] = None,
240
        md5: str = None,
241
        ) -> None:
242
    """Download a file from a url and place it in root.
243
244
    Args:
245
        url: URL to download file from
246
        root: Directory to place downloaded file in
247
        filename: Name to save the file under.
248
            If ``None``, use the basename of the URL
249
        md5: MD5 checksum of the download. If None, do not check
250
    """
251
    import urllib
252
    from torchvision.datasets.utils import check_integrity, gen_bar_updater
253
254
    root = os.path.expanduser(root)
255
    if not filename:
256
        filename = os.path.basename(url)
257
    fpath = os.path.join(root, filename)
258
    os.makedirs(root, exist_ok=True)
259
    # check if file is already present locally
260
    if not check_integrity(fpath, md5):
261
        try:
262
            print('Downloading ' + url + ' to ' + fpath)  # noqa: T001
263
            urllib.request.urlretrieve(
264
                url, fpath,
265
                reporthook=gen_bar_updater()
266
            )
267
        except (urllib.error.URLError, OSError) as e:
268
            if url[:5] == 'https':
269
                url = url.replace('https:', 'http:')
270
                message = (
271
                    'Failed download. Trying https -> http instead.'
272
                    ' Downloading ' + url + ' to ' + fpath
273
                )
274
                print(message)  # noqa: T001
275
                urllib.request.urlretrieve(
276
                    url, fpath,
277
                    reporthook=gen_bar_updater()
278
                )
279
            else:
280
                raise e
281
        # check integrity of downloaded file
282
        if not check_integrity(fpath, md5):
283
            raise RuntimeError('File not found or corrupted.')
284