torchio.utils.add_images_from_batch()   A
last analyzed

Complexity

Conditions 4

Size

Total Lines 32
Code Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 15
nop 4
dl 0
loc 32
rs 9.65
c 0
b 0
f 0
1
from __future__ import annotations
2
3
import ast
4
import gzip
5
import os
6
import shutil
7
import sys
8
import tempfile
9
from collections.abc import Iterable
10
from collections.abc import Sequence
11
from pathlib import Path
12
from typing import Any
13
14
import numpy as np
15
import SimpleITK as sitk
16
import torch
17
from nibabel.nifti1 import Nifti1Image
18
from torch.utils.data import DataLoader
19
from torch.utils.data._utils.collate import default_collate
20
from tqdm.auto import trange
21
22
from . import constants
23
from .types import TypeNumber
24
from .types import TypePath
25
26
27
def to_tuple(
28
    value: Any,
29
    length: int = 1,
30
) -> tuple[TypeNumber, ...]:
31
    """Convert variable to tuple of length n.
32
33
    Example:
34
        >>> from torchio.utils import to_tuple
35
        >>> to_tuple(1, length=1)
36
        (1,)
37
        >>> to_tuple(1, length=3)
38
        (1, 1, 1)
39
40
    If value is an iterable, n is ignored and tuple(value) is returned
41
42
    Example:
43
        >>> to_tuple((1,), length=1)
44
        (1,)
45
        >>> to_tuple((1, 2), length=1)
46
        (1, 2)
47
        >>> to_tuple([1, 2], length=3)
48
        (1, 2)
49
    """
50
    try:
51
        iter(value)
52
        value = tuple(value)
53
    except TypeError:
54
        value = length * (value,)
55
    return value
56
57
58
def get_stem(
59
    path: TypePath | Sequence[TypePath],
60
) -> str | list[str]:
61
    """Get stem of path or paths.
62
63
    Example:
64
        >>> from torchio.utils import get_stem
65
        >>> get_stem('/home/user/my_image.nii.gz')
66
        'my_image'
67
    """
68
69
    def _get_stem(path_string: TypePath) -> str:
70
        return Path(path_string).name.split('.')[0]
71
72
    if isinstance(path, (str, os.PathLike)):
73
        return _get_stem(path)
74
    else:  # path is actually a sequence of paths
75
        return [_get_stem(p) for p in path]
76
77
78
def create_dummy_dataset(
79
    num_images: int,
80
    size_range: tuple[int, int],
81
    directory: TypePath | None = None,
82
    suffix: str = '.nii.gz',
83
    force: bool = False,
84
    verbose: bool = False,
85
):
86
    from .data import LabelMap
87
    from .data import ScalarImage
88
    from .data import Subject
89
90
    output_dir = tempfile.gettempdir() if directory is None else directory
91
    output_dir = Path(output_dir)
92
    images_dir = output_dir / 'dummy_images'
93
    labels_dir = output_dir / 'dummy_labels'
94
95
    if force:
96
        shutil.rmtree(images_dir)
97
        shutil.rmtree(labels_dir)
98
99
    subjects: list[Subject] = []
100
    if images_dir.is_dir():
101
        for i in trange(num_images):
102
            image_path = images_dir / f'image_{i}{suffix}'
103
            label_path = labels_dir / f'label_{i}{suffix}'
104
            subject = Subject(
105
                one_modality=ScalarImage(image_path),
106
                segmentation=LabelMap(label_path),
107
            )
108
            subjects.append(subject)
109
    else:
110
        images_dir.mkdir(exist_ok=True, parents=True)
111
        labels_dir.mkdir(exist_ok=True, parents=True)
112
        iterable: Iterable[int]
113
        if verbose:
114
            print('Creating dummy dataset...')  # noqa: T201
115
            iterable = trange(num_images)
116
        else:
117
            iterable = range(num_images)
118
        for i in iterable:
119
            shape = np.random.randint(*size_range, size=3)
120
            affine = np.eye(4)
121
            image = np.random.rand(*shape)
122
            label = np.ones_like(image)
123
            label[image < 0.33] = 0
124
            label[image > 0.66] = 2
125
            image *= 255
126
127
            image_path = images_dir / f'image_{i}{suffix}'
128
            nii = Nifti1Image(image.astype(np.uint8), affine)
129
            nii.to_filename(str(image_path))
130
131
            label_path = labels_dir / f'label_{i}{suffix}'
132
            nii = Nifti1Image(label.astype(np.uint8), affine)
133
            nii.to_filename(str(label_path))
134
135
            subject = Subject(
136
                one_modality=ScalarImage(image_path),
137
                segmentation=LabelMap(label_path),
138
            )
139
            subjects.append(subject)
140
    return subjects
141
142
143
def apply_transform_to_file(
144
    input_path: TypePath,
145
    transform,  # : Transform seems to create a circular import
146
    output_path: TypePath,
147
    class_: str = 'ScalarImage',
148
    verbose: bool = False,
149
):
150
    from . import data
151
152
    image = getattr(data, class_)(input_path)
153
    subject = data.Subject(image=image)
154
    transformed = transform(subject)
155
    transformed.image.save(output_path)
156
    if verbose and transformed.history:
157
        print('Applied transform:', transformed.history[0])  # noqa: T201
158
159
160
def guess_type(string: str) -> Any:
161
    # Adapted from
162
    # https://www.reddit.com/r/learnpython/comments/4599hl/module_to_guess_type_from_a_string/czw3f5s
163
    string = string.replace(' ', '')
164
    result_type: Any
165
    try:
166
        value = ast.literal_eval(string)
167
    except ValueError:
168
        result_type = str
169
    else:
170
        result_type = type(value)
171
    if result_type in (list, tuple):
172
        string = string[1:-1]  # remove brackets
173
        split = string.split(',')
174
        list_result = [guess_type(n) for n in split]
175
        value = tuple(list_result) if result_type is tuple else list_result
176
        return value
177
    try:
178
        value = result_type(string)
179
    except TypeError:
180
        value = None
181
    return value
182
183
184
def get_torchio_cache_dir() -> Path:
185
    return Path('~/.cache/torchio').expanduser()
186
187
188
def compress(
189
    input_path: TypePath,
190
    output_path: TypePath | None = None,
191
) -> Path:
192
    if output_path is None:
193
        output_path = Path(input_path).with_suffix('.nii.gz')
194
    with open(input_path, 'rb') as f_in:
195
        with gzip.open(output_path, 'wb') as f_out:
196
            shutil.copyfileobj(f_in, f_out)
197
    return Path(output_path)
198
199
200
def check_sequence(sequence: Sequence, name: str) -> None:
201
    try:
202
        iter(sequence)
203
    except TypeError as err:
204
        message = f'"{name}" must be a sequence, not {type(name)}'
205
        raise TypeError(message) from err
206
207
208
def get_major_sitk_version() -> int:
209
    # This attribute was added in version 2
210
    # https://github.com/SimpleITK/SimpleITK/pull/1171
211
    version = getattr(sitk, '__version__', None)
212
    major_version = 1 if version is None else 2
213
    return major_version
214
215
216
def history_collate(batch: Sequence, collate_transforms=True) -> dict:
217
    attr = constants.HISTORY if collate_transforms else 'applied_transforms'
218
    # Adapted from
219
    # https://github.com/romainVala/torchQC/blob/master/segmentation/collate_functions.py
220
    from .data import Subject
221
222
    first_element = batch[0]
223
    if isinstance(first_element, Subject):
224
        dictionary = {
225
            key: default_collate([d[key] for d in batch]) for key in first_element
226
        }
227
        if hasattr(first_element, attr):
228
            dictionary.update({attr: [getattr(d, attr) for d in batch]})
229
    else:
230
        dictionary = {}
231
    return dictionary
232
233
234
def get_subclasses(target_class: type) -> list[type]:
235
    subclasses = target_class.__subclasses__()
236
    subclasses += sum((get_subclasses(cls) for cls in subclasses), [])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable cls does not seem to be defined.
Loading history...
237
    return subclasses
238
239
240
def get_first_item(data_loader: DataLoader):
241
    return next(iter(data_loader))
242
243
244
def get_batch_images_and_size(batch: dict) -> tuple[list[str], int]:
245
    """Get number of images and images names in a batch.
246
247
    Args:
248
        batch: Dictionary generated by a :class:`tio.SubjectsLoader`
249
            extracting data from a :class:`torchio.SubjectsDataset`.
250
251
    Raises:
252
        RuntimeError: If the batch does not seem to contain any dictionaries
253
            that seem to represent a :class:`torchio.Image`.
254
    """
255
    names = []
256
    for key, value in batch.items():
257
        if isinstance(value, dict) and constants.DATA in value:
258
            size = len(value[constants.DATA])
259
            names.append(key)
260
    if not names:
261
        raise RuntimeError('The batch does not seem to contain any images')
262
    return names, size
0 ignored issues
show
introduced by
The variable size does not seem to be defined for all execution paths.
Loading history...
263
264
265
def get_subjects_from_batch(batch: dict) -> list:
266
    """Get list of subjects from collated batch.
267
268
    Args:
269
        batch: Dictionary generated by a :class:`tio.SubjectsLoader`
270
            extracting data from a :class:`torchio.SubjectsDataset`.
271
    """
272
    from .data import LabelMap
273
    from .data import ScalarImage
274
    from .data import Subject
275
276
    subjects = []
277
    image_names, batch_size = get_batch_images_and_size(batch)
278
279
    for i in range(batch_size):
280
        subject_dict = {}
281
282
        for key, value in batch.items():
283
            if key in image_names:
284
                image_name = key
285
                image_dict = value
286
                data = image_dict[constants.DATA][i]
287
                affine = image_dict[constants.AFFINE][i]
288
                path = Path(image_dict[constants.PATH][i])
289
                is_label = image_dict[constants.TYPE][i] == constants.LABEL
290
                klass = LabelMap if is_label else ScalarImage
291
                image = klass(tensor=data, affine=affine, filename=path.name)
292
                subject_dict[image_name] = image
293
            else:
294
                instance_value = value[i]
295
                subject_dict[key] = instance_value
296
297
        subject = Subject(subject_dict)
298
299
        if constants.HISTORY in batch:
300
            applied_transforms = batch[constants.HISTORY][i]
301
            for transform in applied_transforms:
302
                transform.add_transform_to_subject_history(subject)
303
304
        subjects.append(subject)
305
    return subjects
306
307
308
def add_images_from_batch(
309
    subjects: list,
310
    tensor: torch.Tensor,
311
    class_=None,
312
    name='prediction',
313
) -> None:
314
    """Add images to subjects in a list, typically from a network prediction.
315
316
    The spatial metadata (affine matrices) will be extracted from one of the
317
    images of each subject.
318
319
    Args:
320
        subjects: List of instances of :class:`torchio.Subject` to which images
321
            will be added.
322
        tensor: PyTorch tensor of shape :math:`(B, C, W, H, D)`, where
323
            :math:`B` is the batch size.
324
        class_: Class used to instantiate the images,
325
            e.g., :class:`torchio.LabelMap`.
326
            If ``None``, :class:`torchio.ScalarImage` will be used.
327
        name: Name of the images added to the subjects.
328
    """
329
    if class_ is None:
330
        from . import ScalarImage
331
332
        class_ = ScalarImage
333
    for subject, data in zip(subjects, tensor):
334
        one_image = subject.get_first_image()
335
        kwargs = {'tensor': data, 'affine': one_image.affine}
336
        if 'filename' in one_image:
337
            kwargs['filename'] = one_image['filename']
338
        image = class_(**kwargs)
339
        subject.add_image(image, name)
340
341
342
def guess_external_viewer() -> Path | None:
343
    """Guess the path to an executable that could be used to visualize images.
344
345
    Currently, it looks for 1) ITK-SNAP and 2) 3D Slicer. Implemented
346
    for macOS and Windows.
347
    """
348
    if 'SITK_SHOW_COMMAND' in os.environ:
349
        return Path(os.environ['SITK_SHOW_COMMAND'])
350
    platform = sys.platform
351
    itk = 'ITK-SNAP'
352
    slicer = 'Slicer'
353
    if platform == 'darwin':
354
        app_path = '/Applications/{}.app/Contents/MacOS/{}'
355
        itk_snap_path = Path(app_path.format(2 * (itk,)))
356
        if itk_snap_path.is_file():
357
            return itk_snap_path
358
        slicer_path = Path(app_path.format(2 * (slicer,)))
359
        if slicer_path.is_file():
360
            return slicer_path
361
    elif platform == 'win32':
362
        program_files_dir = Path(os.environ['ProgramW6432'])
363
        itk_snap_dirs = list(program_files_dir.glob('ITK-SNAP*'))
364
        if itk_snap_dirs:
365
            itk_snap_dir = itk_snap_dirs[-1]
366
            itk_snap_path = itk_snap_dir / 'bin/itk-snap.exe'
367
            if itk_snap_path.is_file():
368
                return itk_snap_path
369
        slicer_dirs = list(program_files_dir.glob('Slicer*'))
370
        if slicer_dirs:
371
            slicer_dir = slicer_dirs[-1]
372
            slicer_path = slicer_dir / 'slicer.exe'
373
            if slicer_path.is_file():
374
                return slicer_path
375
    elif 'linux' in platform:
376
        itk_snap_which = shutil.which('itksnap')
377
        if itk_snap_which is not None:
378
            return Path(itk_snap_which)
379
        slicer_which = shutil.which('Slicer')
380
        if slicer_which is not None:
381
            return Path(slicer_which)
382
    return None  # for mypy
383
384
385
def parse_spatial_shape(shape):
386
    result = to_tuple(shape, length=3)
387
    for n in result:
388
        if n < 1 or n % 1:
389
            message = (
390
                'All elements in a spatial shape must be positive integers,'
391
                f' but the following shape was passed: {shape}'
392
            )
393
            raise ValueError(message)
394
    if len(result) != 3:
395
        message = (
396
            'Spatial shapes must have 3 elements, but the following shape'
397
            f' was passed: {shape}'
398
        )
399
        raise ValueError(message)
400
    return result
401
402
403
def normalize_path(path: TypePath):
404
    return Path(path).expanduser().resolve()
405
406
407
def is_iterable(object: Any) -> bool:
408
    try:
409
        iter(object)
410
        return True
411
    except TypeError:
412
        return False
413