Passed
Push — master ( c9d9b4...13aa31 )
by Fernando
01:10
created

torchio.utils.compress()   A

Complexity

Conditions 3

Size

Total Lines 4
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 4
nop 2
dl 0
loc 4
rs 10
c 0
b 0
f 0
1
import ast
2
import gzip
3
import shutil
4
import tempfile
5
from pathlib import Path
6
from typing import Union, Iterable, Tuple, Any, Optional, List, Sequence
7
8
from torch.utils.data._utils.collate import default_collate
9
import numpy as np
10
import nibabel as nib
11
import SimpleITK as sitk
12
from tqdm import trange
13
14
from .constants import INTENSITY
15
from .typing import TypeNumber, TypePath
16
17
18
def to_tuple(
19
        value: Union[TypeNumber, Iterable[TypeNumber]],
20
        length: int = 1,
21
        ) -> Tuple[TypeNumber, ...]:
22
    """
23
    to_tuple(1, length=1) -> (1,)
24
    to_tuple(1, length=3) -> (1, 1, 1)
25
26
    If value is an iterable, n is ignored and tuple(value) is returned
27
    to_tuple((1,), length=1) -> (1,)
28
    to_tuple((1, 2), length=1) -> (1, 2)
29
    to_tuple([1, 2], length=3) -> (1, 2)
30
    """
31
    try:
32
        iter(value)
33
        value = tuple(value)
34
    except TypeError:
35
        value = length * (value,)
36
    return value
37
38
39
def get_stem(
40
        path: Union[TypePath, List[TypePath]]
41
        ) -> Union[str, List[str]]:
42
    """
43
    '/home/user/image.nii.gz' -> 'image'
44
    """
45
    def _get_stem(path_string):
46
        return Path(path_string).name.split('.')[0]
47
    if isinstance(path, (str, Path)):
48
        return _get_stem(path)
49
    return [_get_stem(p) for p in path]
50
51
52
def create_dummy_dataset(
53
        num_images: int,
54
        size_range: Tuple[int, int],
55
        directory: Optional[TypePath] = None,
56
        suffix: str = '.nii.gz',
57
        force: bool = False,
58
        verbose: bool = False,
59
        ):
60
    from .data import ScalarImage, LabelMap, Subject
61
    output_dir = tempfile.gettempdir() if directory is None else directory
62
    output_dir = Path(output_dir)
63
    images_dir = output_dir / 'dummy_images'
64
    labels_dir = output_dir / 'dummy_labels'
65
66
    if force:
67
        shutil.rmtree(images_dir)
68
        shutil.rmtree(labels_dir)
69
70
    subjects: List[Subject] = []
71
    if images_dir.is_dir():
72
        for i in trange(num_images):
73
            image_path = images_dir / f'image_{i}{suffix}'
74
            label_path = labels_dir / f'label_{i}{suffix}'
75
            subject = Subject(
76
                one_modality=ScalarImage(image_path),
77
                segmentation=LabelMap(label_path),
78
            )
79
            subjects.append(subject)
80
    else:
81
        images_dir.mkdir(exist_ok=True, parents=True)
82
        labels_dir.mkdir(exist_ok=True, parents=True)
83
        if verbose:
84
            print('Creating dummy dataset...')  # noqa: T001
85
            iterable = trange(num_images)
86
        else:
87
            iterable = range(num_images)
88
        for i in iterable:
89
            shape = np.random.randint(*size_range, size=3)
90
            affine = np.eye(4)
91
            image = np.random.rand(*shape)
92
            label = np.ones_like(image)
93
            label[image < 0.33] = 0
94
            label[image > 0.66] = 2
95
            image *= 255
96
97
            image_path = images_dir / f'image_{i}{suffix}'
98
            nii = nib.Nifti1Image(image.astype(np.uint8), affine)
99
            nii.to_filename(str(image_path))
100
101
            label_path = labels_dir / f'label_{i}{suffix}'
102
            nii = nib.Nifti1Image(label.astype(np.uint8), affine)
103
            nii.to_filename(str(label_path))
104
105
            subject = Subject(
106
                one_modality=ScalarImage(image_path),
107
                segmentation=LabelMap(label_path),
108
            )
109
            subjects.append(subject)
110
    return subjects
111
112
113
def apply_transform_to_file(
114
        input_path: TypePath,
115
        transform,  # : Transform seems to create a circular import
116
        output_path: TypePath,
117
        type: str = INTENSITY,  # noqa: A002
118
        verbose: bool = False,
119
        ):
120
    from . import Image, Subject
121
    subject = Subject(image=Image(input_path, type=type))
122
    transformed = transform(subject)
123
    transformed.image.save(output_path)
124
    if verbose and transformed.history:
125
        print('Applied transform:', transformed.history[0])  # noqa: T001
126
127
128
def guess_type(string: str) -> Any:
129
    # Adapted from
130
    # https://www.reddit.com/r/learnpython/comments/4599hl/module_to_guess_type_from_a_string/czw3f5s
131
    string = string.replace(' ', '')
132
    try:
133
        value = ast.literal_eval(string)
134
    except ValueError:
135
        result_type = str
136
    else:
137
        result_type = type(value)
138
    if result_type in (list, tuple):
139
        string = string[1:-1]  # remove brackets
140
        split = string.split(',')
141
        list_result = [guess_type(n) for n in split]
142
        value = tuple(list_result) if result_type is tuple else list_result
143
        return value
144
    try:
145
        value = result_type(string)
146
    except TypeError:
147
        value = None
148
    return value
149
150
151
def get_torchio_cache_dir():
152
    return Path('~/.cache/torchio').expanduser()
153
154
155
def round_up(value: float) -> int:
156
    """Round half towards infinity.
157
158
    Args:
159
        value: The value to round.
160
161
    Example:
162
163
        >>> round(2.5)
164
        2
165
        >>> round(3.5)
166
        4
167
        >>> round_up(2.5)
168
        3
169
        >>> round_up(3.5)
170
        4
171
172
    """
173
    return int(np.floor(value + 0.5))
174
175
176
def compress(input_path, output_path):
177
    with open(input_path, 'rb') as f_in:
178
        with gzip.open(output_path, 'wb') as f_out:
179
            shutil.copyfileobj(f_in, f_out)
180
181
182
def check_sequence(sequence: Sequence, name: str):
183
    try:
184
        iter(sequence)
185
    except TypeError:
186
        message = f'"{name}" must be a sequence, not {type(name)}'
187
        raise TypeError(message)
188
189
190
def get_major_sitk_version() -> int:
191
    # This attribute was added in version 2
192
    # https://github.com/SimpleITK/SimpleITK/pull/1171
193
    version = getattr(sitk, '__version__', None)
194
    major_version = 1 if version is None else 2
195
    return major_version
196
197
198
def history_collate(batch: Sequence, collate_transforms=True):
199
    attr = 'history' if collate_transforms else 'applied_transforms'
200
    # Adapted from
201
    # https://github.com/romainVala/torchQC/blob/master/segmentation/collate_functions.py
202
    from .data import Subject
203
    first_element = batch[0]
204
    if isinstance(first_element, Subject):
205
        dictionary = {
206
            key: default_collate([d[key] for d in batch])
207
            for key in first_element
208
        }
209
        if hasattr(first_element, attr):
210
            dictionary.update({attr: [getattr(d, attr) for d in batch]})
211
        return dictionary
212
213
214
def get_subclasses(target_class: type) -> List[type]:
215
    subclasses = target_class.__subclasses__()
216
    subclasses += sum([get_subclasses(cls) for cls in subclasses], [])
217
    return subclasses
218