torchio.utils.to_tuple()   A
last analyzed

Complexity

Conditions 2

Size

Total Lines 19
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 9
dl 0
loc 19
rs 9.95
c 0
b 0
f 0
cc 2
nop 2
1
import ast
2
import gzip
3
import shutil
4
import tempfile
5
from pathlib import Path
6
from typing import Union, Tuple, Any, Optional, List, Sequence, Dict
7
8
import torch
9
from torch.utils.data._utils.collate import default_collate
10
import numpy as np
11
import nibabel as nib
12
import SimpleITK as sitk
13
from tqdm import trange
14
15
from . import constants
16
from .typing import TypeNumber, TypePath
17
18
19
def to_tuple(
20
        value: Any,
21
        length: int = 1,
22
        ) -> Tuple[TypeNumber, ...]:
23
    """
24
    to_tuple(1, length=1) -> (1,)
25
    to_tuple(1, length=3) -> (1, 1, 1)
26
27
    If value is an iterable, n is ignored and tuple(value) is returned
28
    to_tuple((1,), length=1) -> (1,)
29
    to_tuple((1, 2), length=1) -> (1, 2)
30
    to_tuple([1, 2], length=3) -> (1, 2)
31
    """
32
    try:
33
        iter(value)
34
        value = tuple(value)
35
    except TypeError:
36
        value = length * (value,)
37
    return value
38
39
40
def get_stem(
41
        path: Union[TypePath, Sequence[TypePath]]
42
        ) -> Union[str, List[str]]:
43
    """
44
    '/home/user/image.nii.gz' -> 'image'
45
    """
46
    def _get_stem(path_string):
47
        return Path(path_string).name.split('.')[0]
48
    if isinstance(path, (str, Path)):
49
        return _get_stem(path)
50
    return [_get_stem(p) for p in path]
51
52
53
def create_dummy_dataset(
54
        num_images: int,
55
        size_range: Tuple[int, int],
56
        directory: Optional[TypePath] = None,
57
        suffix: str = '.nii.gz',
58
        force: bool = False,
59
        verbose: bool = False,
60
        ):
61
    from .data import ScalarImage, LabelMap, Subject
62
    output_dir = tempfile.gettempdir() if directory is None else directory
63
    output_dir = Path(output_dir)
64
    images_dir = output_dir / 'dummy_images'
65
    labels_dir = output_dir / 'dummy_labels'
66
67
    if force:
68
        shutil.rmtree(images_dir)
69
        shutil.rmtree(labels_dir)
70
71
    subjects: List[Subject] = []
72
    if images_dir.is_dir():
73
        for i in trange(num_images):
74
            image_path = images_dir / f'image_{i}{suffix}'
75
            label_path = labels_dir / f'label_{i}{suffix}'
76
            subject = Subject(
77
                one_modality=ScalarImage(image_path),
78
                segmentation=LabelMap(label_path),
79
            )
80
            subjects.append(subject)
81
    else:
82
        images_dir.mkdir(exist_ok=True, parents=True)
83
        labels_dir.mkdir(exist_ok=True, parents=True)
84
        if verbose:
85
            print('Creating dummy dataset...')  # noqa: T001
86
            iterable = trange(num_images)
87
        else:
88
            iterable = range(num_images)
89
        for i in iterable:
90
            shape = np.random.randint(*size_range, size=3)
91
            affine = np.eye(4)
92
            image = np.random.rand(*shape)
93
            label = np.ones_like(image)
94
            label[image < 0.33] = 0
95
            label[image > 0.66] = 2
96
            image *= 255
97
98
            image_path = images_dir / f'image_{i}{suffix}'
99
            nii = nib.Nifti1Image(image.astype(np.uint8), affine)
100
            nii.to_filename(str(image_path))
101
102
            label_path = labels_dir / f'label_{i}{suffix}'
103
            nii = nib.Nifti1Image(label.astype(np.uint8), affine)
104
            nii.to_filename(str(label_path))
105
106
            subject = Subject(
107
                one_modality=ScalarImage(image_path),
108
                segmentation=LabelMap(label_path),
109
            )
110
            subjects.append(subject)
111
    return subjects
112
113
114
def apply_transform_to_file(
115
        input_path: TypePath,
116
        transform,  # : Transform seems to create a circular import
117
        output_path: TypePath,
118
        class_: str = 'ScalarImage',
119
        verbose: bool = False,
120
        ):
121
    from . import data
122
    image = getattr(data, class_)(input_path)
123
    subject = data.Subject(image=image)
124
    transformed = transform(subject)
125
    transformed.image.save(output_path)
126
    if verbose and transformed.history:
127
        print('Applied transform:', transformed.history[0])  # noqa: T001
128
129
130
def guess_type(string: str) -> Any:
131
    # Adapted from
132
    # https://www.reddit.com/r/learnpython/comments/4599hl/module_to_guess_type_from_a_string/czw3f5s
133
    string = string.replace(' ', '')
134
    result_type: Any
135
    try:
136
        value = ast.literal_eval(string)
137
    except ValueError:
138
        result_type = str
139
    else:
140
        result_type = type(value)
141
    if result_type in (list, tuple):
142
        string = string[1:-1]  # remove brackets
143
        split = string.split(',')
144
        list_result = [guess_type(n) for n in split]
145
        value = tuple(list_result) if result_type is tuple else list_result
146
        return value
147
    try:
148
        value = result_type(string)
149
    except TypeError:
150
        value = None
151
    return value
152
153
154
def get_torchio_cache_dir() -> Path:
155
    return Path('~/.cache/torchio').expanduser()
156
157
158
def compress(
159
        input_path: TypePath,
160
        output_path: Optional[TypePath] = None,
161
        ) -> Path:
162
    if output_path is None:
163
        output_path = Path(input_path).with_suffix('.nii.gz')
164
    with open(input_path, 'rb') as f_in:
165
        with gzip.open(output_path, 'wb') as f_out:
166
            shutil.copyfileobj(f_in, f_out)
167
    return Path(output_path)
168
169
170
def check_sequence(sequence: Sequence, name: str) -> None:
171
    try:
172
        iter(sequence)
173
    except TypeError:
174
        message = f'"{name}" must be a sequence, not {type(name)}'
175
        raise TypeError(message)
176
177
178
def get_major_sitk_version() -> int:
179
    # This attribute was added in version 2
180
    # https://github.com/SimpleITK/SimpleITK/pull/1171
181
    version = getattr(sitk, '__version__', None)
182
    major_version = 1 if version is None else 2
183
    return major_version
184
185
186
def history_collate(batch: Sequence, collate_transforms=True) -> Dict:
187
    attr = constants.HISTORY if collate_transforms else 'applied_transforms'
188
    # Adapted from
189
    # https://github.com/romainVala/torchQC/blob/master/segmentation/collate_functions.py
190
    from .data import Subject
191
    first_element = batch[0]
192
    if isinstance(first_element, Subject):
193
        dictionary = {
194
            key: default_collate([d[key] for d in batch])
195
            for key in first_element
196
        }
197
        if hasattr(first_element, attr):
198
            dictionary.update({attr: [getattr(d, attr) for d in batch]})
199
    else:
200
        dictionary = {}
201
    return dictionary
202
203
204
def get_subclasses(target_class: type) -> List[type]:
205
    subclasses = target_class.__subclasses__()
206
    subclasses += sum((get_subclasses(cls) for cls in subclasses), [])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable cls does not seem to be defined.
Loading history...
207
    return subclasses
208
209
210
def get_first_item(data_loader: torch.utils.data.DataLoader):
211
    return next(iter(data_loader))
212
213
214
def get_batch_images_and_size(batch: Dict) -> Tuple[List[str], int]:
215
    """Get number of images and images names in a batch.
216
217
    Args:
218
        batch: Dictionary generated by a :class:`torch.utils.data.DataLoader`
219
        extracting data from a :class:`torchio.SubjectsDataset`.
220
221
    Raises:
222
        RuntimeError: If the batch does not seem to contain any dictionaries
223
        that seem to represent a :class:`torchio.Image`.
224
    """
225
    names = []
226
    for image_name, image_dict in batch.items():
227
        if constants.DATA in image_dict:  # assume it is a TorchIO Image
228
            size = len(image_dict[constants.DATA])
229
            names.append(image_name)
230
    if not names:
231
        raise RuntimeError('The batch does not seem to contain any images')
232
    return names, size
0 ignored issues
show
introduced by
The variable size does not seem to be defined for all execution paths.
Loading history...
233
234
235
def get_subjects_from_batch(batch: Dict) -> List:
236
    """Get list of subjects from collated batch.
237
238
    Args:
239
        batch: Dictionary generated by a :class:`torch.utils.data.DataLoader`
240
        extracting data from a :class:`torchio.SubjectsDataset`.
241
    """
242
    from .data import ScalarImage, LabelMap, Subject
243
    subjects = []
244
    image_names, batch_size = get_batch_images_and_size(batch)
245
    for i in range(batch_size):
246
        subject_dict = {}
247
        for image_name in image_names:
248
            image_dict = batch[image_name]
249
            data = image_dict[constants.DATA][i]
250
            affine = image_dict[constants.AFFINE][i]
251
            path = Path(image_dict[constants.PATH][i])
252
            is_label = image_dict[constants.TYPE] == constants.LABEL
253
            klass = LabelMap if is_label else ScalarImage
254
            image = klass(tensor=data, affine=affine, filename=path.name)
255
            subject_dict[image_name] = image
256
        subject = Subject(subject_dict)
257
        subjects.append(subject)
258
    return subjects
259
260
261
def add_images_from_batch(
262
        subjects: List,
263
        tensor: torch.Tensor,
264
        class_=None,
265
        name='prediction',
266
        ) -> None:
267
    """Add images to subjects in a list, typically from a network prediction.
268
269
    The spatial metadata (affine matrices) will be extracted from one of the
270
    images of each subject.
271
272
    Args:
273
        subjects: List of instances of :class:`torchio.Subject` to which images
274
            will be added.
275
        tensor: PyTorch tensor of shape :math:`(B, C, W, H, D)`, where
276
            :math:`B` is the batch size.
277
        class_: Class used to instantiate the images,
278
            e.g., :class:`torchio.LabelMap`.
279
            If ``None``, :class:`torchio.ScalarImage` will be used.
280
        name: Name of the images added to the subjects.
281
    """
282
    if class_ is None:
283
        from . import ScalarImage
284
        class_ = ScalarImage
285
    for subject, data in zip(subjects, tensor):
286
        one_image = subject.get_first_image()
287
        kwargs = {'tensor': data, 'affine': one_image.affine}
288
        if 'filename' in one_image:
289
            kwargs['filename'] = one_image['filename']
290
        image = class_(**kwargs)
291
        subject.add_image(image, name)
292