Passed
Pull Request — master (#226)
by Fernando
01:12
created

torchio.utils.get_transform()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
import ast
2
import shutil
3
import tempfile
4
from pathlib import Path
5
from typing import Union, Iterable, Tuple, Any, Optional, List
6
7
import torch
8
import numpy as np
9
import nibabel as nib
10
import SimpleITK as sitk
11
from tqdm import trange
12
from .torchio import (
13
    INTENSITY,
14
    LABEL,
15
    TypeData,
16
    TypeNumber,
17
    TypePath,
18
)
19
20
21
FLIP_XY = np.diag((-1, -1, 1))  # used to switch between LPS and RAS
22
23
24
def to_tuple(
25
        value: Union[TypeNumber, Iterable[TypeNumber]],
26
        length: int = 1,
27
        ) -> Tuple[TypeNumber, ...]:
28
    """
29
    to_tuple(1, length=1) -> (1,)
30
    to_tuple(1, length=3) -> (1, 1, 1)
31
32
    If value is an iterable, n is ignored and tuple(value) is returned
33
    to_tuple((1,), length=1) -> (1,)
34
    to_tuple((1, 2), length=1) -> (1, 2)
35
    to_tuple([1, 2], length=3) -> (1, 2)
36
    """
37
    try:
38
        iter(value)
39
        value = tuple(value)
40
    except TypeError:
41
        value = length * (value,)
42
    return value
43
44
45
def get_stem(path: TypePath) -> str:
46
    """
47
    '/home/user/image.nii.gz' -> 'image'
48
    """
49
    path = Path(path)
50
    return path.name.split('.')[0]
51
52
53
def create_dummy_dataset(
54
        num_images: int,
55
        size_range: Tuple[int, int],
56
        directory: Optional[TypePath] = None,
57
        suffix: str = '.nii.gz',
58
        force: bool = False,
59
        verbose: bool = False,
60
        ):
61
    from .data import Image, Subject
62
    output_dir = tempfile.gettempdir() if directory is None else directory
63
    output_dir = Path(output_dir)
64
    images_dir = output_dir / 'dummy_images'
65
    labels_dir = output_dir / 'dummy_labels'
66
67
    if force:
68
        shutil.rmtree(images_dir)
69
        shutil.rmtree(labels_dir)
70
71
    subjects: List[Subject] = []
72
    if images_dir.is_dir():
73
        for i in trange(num_images):
74
            image_path = images_dir / f'image_{i}{suffix}'
75
            label_path = labels_dir / f'label_{i}{suffix}'
76
            subject = Subject(
77
                one_modality=Image(image_path, INTENSITY),
78
                segmentation=Image(label_path, LABEL),
79
            )
80
            subjects.append(subject)
81
    else:
82
        images_dir.mkdir(exist_ok=True, parents=True)
83
        labels_dir.mkdir(exist_ok=True, parents=True)
84
        if verbose:
85
            print('Creating dummy dataset...')
86
            iterable = trange(num_images)
87
        else:
88
            iterable = range(num_images)
89
        for i in iterable:
90
            shape = np.random.randint(*size_range, size=3)
91
            affine = np.eye(4)
92
            image = np.random.rand(*shape)
93
            label = np.ones_like(image)
94
            label[image < 0.33] = 0
95
            label[image > 0.66] = 2
96
            image *= 255
97
98
            image_path = images_dir / f'image_{i}{suffix}'
99
            nii = nib.Nifti1Image(image.astype(np.uint8), affine)
100
            nii.to_filename(str(image_path))
101
102
            label_path = labels_dir / f'label_{i}{suffix}'
103
            nii = nib.Nifti1Image(label.astype(np.uint8), affine)
104
            nii.to_filename(str(label_path))
105
106
            subject = Subject(
107
                one_modality=Image(image_path, INTENSITY),
108
                segmentation=Image(label_path, LABEL),
109
            )
110
            subjects.append(subject)
111
    return subjects
112
113
114
def apply_transform_to_file(
115
        input_path: TypePath,
116
        transform,  # : Transform seems to create a circular import (TODO)
117
        output_path: TypePath,
118
        type: str = INTENSITY,
119
        verbose: bool = False,
120
        seed: Optional[int] = None,
121
        ):
122
    from . import Image, ImagesDataset, Subject
123
    subject = Subject(image=Image(input_path, type))
124
    transformed = transform(subject, seed=seed)
125
    transformed.image.save(output_path)
126
    if verbose and transformed.history:
127
        print(transformed.history[0])
128
129
def guess_type(string: str) -> Any:
130
    # Adapted from
131
    # https://www.reddit.com/r/learnpython/comments/4599hl/module_to_guess_type_from_a_string/czw3f5s
132
    string = string.replace(' ', '')
133
    try:
134
        value = ast.literal_eval(string)
135
    except ValueError:
136
        result_type = str
137
    else:
138
        result_type = type(value)
139
    if result_type in (list, tuple):
140
        string = string[1:-1]  # remove brackets
141
        split = string.split(',')
142
        list_result = [guess_type(n) for n in split]
143
        value = tuple(list_result) if result_type is tuple else list_result
144
        return value
145
    try:
146
        value = result_type(string)
147
    except TypeError:
148
        value = None
149
    return value
150
151
152
def get_rotation_and_spacing_from_affine(
153
        affine: np.ndarray,
154
        ) -> Tuple[np.ndarray, np.ndarray]:
155
    # From https://github.com/nipy/nibabel/blob/master/nibabel/orientations.py
156
    rotation_zoom = affine[:3, :3]
157
    spacing = np.sqrt(np.sum(rotation_zoom * rotation_zoom, axis=0))
158
    rotation = rotation_zoom / spacing
159
    return rotation, spacing
160
161
162
def nib_to_sitk(data: TypeData, affine: TypeData) -> sitk.Image:
163
    array = data.numpy() if isinstance(data, torch.Tensor) else data
164
    affine = affine.numpy() if isinstance(affine, torch.Tensor) else affine
165
    origin = np.dot(FLIP_XY, affine[:3, 3]).astype(np.float64)
166
    rotation, spacing = get_rotation_and_spacing_from_affine(affine)
167
    direction = np.dot(FLIP_XY, rotation)
168
    image = sitk.GetImageFromArray(array.transpose())
169
    if array.ndim == 2:  # ignore first dimension if 2D (1, 1, H, W)
170
        direction = direction[1:3, 1:3]
171
    image.SetOrigin(origin)
172
    image.SetSpacing(spacing)
173
    image.SetDirection(direction.flatten())
174
    return image
175
176
177
def sitk_to_nib(image: sitk.Image) -> Tuple[np.ndarray, np.ndarray]:
178
    data = sitk.GetArrayFromImage(image).transpose()
179
    spacing = np.array(image.GetSpacing())
180
    direction = np.array(image.GetDirection())
181
    origin = image.GetOrigin()
182
    if len(direction) == 9:
183
        rotation = direction.reshape(3, 3)
184
    elif len(direction) == 4:  # ignore first dimension if 2D (1, 1, H, W)
185
        rotation_2d = direction.reshape(2, 2)
186
        rotation = np.eye(3)
187
        rotation[1:3, 1:3] = rotation_2d
188
        spacing = 1, *spacing
189
        origin = 0, *origin
190
    rotation = np.dot(FLIP_XY, rotation)
0 ignored issues
show
introduced by
The variable rotation does not seem to be defined for all execution paths.
Loading history...
191
    rotation_zoom = rotation * spacing
192
    translation = np.dot(FLIP_XY, origin)
193
    affine = np.eye(4)
194
    affine[:3, :3] = rotation_zoom
195
    affine[:3, 3] = translation
196
    return data, affine
197
198
199
def get_torchio_cache_dir():
200
    return Path('~/.cache/torchio').expanduser()
201
202
203
def round_up(value: float) -> float:
204
    """Round half towards infinity.
205
206
    Args:
207
        value: The value to round.
208
209
    Example:
210
211
        >>> round(2.5)
212
        2
213
        >>> round(3.5)
214
        4
215
        >>> round_up(2.5)
216
        3
217
        >>> round_up(3.5)
218
        4
219
220
    """
221
    return np.floor(value + 0.5)
222
223
def get_transform(transform_name):
224
    from . import transforms
225
    return getattr(transforms, transform_name)
226