Passed
Pull Request — master (#508)
by Fernando
01:20
created

torchio.utils.get_first_item()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 2
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
import ast
2
import gzip
3
import shutil
4
import tempfile
5
from pathlib import Path
6
from typing import Union, Iterable, Tuple, Any, Optional, List, Sequence
7
8
from torch.utils.data._utils.collate import default_collate
9
import numpy as np
10
import nibabel as nib
11
import SimpleITK as sitk
12
from tqdm import trange
13
14
from .typing import TypeNumber, TypePath
15
16
17
def to_tuple(
18
        value: Union[TypeNumber, Iterable[TypeNumber]],
19
        length: int = 1,
20
        ) -> Tuple[TypeNumber, ...]:
21
    """
22
    to_tuple(1, length=1) -> (1,)
23
    to_tuple(1, length=3) -> (1, 1, 1)
24
25
    If value is an iterable, n is ignored and tuple(value) is returned
26
    to_tuple((1,), length=1) -> (1,)
27
    to_tuple((1, 2), length=1) -> (1, 2)
28
    to_tuple([1, 2], length=3) -> (1, 2)
29
    """
30
    try:
31
        iter(value)
32
        value = tuple(value)
33
    except TypeError:
34
        value = length * (value,)
35
    return value
36
37
38
def get_stem(
39
        path: Union[TypePath, List[TypePath]]
40
        ) -> Union[str, List[str]]:
41
    """
42
    '/home/user/image.nii.gz' -> 'image'
43
    """
44
    def _get_stem(path_string):
45
        return Path(path_string).name.split('.')[0]
46
    if isinstance(path, (str, Path)):
47
        return _get_stem(path)
48
    return [_get_stem(p) for p in path]
49
50
51
def create_dummy_dataset(
52
        num_images: int,
53
        size_range: Tuple[int, int],
54
        directory: Optional[TypePath] = None,
55
        suffix: str = '.nii.gz',
56
        force: bool = False,
57
        verbose: bool = False,
58
        ):
59
    from .data import ScalarImage, LabelMap, Subject
60
    output_dir = tempfile.gettempdir() if directory is None else directory
61
    output_dir = Path(output_dir)
62
    images_dir = output_dir / 'dummy_images'
63
    labels_dir = output_dir / 'dummy_labels'
64
65
    if force:
66
        shutil.rmtree(images_dir)
67
        shutil.rmtree(labels_dir)
68
69
    subjects: List[Subject] = []
70
    if images_dir.is_dir():
71
        for i in trange(num_images):
72
            image_path = images_dir / f'image_{i}{suffix}'
73
            label_path = labels_dir / f'label_{i}{suffix}'
74
            subject = Subject(
75
                one_modality=ScalarImage(image_path),
76
                segmentation=LabelMap(label_path),
77
            )
78
            subjects.append(subject)
79
    else:
80
        images_dir.mkdir(exist_ok=True, parents=True)
81
        labels_dir.mkdir(exist_ok=True, parents=True)
82
        if verbose:
83
            print('Creating dummy dataset...')  # noqa: T001
84
            iterable = trange(num_images)
85
        else:
86
            iterable = range(num_images)
87
        for i in iterable:
88
            shape = np.random.randint(*size_range, size=3)
89
            affine = np.eye(4)
90
            image = np.random.rand(*shape)
91
            label = np.ones_like(image)
92
            label[image < 0.33] = 0
93
            label[image > 0.66] = 2
94
            image *= 255
95
96
            image_path = images_dir / f'image_{i}{suffix}'
97
            nii = nib.Nifti1Image(image.astype(np.uint8), affine)
98
            nii.to_filename(str(image_path))
99
100
            label_path = labels_dir / f'label_{i}{suffix}'
101
            nii = nib.Nifti1Image(label.astype(np.uint8), affine)
102
            nii.to_filename(str(label_path))
103
104
            subject = Subject(
105
                one_modality=ScalarImage(image_path),
106
                segmentation=LabelMap(label_path),
107
            )
108
            subjects.append(subject)
109
    return subjects
110
111
112
def apply_transform_to_file(
113
        input_path: TypePath,
114
        transform,  # : Transform seems to create a circular import
115
        output_path: TypePath,
116
        class_: str = 'ScalarImage',
117
        verbose: bool = False,
118
        ):
119
    from . import data
120
    image = getattr(data, class_)(input_path)
121
    subject = data.Subject(image=image)
122
    transformed = transform(subject)
123
    transformed.image.save(output_path)
124
    if verbose and transformed.history:
125
        print('Applied transform:', transformed.history[0])  # noqa: T001
126
127
128
def guess_type(string: str) -> Any:
129
    # Adapted from
130
    # https://www.reddit.com/r/learnpython/comments/4599hl/module_to_guess_type_from_a_string/czw3f5s
131
    string = string.replace(' ', '')
132
    try:
133
        value = ast.literal_eval(string)
134
    except ValueError:
135
        result_type = str
136
    else:
137
        result_type = type(value)
138
    if result_type in (list, tuple):
139
        string = string[1:-1]  # remove brackets
140
        split = string.split(',')
141
        list_result = [guess_type(n) for n in split]
142
        value = tuple(list_result) if result_type is tuple else list_result
143
        return value
144
    try:
145
        value = result_type(string)
146
    except TypeError:
147
        value = None
148
    return value
149
150
151
def get_torchio_cache_dir():
152
    return Path('~/.cache/torchio').expanduser()
153
154
155
def compress(input_path, output_path):
156
    with open(input_path, 'rb') as f_in:
157
        with gzip.open(output_path, 'wb') as f_out:
158
            shutil.copyfileobj(f_in, f_out)
159
160
161
def check_sequence(sequence: Sequence, name: str):
162
    try:
163
        iter(sequence)
164
    except TypeError:
165
        message = f'"{name}" must be a sequence, not {type(name)}'
166
        raise TypeError(message)
167
168
169
def get_major_sitk_version() -> int:
170
    # This attribute was added in version 2
171
    # https://github.com/SimpleITK/SimpleITK/pull/1171
172
    version = getattr(sitk, '__version__', None)
173
    major_version = 1 if version is None else 2
174
    return major_version
175
176
177
def history_collate(batch: Sequence, collate_transforms=True):
178
    attr = 'history' if collate_transforms else 'applied_transforms'
179
    # Adapted from
180
    # https://github.com/romainVala/torchQC/blob/master/segmentation/collate_functions.py
181
    from .data import Subject
182
    first_element = batch[0]
183
    if isinstance(first_element, Subject):
184
        dictionary = {
185
            key: default_collate([d[key] for d in batch])
186
            for key in first_element
187
        }
188
        if hasattr(first_element, attr):
189
            dictionary.update({attr: [getattr(d, attr) for d in batch]})
190
        return dictionary
191
192
193
def get_subclasses(target_class: type) -> List[type]:
194
    subclasses = target_class.__subclasses__()
195
    subclasses += sum((get_subclasses(cls) for cls in subclasses), [])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable cls does not seem to be defined.
Loading history...
196
    return subclasses
197
198
199
def get_first_item(data_loader):
200
    return next(iter(data_loader))
201