Passed
Push — master ( 0bdb47...b0bac6 )
by Fernando
01:14
created

torchio.utils.check_uint_to_int()   A

Complexity

Conditions 3

Size

Total Lines 7
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 6
nop 1
dl 0
loc 7
rs 10
c 0
b 0
f 0
1
import os
2
import ast
3
import gzip
4
import shutil
5
import tempfile
6
from pathlib import Path
7
from typing import Union, Iterable, Tuple, Any, Optional, List, Sequence
8
9
from torch.utils.data._utils.collate import default_collate
10
import numpy as np
11
import nibabel as nib
12
import SimpleITK as sitk
13
from tqdm import trange
14
15
from .torchio import (
16
    INTENSITY,
17
    TypeData,
18
    TypeNumber,
19
    TypePath,
20
    REPO_URL,
21
)
22
23
24
FLIP_XY = np.diag((-1, -1, 1))  # used to switch between LPS and RAS
25
26
27
def to_tuple(
28
        value: Union[TypeNumber, Iterable[TypeNumber]],
29
        length: int = 1,
30
        ) -> Tuple[TypeNumber, ...]:
31
    """
32
    to_tuple(1, length=1) -> (1,)
33
    to_tuple(1, length=3) -> (1, 1, 1)
34
35
    If value is an iterable, n is ignored and tuple(value) is returned
36
    to_tuple((1,), length=1) -> (1,)
37
    to_tuple((1, 2), length=1) -> (1, 2)
38
    to_tuple([1, 2], length=3) -> (1, 2)
39
    """
40
    try:
41
        iter(value)
42
        value = tuple(value)
43
    except TypeError:
44
        value = length * (value,)
45
    return value
46
47
48
def get_stem(
49
        path: Union[TypePath, List[TypePath]]
50
        ) -> Union[str, List[str]]:
51
    """
52
    '/home/user/image.nii.gz' -> 'image'
53
    """
54
    def _get_stem(path_string):
55
        return Path(path_string).name.split('.')[0]
56
    if isinstance(path, (str, Path)):
57
        return _get_stem(path)
58
    return [_get_stem(p) for p in path]
59
60
61
def create_dummy_dataset(
62
        num_images: int,
63
        size_range: Tuple[int, int],
64
        directory: Optional[TypePath] = None,
65
        suffix: str = '.nii.gz',
66
        force: bool = False,
67
        verbose: bool = False,
68
        ):
69
    from .data import ScalarImage, LabelMap, Subject
70
    output_dir = tempfile.gettempdir() if directory is None else directory
71
    output_dir = Path(output_dir)
72
    images_dir = output_dir / 'dummy_images'
73
    labels_dir = output_dir / 'dummy_labels'
74
75
    if force:
76
        shutil.rmtree(images_dir)
77
        shutil.rmtree(labels_dir)
78
79
    subjects: List[Subject] = []
80
    if images_dir.is_dir():
81
        for i in trange(num_images):
82
            image_path = images_dir / f'image_{i}{suffix}'
83
            label_path = labels_dir / f'label_{i}{suffix}'
84
            subject = Subject(
85
                one_modality=ScalarImage(image_path),
86
                segmentation=LabelMap(label_path),
87
            )
88
            subjects.append(subject)
89
    else:
90
        images_dir.mkdir(exist_ok=True, parents=True)
91
        labels_dir.mkdir(exist_ok=True, parents=True)
92
        if verbose:
93
            print('Creating dummy dataset...')  # noqa: T001
94
            iterable = trange(num_images)
95
        else:
96
            iterable = range(num_images)
97
        for i in iterable:
98
            shape = np.random.randint(*size_range, size=3)
99
            affine = np.eye(4)
100
            image = np.random.rand(*shape)
101
            label = np.ones_like(image)
102
            label[image < 0.33] = 0
103
            label[image > 0.66] = 2
104
            image *= 255
105
106
            image_path = images_dir / f'image_{i}{suffix}'
107
            nii = nib.Nifti1Image(image.astype(np.uint8), affine)
108
            nii.to_filename(str(image_path))
109
110
            label_path = labels_dir / f'label_{i}{suffix}'
111
            nii = nib.Nifti1Image(label.astype(np.uint8), affine)
112
            nii.to_filename(str(label_path))
113
114
            subject = Subject(
115
                one_modality=ScalarImage(image_path),
116
                segmentation=LabelMap(label_path),
117
            )
118
            subjects.append(subject)
119
    return subjects
120
121
122
def apply_transform_to_file(
123
        input_path: TypePath,
124
        transform,  # : Transform seems to create a circular import
125
        output_path: TypePath,
126
        type: str = INTENSITY,  # noqa: A002
127
        verbose: bool = False,
128
        ):
129
    from . import Image, Subject
130
    subject = Subject(image=Image(input_path, type=type))
131
    transformed = transform(subject)
132
    transformed.image.save(output_path)
133
    if verbose and transformed.history:
134
        print('Applied transform:', transformed.history[0])  # noqa: T001
135
136
137
def guess_type(string: str) -> Any:
138
    # Adapted from
139
    # https://www.reddit.com/r/learnpython/comments/4599hl/module_to_guess_type_from_a_string/czw3f5s
140
    string = string.replace(' ', '')
141
    try:
142
        value = ast.literal_eval(string)
143
    except ValueError:
144
        result_type = str
145
    else:
146
        result_type = type(value)
147
    if result_type in (list, tuple):
148
        string = string[1:-1]  # remove brackets
149
        split = string.split(',')
150
        list_result = [guess_type(n) for n in split]
151
        value = tuple(list_result) if result_type is tuple else list_result
152
        return value
153
    try:
154
        value = result_type(string)
155
    except TypeError:
156
        value = None
157
    return value
158
159
160
def get_rotation_and_spacing_from_affine(
161
        affine: np.ndarray,
162
        ) -> Tuple[np.ndarray, np.ndarray]:
163
    # From https://github.com/nipy/nibabel/blob/master/nibabel/orientations.py
164
    rotation_zoom = affine[:3, :3]
165
    spacing = np.sqrt(np.sum(rotation_zoom * rotation_zoom, axis=0))
166
    rotation = rotation_zoom / spacing
167
    return rotation, spacing
168
169
170
def nib_to_sitk(
171
        data: TypeData,
172
        affine: TypeData,
173
        squeeze: bool = False,
174
        force_3d: bool = False,
175
        force_4d: bool = False,
176
        ) -> sitk.Image:
177
    """Create a SimpleITK image from a tensor and a 4x4 affine matrix."""
178
    if data.ndim != 4:
179
        raise ValueError(f'Input must be 4D, but has shape {tuple(data.shape)}')
180
    # Possibilities
181
    # (1, w, h, 1)
182
    # (c, w, h, 1)
183
    # (1, w, h, 1)
184
    # (c, w, h, d)
185
    array = np.asarray(data)
186
    affine = np.asarray(affine).astype(np.float64)
187
188
    is_multichannel = array.shape[0] > 1 and not force_4d
189
    is_2d = array.shape[3] == 1 and not force_3d
190
    if is_2d:
191
        array = array[..., 0]
192
    if not is_multichannel and not force_4d:
193
        array = array[0]
194
    array = array.transpose()  # (W, H, D, C) or (W, H, D)
195
    image = sitk.GetImageFromArray(array, isVector=is_multichannel)
196
197
    rotation, spacing = get_rotation_and_spacing_from_affine(affine)
198
    origin = np.dot(FLIP_XY, affine[:3, 3])
199
    direction = np.dot(FLIP_XY, rotation)
200
    if is_2d:  # ignore first dimension if 2D (1, W, H, 1)
201
        direction = direction[:2, :2]
202
    image.SetOrigin(origin)  # should I add a 4th value if force_4d?
203
    image.SetSpacing(spacing)
204
    image.SetDirection(direction.flatten())
205
    if data.ndim == 4:
206
        assert image.GetNumberOfComponentsPerPixel() == data.shape[0]
207
    num_spatial_dims = 2 if is_2d else 3
208
    assert image.GetSize() == data.shape[1: 1 + num_spatial_dims]
209
    return image
210
211
212
def sitk_to_nib(
213
        image: sitk.Image,
214
        keepdim: bool = False,
215
        ) -> Tuple[np.ndarray, np.ndarray]:
216
    data = sitk.GetArrayFromImage(image).transpose()
217
    num_components = image.GetNumberOfComponentsPerPixel()
218
    if num_components == 1:
219
        data = data[np.newaxis]  # add channels dimension
220
    input_spatial_dims = image.GetDimension()
221
    if input_spatial_dims == 2:
222
        data = data[..., np.newaxis]
223
    if not keepdim:
224
        data = ensure_4d(data, num_spatial_dims=input_spatial_dims)
225
    assert data.shape[0] == num_components
226
    assert data.shape[1: 1 + input_spatial_dims] == image.GetSize()
227
    spacing = np.array(image.GetSpacing())
228
    direction = np.array(image.GetDirection())
229
    origin = image.GetOrigin()
230
    if len(direction) == 9:
231
        rotation = direction.reshape(3, 3)
232
    elif len(direction) == 4:  # ignore first dimension if 2D (1, W, H, 1)
233
        rotation_2d = direction.reshape(2, 2)
234
        rotation = np.eye(3)
235
        rotation[:2, :2] = rotation_2d
236
        spacing = *spacing, 1
237
        origin = *origin, 0
238
    else:
239
        raise RuntimeError(f'Direction not understood: {direction}')
240
    rotation = np.dot(FLIP_XY, rotation)
241
    rotation_zoom = rotation * spacing
242
    translation = np.dot(FLIP_XY, origin)
243
    affine = np.eye(4)
244
    affine[:3, :3] = rotation_zoom
245
    affine[:3, 3] = translation
246
    return data, affine
247
248
249
def ensure_4d(tensor: TypeData, num_spatial_dims=None) -> TypeData:
250
    # I wish named tensors were properly supported in PyTorch
251
    num_dimensions = tensor.ndim
252
    if num_dimensions == 4:
253
        pass
254
    elif num_dimensions == 5:  # hope (W, H, D, 1, C)
255
        if tensor.shape[-2] == 1:
256
            tensor = tensor[..., 0, :]
257
            tensor = tensor.permute(3, 0, 1, 2)
258
        else:
259
            raise ValueError('5D is not supported for shape[-2] > 1')
260
    elif num_dimensions == 2:  # assume 2D monochannel (W, H)
261
        tensor = tensor[np.newaxis, ..., np.newaxis]  # (1, W, H, 1)
262
    elif num_dimensions == 3:  # 2D multichannel or 3D monochannel?
263
        if num_spatial_dims == 2:
264
            tensor = tensor[..., np.newaxis]  # (C, W, H, 1)
265
        elif num_spatial_dims == 3:  # (W, H, D)
266
            tensor = tensor[np.newaxis]  # (1, W, H, D)
267
        else:  # try to guess
268
            shape = tensor.shape
269
            maybe_rgb = 3 in (shape[0], shape[-1])
270
            if maybe_rgb:
271
                if shape[-1] == 3:  # (W, H, 3)
272
                    tensor = tensor.permute(2, 0, 1)  # (3, W, H)
273
                tensor = tensor[..., np.newaxis]  # (3, W, H, 1)
274
            else:  # (W, H, D)
275
                tensor = tensor[np.newaxis]  # (1, W, H, D)
276
    else:
277
        message = (
278
            f'{num_dimensions}D images not supported yet. Please create an'
279
            f' issue in {REPO_URL} if you would like support for them'
280
        )
281
        raise ValueError(message)
282
    assert tensor.ndim == 4
283
    return tensor
284
285
286
def get_torchio_cache_dir():
287
    return Path('~/.cache/torchio').expanduser()
288
289
290
def round_up(value: float) -> int:
291
    """Round half towards infinity.
292
293
    Args:
294
        value: The value to round.
295
296
    Example:
297
298
        >>> round(2.5)
299
        2
300
        >>> round(3.5)
301
        4
302
        >>> round_up(2.5)
303
        3
304
        >>> round_up(3.5)
305
        4
306
307
    """
308
    return int(np.floor(value + 0.5))
309
310
311
def compress(input_path, output_path):
312
    with open(input_path, 'rb') as f_in:
313
        with gzip.open(output_path, 'wb') as f_out:
314
            shutil.copyfileobj(f_in, f_out)
315
316
317
def check_sequence(sequence: Sequence, name: str):
318
    try:
319
        iter(sequence)
320
    except TypeError:
321
        message = f'"{name}" must be a sequence, not {type(name)}'
322
        raise TypeError(message)
323
324
325
def get_major_sitk_version() -> int:
326
    # This attribute was added in version 2
327
    # https://github.com/SimpleITK/SimpleITK/pull/1171
328
    version = getattr(sitk, '__version__', None)
329
    major_version = 1 if version is None else 2
330
    return major_version
331
332
333
def history_collate(batch: Sequence, collate_transforms=True):
334
    attr = 'history' if collate_transforms else 'applied_transforms'
335
    # Adapted from
336
    # https://github.com/romainVala/torchQC/blob/master/segmentation/collate_functions.py
337
    from .data import Subject
338
    first_element = batch[0]
339
    if isinstance(first_element, Subject):
340
        dictionary = {
341
            key: default_collate([d[key] for d in batch])
342
            for key in first_element
343
        }
344
        if hasattr(first_element, attr):
345
            dictionary.update({attr: [getattr(d, attr) for d in batch]})
346
        return dictionary
347
348
349
# Adapted from torchvision, removing print statements
350
def download_and_extract_archive(
351
        url: str,
352
        download_root: TypePath,
353
        extract_root: Optional[TypePath] = None,
354
        filename: Optional[TypePath] = None,
355
        md5: str = None,
356
        remove_finished: bool = False,
357
        ) -> None:
358
    download_root = os.path.expanduser(download_root)
359
    if extract_root is None:
360
        extract_root = download_root
361
    if not filename:
362
        filename = os.path.basename(url)
363
    download_url(url, download_root, filename, md5)
364
    archive = os.path.join(download_root, filename)
365
    from torchvision.datasets.utils import extract_archive
366
    extract_archive(archive, extract_root, remove_finished)
367
368
369
# Adapted from torchvision, removing print statements
370
def download_url(
371
        url: str,
372
        root: TypePath,
373
        filename: Optional[TypePath] = None,
374
        md5: str = None,
375
        ) -> None:
376
    """Download a file from a url and place it in root.
377
378
    Args:
379
        url: URL to download file from
380
        root: Directory to place downloaded file in
381
        filename: Name to save the file under.
382
            If ``None``, use the basename of the URL
383
        md5: MD5 checksum of the download. If None, do not check
384
    """
385
    import urllib
386
    from torchvision.datasets.utils import check_integrity, gen_bar_updater
387
388
    root = os.path.expanduser(root)
389
    if not filename:
390
        filename = os.path.basename(url)
391
    fpath = os.path.join(root, filename)
392
    os.makedirs(root, exist_ok=True)
393
    # check if file is already present locally
394
    if not check_integrity(fpath, md5):
395
        try:
396
            print('Downloading ' + url + ' to ' + fpath)  # noqa: T001
397
            urllib.request.urlretrieve(
398
                url, fpath,
399
                reporthook=gen_bar_updater()
400
            )
401
        except (urllib.error.URLError, OSError) as e:
402
            if url[:5] == 'https':
403
                url = url.replace('https:', 'http:')
404
                message = (
405
                    'Failed download. Trying https -> http instead.'
406
                    ' Downloading ' + url + ' to ' + fpath
407
                )
408
                print(message)  # noqa: T001
409
                urllib.request.urlretrieve(
410
                    url, fpath,
411
                    reporthook=gen_bar_updater()
412
                )
413
            else:
414
                raise e
415
        # check integrity of downloaded file
416
        if not check_integrity(fpath, md5):
417
            raise RuntimeError('File not found or corrupted.')
418
419
420
def check_uint_to_int(array):
421
    # This is because PyTorch won't take uint16 nor uint32
422
    if array.dtype == np.uint16:
423
        return array.astype(np.int32)
424
    if array.dtype == np.uint32:
425
        return array.astype(np.int64)
426
    return array
427