Passed
Pull Request — main (#1288)
by Fernando
01:28
created

torchio.utils   F

Complexity

Total Complexity 78

Size/Duplication

Total Lines 436
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 78
eloc 286
dl 0
loc 436
rs 2.16
c 0
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like torchio.utils often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from __future__ import annotations
2
3
import ast
4
import gzip
5
import inspect
6
import os
7
import shutil
8
import sys
9
import tempfile
10
from collections.abc import Iterable
11
from collections.abc import Sequence
12
from pathlib import Path
13
from typing import Any
14
15
import numpy as np
16
import SimpleITK as sitk
17
import torch
18
import torch.utils.data.dataloader
19
from nibabel.nifti1 import Nifti1Image
20
from torch.utils.data import DataLoader
21
from torch.utils.data._utils.collate import default_collate
22
from tqdm.auto import trange
23
24
from . import constants
25
from .types import TypeNumber
26
from .types import TypePath
27
28
ITK_SNAP = 'ITK-SNAP'
29
SLICER = 'Slicer'
30
31
32
def to_tuple(
33
    value: Any,
34
    length: int = 1,
35
) -> tuple[TypeNumber, ...]:
36
    """Convert variable to tuple of length n.
37
38
    Example:
39
        >>> from torchio.utils import to_tuple
40
        >>> to_tuple(1, length=1)
41
        (1,)
42
        >>> to_tuple(1, length=3)
43
        (1, 1, 1)
44
45
    If value is an iterable, n is ignored and tuple(value) is returned
46
47
    Example:
48
        >>> to_tuple((1,), length=1)
49
        (1,)
50
        >>> to_tuple((1, 2), length=1)
51
        (1, 2)
52
        >>> to_tuple([1, 2], length=3)
53
        (1, 2)
54
    """
55
    try:
56
        iter(value)
57
        value = tuple(value)
58
    except TypeError:
59
        value = length * (value,)
60
    return value
61
62
63
def get_stem(
64
    path: TypePath | Sequence[TypePath],
65
) -> str | list[str]:
66
    """Get stem of path or paths.
67
68
    Example:
69
        >>> from torchio.utils import get_stem
70
        >>> get_stem('/home/user/my_image.nii.gz')
71
        'my_image'
72
    """
73
74
    def _get_stem(path_string: TypePath) -> str:
75
        return Path(path_string).name.split('.')[0]
76
77
    if isinstance(path, (str, os.PathLike)):
78
        return _get_stem(path)
79
    else:  # path is actually a sequence of paths
80
        return [_get_stem(p) for p in path]
81
82
83
def create_dummy_dataset(
84
    num_images: int,
85
    size_range: tuple[int, int],
86
    directory: TypePath | None = None,
87
    suffix: str = '.nii.gz',
88
    force: bool = False,
89
    verbose: bool = False,
90
):
91
    from .data import LabelMap
92
    from .data import ScalarImage
93
    from .data import Subject
94
95
    output_dir = tempfile.gettempdir() if directory is None else directory
96
    output_dir = Path(output_dir)
97
    images_dir = output_dir / 'dummy_images'
98
    labels_dir = output_dir / 'dummy_labels'
99
100
    if force:
101
        shutil.rmtree(images_dir)
102
        shutil.rmtree(labels_dir)
103
104
    subjects: list[Subject] = []
105
    if images_dir.is_dir():
106
        for i in trange(num_images):
107
            image_path = images_dir / f'image_{i}{suffix}'
108
            label_path = labels_dir / f'label_{i}{suffix}'
109
            subject = Subject(
110
                one_modality=ScalarImage(image_path),
111
                segmentation=LabelMap(label_path),
112
            )
113
            subjects.append(subject)
114
    else:
115
        images_dir.mkdir(exist_ok=True, parents=True)
116
        labels_dir.mkdir(exist_ok=True, parents=True)
117
        iterable: Iterable[int]
118
        if verbose:
119
            print('Creating dummy dataset...')  # noqa: T201
120
            iterable = trange(num_images)
121
        else:
122
            iterable = range(num_images)
123
        for i in iterable:
124
            shape = np.random.randint(*size_range, size=3)
125
            affine = np.eye(4)
126
            image = np.random.rand(*shape)
127
            label = np.ones_like(image)
128
            label[image < 0.33] = 0
129
            label[image > 0.66] = 2
130
            image *= 255
131
132
            image_path = images_dir / f'image_{i}{suffix}'
133
            nii = Nifti1Image(image.astype(np.uint8), affine)
134
            nii.to_filename(str(image_path))
135
136
            label_path = labels_dir / f'label_{i}{suffix}'
137
            nii = Nifti1Image(label.astype(np.uint8), affine)
138
            nii.to_filename(str(label_path))
139
140
            subject = Subject(
141
                one_modality=ScalarImage(image_path),
142
                segmentation=LabelMap(label_path),
143
            )
144
            subjects.append(subject)
145
    return subjects
146
147
148
def apply_transform_to_file(
149
    input_path: TypePath,
150
    transform,  # : Transform seems to create a circular import
151
    output_path: TypePath,
152
    class_: str = 'ScalarImage',
153
    verbose: bool = False,
154
):
155
    from . import data
156
157
    image = getattr(data, class_)(input_path)
158
    subject = data.Subject(image=image)
159
    transformed = transform(subject)
160
    transformed.image.save(output_path)
161
    if verbose and transformed.history:
162
        print('Applied transform:', transformed.history[0])  # noqa: T201
163
164
165
def guess_type(string: str) -> Any:
166
    # Adapted from
167
    # https://www.reddit.com/r/learnpython/comments/4599hl/module_to_guess_type_from_a_string/czw3f5s
168
    string = string.replace(' ', '')
169
    result_type: Any
170
    try:
171
        value = ast.literal_eval(string)
172
    except ValueError:
173
        result_type = str
174
    else:
175
        result_type = type(value)
176
    if result_type in (list, tuple):
177
        string = string[1:-1]  # remove brackets
178
        split = string.split(',')
179
        list_result = [guess_type(n) for n in split]
180
        value = tuple(list_result) if result_type is tuple else list_result
181
        return value
182
    try:
183
        value = result_type(string)
184
    except TypeError:
185
        value = None
186
    return value
187
188
189
def get_torchio_cache_dir() -> Path:
190
    return Path('~/.cache/torchio').expanduser()
191
192
193
def compress(
194
    input_path: TypePath,
195
    output_path: TypePath | None = None,
196
) -> Path:
197
    if output_path is None:
198
        output_path = Path(input_path).with_suffix('.nii.gz')
199
    with open(input_path, 'rb') as f_in:
200
        with gzip.open(output_path, 'wb') as f_out:
201
            shutil.copyfileobj(f_in, f_out)
202
    return Path(output_path)
203
204
205
def check_sequence(sequence: Sequence, name: str) -> None:
206
    try:
207
        iter(sequence)
208
    except TypeError as err:
209
        message = f'"{name}" must be a sequence, not {type(name)}'
210
        raise TypeError(message) from err
211
212
213
def get_major_sitk_version() -> int:
214
    # This attribute was added in version 2
215
    # https://github.com/SimpleITK/SimpleITK/pull/1171
216
    version = getattr(sitk, '__version__', None)
217
    major_version = 1 if version is None else 2
218
    return major_version
219
220
221
def history_collate(batch: Sequence, collate_transforms=True) -> dict:
222
    attr = constants.HISTORY if collate_transforms else 'applied_transforms'
223
    # Adapted from
224
    # https://github.com/romainVala/torchQC/blob/master/segmentation/collate_functions.py
225
    from .data import Subject
226
227
    first_element = batch[0]
228
    if isinstance(first_element, Subject):
229
        dictionary = {
230
            key: default_collate([d[key] for d in batch]) for key in first_element
231
        }
232
        if hasattr(first_element, attr):
233
            dictionary.update({attr: [getattr(d, attr) for d in batch]})
234
    else:
235
        dictionary = {}
236
    return dictionary
237
238
239
def get_subclasses(target_class: type) -> list[type]:
240
    subclasses = target_class.__subclasses__()
241
    subclasses += sum((get_subclasses(cls) for cls in subclasses), [])
242
    return subclasses
243
244
245
def get_first_item(data_loader: DataLoader):
246
    return next(iter(data_loader))
247
248
249
def get_batch_images_and_size(batch: dict) -> tuple[list[str], int]:
250
    """Get number of images and images names in a batch.
251
252
    Args:
253
        batch: Dictionary generated by a :class:`tio.SubjectsLoader`
254
            extracting data from a :class:`torchio.SubjectsDataset`.
255
256
    Raises:
257
        RuntimeError: If the batch does not seem to contain any dictionaries
258
            that seem to represent a :class:`torchio.Image`.
259
    """
260
    names = []
261
    for key, value in batch.items():
262
        if isinstance(value, dict) and constants.DATA in value:
263
            size = len(value[constants.DATA])
264
            names.append(key)
265
    if not names:
266
        raise RuntimeError('The batch does not seem to contain any images')
267
    return names, size
268
269
270
def get_subjects_from_batch(batch: dict) -> list:
271
    """Get list of subjects from collated batch.
272
273
    Args:
274
        batch: Dictionary generated by a :class:`tio.SubjectsLoader`
275
            extracting data from a :class:`torchio.SubjectsDataset`.
276
    """
277
    from .data import LabelMap
278
    from .data import ScalarImage
279
    from .data import Subject
280
281
    subjects = []
282
    image_names, batch_size = get_batch_images_and_size(batch)
283
284
    for i in range(batch_size):
285
        subject_dict = {}
286
287
        for key, value in batch.items():
288
            if key in image_names:
289
                image_name = key
290
                image_dict = value
291
                data = image_dict[constants.DATA][i]
292
                affine = image_dict[constants.AFFINE][i]
293
                path = Path(image_dict[constants.PATH][i])
294
                is_label = image_dict[constants.TYPE][i] == constants.LABEL
295
                klass = LabelMap if is_label else ScalarImage
296
                image = klass(tensor=data, affine=affine, filename=path.name)
297
                subject_dict[image_name] = image
298
            else:
299
                instance_value = value[i]
300
                subject_dict[key] = instance_value
301
302
        subject = Subject(subject_dict)
303
304
        if constants.HISTORY in batch:
305
            applied_transforms = batch[constants.HISTORY][i]
306
            for transform in applied_transforms:
307
                transform.add_transform_to_subject_history(subject)
308
309
        subjects.append(subject)
310
    return subjects
311
312
313
def add_images_from_batch(
314
    subjects: list,
315
    tensor: torch.Tensor,
316
    class_=None,
317
    name='prediction',
318
) -> None:
319
    """Add images to subjects in a list, typically from a network prediction.
320
321
    The spatial metadata (affine matrices) will be extracted from one of the
322
    images of each subject.
323
324
    Args:
325
        subjects: List of instances of :class:`torchio.Subject` to which images
326
            will be added.
327
        tensor: PyTorch tensor of shape :math:`(B, C, W, H, D)`, where
328
            :math:`B` is the batch size.
329
        class_: Class used to instantiate the images,
330
            e.g., :class:`torchio.LabelMap`.
331
            If ``None``, :class:`torchio.ScalarImage` will be used.
332
        name: Name of the images added to the subjects.
333
    """
334
    if class_ is None:
335
        from . import ScalarImage
336
337
        class_ = ScalarImage
338
    for subject, data in zip(subjects, tensor):
339
        one_image = subject.get_first_image()
340
        kwargs = {'tensor': data, 'affine': one_image.affine}
341
        if 'filename' in one_image:
342
            kwargs['filename'] = one_image['filename']
343
        image = class_(**kwargs)
344
        subject.add_image(image, name)
345
346
347
def guess_external_viewer() -> Path | None:
348
    """Guess the path to an executable that could be used to visualize images.
349
350
    It looks for 1) ITK-SNAP and 2) 3D Slicer.
351
    """
352
    if 'SITK_SHOW_COMMAND' in os.environ:
353
        return Path(os.environ['SITK_SHOW_COMMAND'])
354
355
    if (platform := sys.platform) == 'darwin':
356
        return _guess_macos_viewer()
357
    elif platform == 'win32':
358
        return _guess_windows_viewer()
359
    elif 'linux' in platform:
360
        return _guess_linux_viewer()
361
    else:
362
        return None
363
364
365
def _guess_macos_viewer() -> Optional[Path]:
366
    def _get_app_path(app_name: str) -> Path:
367
        app_path = '/Applications/{}.app/Contents/MacOS/{}'
368
        return Path(app_path.format(2 * (app_name,)))
369
370
    if (itk_snap_path := _get_app_path(ITK_SNAP)).is_file():
371
        return itk_snap_path
372
    elif (slicer_path := _get_app_path(SLICER)).is_file():
373
        return slicer_path
374
    else:
375
        return None
376
377
378
def _guess_windows_viewer() -> Optional[Path]:
379
    def _get_app_path(app_dirs: list[Path], bin_name: str) -> Path:
380
        app_dir = app_dirs[-1]
381
        app_path = app_dir / bin_name
382
        if app_path.is_file():
383
            return app_path
384
385
    program_files_dir = Path(os.environ['ProgramW6432'])
386
    itk_snap_dirs = list(program_files_dir.glob(f'{ITK_SNAP}*'))
387
    slicer_dirs = list(program_files_dir.glob(f'{SLICER}*'))
388
389
    if itk_snap_dirs:
390
        itk_snap_path = _get_app_path(itk_snap_dirs, 'bin/itk-snap.exe')
391
        if itk_snap_path.is_file():
392
            return itk_snap_path
393
    elif slicer_dirs:
394
        slicer_path = _get_app_path(slicer_dirs, 'slicer.exe')
395
        if slicer_path.is_file():
396
            return slicer_path
397
    else:
398
        return None
399
400
401
def _guess_linux_viewer() -> Optional[Path]:
402
    if (itk_snap_which := shutil.which('itksnap')) is not None:
403
        return Path(itk_snap_which)
404
    elif (slicer_which := shutil.which('Slicer')) is not None:
405
        return Path(slicer_which)
406
    else:
407
        return None
408
409
410
def parse_spatial_shape(shape):
411
    result = to_tuple(shape, length=3)
412
    for n in result:
413
        if n < 1 or n % 1:
414
            message = (
415
                'All elements in a spatial shape must be positive integers,'
416
                f' but the following shape was passed: {shape}'
417
            )
418
            raise ValueError(message)
419
    if len(result) != 3:
420
        message = (
421
            'Spatial shapes must have 3 elements, but the following shape'
422
            f' was passed: {shape}'
423
        )
424
        raise ValueError(message)
425
    return result
426
427
428
def normalize_path(path: TypePath):
429
    return Path(path).expanduser().resolve()
430
431
432
def is_iterable(object: Any) -> bool:
433
    try:
434
        iter(object)
435
        return True
436
    except TypeError:
437
        return False
438
439
440
def in_class(classes) -> bool:
441
    classes = to_tuple(classes)
442
    stack = inspect.stack()
443
    for frame_info in stack:
444
        instance = frame_info.frame.f_locals.get('self')
445
        if instance is None:
446
            continue
447
        if instance.__class__ in classes:
448
            return True
449
    else:
450
        return False
451
452
453
def in_torch_loader() -> bool:
454
    classes = (
455
        torch.utils.data.dataloader._SingleProcessDataLoaderIter,
456
        torch.utils.data.dataloader._MultiProcessingDataLoaderIter,
457
    )
458
    return in_class(classes)
459