Completed
Pull Request — master (#353)
by Fernando
118:39 queued 117:31
created

torchio.utils.sitk_to_nib()   B

Complexity

Conditions 6

Size

Total Lines 35
Code Lines 33

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 6
eloc 33
nop 2
dl 0
loc 35
rs 8.1546
c 0
b 0
f 0
1
import ast
2
import gzip
3
import shutil
4
import tempfile
5
from pathlib import Path
6
from typing import Union, Iterable, Tuple, Any, Optional, List, Sequence
7
8
from torch.utils.data._utils.collate import default_collate
9
import numpy as np
10
import nibabel as nib
11
import SimpleITK as sitk
12
from tqdm import trange
13
14
from .torchio import (
15
    INTENSITY,
16
    TypeData,
17
    TypeNumber,
18
    TypePath,
19
    REPO_URL,
20
)
21
22
23
FLIP_XY = np.diag((-1, -1, 1))  # used to switch between LPS and RAS
24
25
26
def to_tuple(
27
        value: Union[TypeNumber, Iterable[TypeNumber]],
28
        length: int = 1,
29
        ) -> Tuple[TypeNumber, ...]:
30
    """
31
    to_tuple(1, length=1) -> (1,)
32
    to_tuple(1, length=3) -> (1, 1, 1)
33
34
    If value is an iterable, n is ignored and tuple(value) is returned
35
    to_tuple((1,), length=1) -> (1,)
36
    to_tuple((1, 2), length=1) -> (1, 2)
37
    to_tuple([1, 2], length=3) -> (1, 2)
38
    """
39
    try:
40
        iter(value)
41
        value = tuple(value)
42
    except TypeError:
43
        value = length * (value,)
44
    return value
45
46
47
def get_stem(
48
        path: Union[TypePath, List[TypePath]]
49
        ) -> Union[str, List[str]]:
50
    """
51
    '/home/user/image.nii.gz' -> 'image'
52
    """
53
    def _get_stem(path_string):
54
        return Path(path_string).name.split('.')[0]
55
    if isinstance(path, (str, Path)):
56
        return _get_stem(path)
57
    return [_get_stem(p) for p in path]
58
59
60
def create_dummy_dataset(
61
        num_images: int,
62
        size_range: Tuple[int, int],
63
        directory: Optional[TypePath] = None,
64
        suffix: str = '.nii.gz',
65
        force: bool = False,
66
        verbose: bool = False,
67
        ):
68
    from .data import ScalarImage, LabelMap, Subject
69
    output_dir = tempfile.gettempdir() if directory is None else directory
70
    output_dir = Path(output_dir)
71
    images_dir = output_dir / 'dummy_images'
72
    labels_dir = output_dir / 'dummy_labels'
73
74
    if force:
75
        shutil.rmtree(images_dir)
76
        shutil.rmtree(labels_dir)
77
78
    subjects: List[Subject] = []
79
    if images_dir.is_dir():
80
        for i in trange(num_images):
81
            image_path = images_dir / f'image_{i}{suffix}'
82
            label_path = labels_dir / f'label_{i}{suffix}'
83
            subject = Subject(
84
                one_modality=ScalarImage(image_path),
85
                segmentation=LabelMap(label_path),
86
            )
87
            subjects.append(subject)
88
    else:
89
        images_dir.mkdir(exist_ok=True, parents=True)
90
        labels_dir.mkdir(exist_ok=True, parents=True)
91
        if verbose:
92
            print('Creating dummy dataset...')  # noqa: T001
93
            iterable = trange(num_images)
94
        else:
95
            iterable = range(num_images)
96
        for i in iterable:
97
            shape = np.random.randint(*size_range, size=3)
98
            affine = np.eye(4)
99
            image = np.random.rand(*shape)
100
            label = np.ones_like(image)
101
            label[image < 0.33] = 0
102
            label[image > 0.66] = 2
103
            image *= 255
104
105
            image_path = images_dir / f'image_{i}{suffix}'
106
            nii = nib.Nifti1Image(image.astype(np.uint8), affine)
107
            nii.to_filename(str(image_path))
108
109
            label_path = labels_dir / f'label_{i}{suffix}'
110
            nii = nib.Nifti1Image(label.astype(np.uint8), affine)
111
            nii.to_filename(str(label_path))
112
113
            subject = Subject(
114
                one_modality=ScalarImage(image_path),
115
                segmentation=LabelMap(label_path),
116
            )
117
            subjects.append(subject)
118
    return subjects
119
120
121
def apply_transform_to_file(
122
        input_path: TypePath,
123
        transform,  # : Transform seems to create a circular import
124
        output_path: TypePath,
125
        type: str = INTENSITY,  # noqa: A002
126
        verbose: bool = False,
127
        ):
128
    from . import Image, Subject
129
    subject = Subject(image=Image(input_path, type=type))
130
    transformed = transform(subject)
131
    transformed.image.save(output_path)
132
    if verbose and transformed.history:
133
        print('Applied transform:', transformed.history[0])  # noqa: T001
134
135
136
def guess_type(string: str) -> Any:
137
    # Adapted from
138
    # https://www.reddit.com/r/learnpython/comments/4599hl/module_to_guess_type_from_a_string/czw3f5s
139
    string = string.replace(' ', '')
140
    try:
141
        value = ast.literal_eval(string)
142
    except ValueError:
143
        result_type = str
144
    else:
145
        result_type = type(value)
146
    if result_type in (list, tuple):
147
        string = string[1:-1]  # remove brackets
148
        split = string.split(',')
149
        list_result = [guess_type(n) for n in split]
150
        value = tuple(list_result) if result_type is tuple else list_result
151
        return value
152
    try:
153
        value = result_type(string)
154
    except TypeError:
155
        value = None
156
    return value
157
158
159
def get_rotation_and_spacing_from_affine(
160
        affine: np.ndarray,
161
        ) -> Tuple[np.ndarray, np.ndarray]:
162
    # From https://github.com/nipy/nibabel/blob/master/nibabel/orientations.py
163
    rotation_zoom = affine[:3, :3]
164
    spacing = np.sqrt(np.sum(rotation_zoom * rotation_zoom, axis=0))
165
    rotation = rotation_zoom / spacing
166
    return rotation, spacing
167
168
169
def nib_to_sitk(
170
        data: TypeData,
171
        affine: TypeData,
172
        squeeze: bool = False,
173
        force_3d: bool = False,
174
        force_4d: bool = False,
175
        ) -> sitk.Image:
176
    """Create a SimpleITK image from a tensor and a 4x4 affine matrix."""
177
    if data.ndim != 4:
178
        raise ValueError(f'Input must be 4D, but has shape {tuple(data.shape)}')
179
    # Possibilities
180
    # (1, w, h, 1)
181
    # (c, w, h, 1)
182
    # (1, w, h, 1)
183
    # (c, w, h, d)
184
    array = np.asarray(data)
185
    affine = np.asarray(affine).astype(np.float64)
186
187
    is_multichannel = array.shape[0] > 1 and not force_4d
188
    is_2d = array.shape[3] == 1 and not force_3d
189
    if is_2d:
190
        array = array[..., 0]
191
    if not is_multichannel and not force_4d:
192
        array = array[0]
193
    array = array.transpose()  # (W, H, D, C) or (W, H, D)
194
    image = sitk.GetImageFromArray(array, isVector=is_multichannel)
195
196
    rotation, spacing = get_rotation_and_spacing_from_affine(affine)
197
    origin = np.dot(FLIP_XY, affine[:3, 3])
198
    direction = np.dot(FLIP_XY, rotation)
199
    if is_2d:  # ignore first dimension if 2D (1, W, H, 1)
200
        direction = direction[:2, :2]
201
    image.SetOrigin(origin)  # should I add a 4th value if force_4d?
202
    image.SetSpacing(spacing)
203
    image.SetDirection(direction.flatten())
204
    if data.ndim == 4:
205
        assert image.GetNumberOfComponentsPerPixel() == data.shape[0]
206
    num_spatial_dims = 2 if is_2d else 3
207
    assert image.GetSize() == data.shape[1: 1 + num_spatial_dims]
208
    return image
209
210
211
def sitk_to_nib(
212
        image: sitk.Image,
213
        keepdim: bool = False,
214
        ) -> Tuple[np.ndarray, np.ndarray]:
215
    data = sitk.GetArrayFromImage(image).transpose()
216
    num_components = image.GetNumberOfComponentsPerPixel()
217
    if num_components == 1:
218
        data = data[np.newaxis]  # add channels dimension
219
    input_spatial_dims = image.GetDimension()
220
    if input_spatial_dims == 2:
221
        data = data[..., np.newaxis]
222
    if not keepdim:
223
        data = ensure_4d(data, num_spatial_dims=input_spatial_dims)
224
    assert data.shape[0] == num_components
225
    assert data.shape[1: 1 + input_spatial_dims] == image.GetSize()
226
    spacing = np.array(image.GetSpacing())
227
    direction = np.array(image.GetDirection())
228
    origin = image.GetOrigin()
229
    if len(direction) == 9:
230
        rotation = direction.reshape(3, 3)
231
    elif len(direction) == 4:  # ignore first dimension if 2D (1, W, H, 1)
232
        rotation_2d = direction.reshape(2, 2)
233
        rotation = np.eye(3)
234
        rotation[:2, :2] = rotation_2d
235
        spacing = *spacing, 1
236
        origin = *origin, 0
237
    else:
238
        raise RuntimeError(f'Direction not understood: {direction}')
239
    rotation = np.dot(FLIP_XY, rotation)
240
    rotation_zoom = rotation * spacing
241
    translation = np.dot(FLIP_XY, origin)
242
    affine = np.eye(4)
243
    affine[:3, :3] = rotation_zoom
244
    affine[:3, 3] = translation
245
    return data, affine
246
247
248
def ensure_4d(tensor: TypeData, num_spatial_dims=None) -> TypeData:
249
    # I wish named tensors were properly supported in PyTorch
250
    num_dimensions = tensor.ndim
251
    if num_dimensions == 4:
252
        pass
253
    elif num_dimensions == 5:  # hope (W, H, D, 1, C)
254
        if tensor.shape[-2] == 1:
255
            tensor = tensor[..., 0, :]
256
            tensor = tensor.permute(3, 0, 1, 2)
257
        else:
258
            raise ValueError('5D is not supported for shape[-2] > 1')
259
    elif num_dimensions == 2:  # assume 2D monochannel (W, H)
260
        tensor = tensor[np.newaxis, ..., np.newaxis]  # (1, W, H, 1)
261
    elif num_dimensions == 3:  # 2D multichannel or 3D monochannel?
262
        if num_spatial_dims == 2:
263
            tensor = tensor[..., np.newaxis]  # (C, W, H, 1)
264
        elif num_spatial_dims == 3:  # (W, H, D)
265
            tensor = tensor[np.newaxis]  # (1, W, H, D)
266
        else:  # try to guess
267
            shape = tensor.shape
268
            maybe_rgb = 3 in (shape[0], shape[-1])
269
            if maybe_rgb:
270
                if shape[-1] == 3:  # (W, H, 3)
271
                    tensor = tensor.permute(2, 0, 1)  # (3, W, H)
272
                tensor = tensor[..., np.newaxis]  # (3, W, H, 1)
273
            else:  # (W, H, D)
274
                tensor = tensor[np.newaxis]  # (1, W, H, D)
275
    else:
276
        message = (
277
            f'{num_dimensions}D images not supported yet. Please create an'
278
            f' issue in {REPO_URL} if you would like support for them'
279
        )
280
        raise ValueError(message)
281
    assert tensor.ndim == 4
282
    return tensor
283
284
285
def get_torchio_cache_dir():
286
    return Path('~/.cache/torchio').expanduser()
287
288
289
def round_up(value: float) -> int:
290
    """Round half towards infinity.
291
292
    Args:
293
        value: The value to round.
294
295
    Example:
296
297
        >>> round(2.5)
298
        2
299
        >>> round(3.5)
300
        4
301
        >>> round_up(2.5)
302
        3
303
        >>> round_up(3.5)
304
        4
305
306
    """
307
    return int(np.floor(value + 0.5))
308
309
310
def compress(input_path, output_path):
311
    with open(input_path, 'rb') as f_in:
312
        with gzip.open(output_path, 'wb') as f_out:
313
            shutil.copyfileobj(f_in, f_out)
314
315
316
def check_sequence(sequence: Sequence, name: str):
317
    try:
318
        iter(sequence)
319
    except TypeError:
320
        message = f'"{name}" must be a sequence, not {type(name)}'
321
        raise TypeError(message)
322
323
324
def get_major_sitk_version() -> int:
325
    # This attribute was added in version 2
326
    # https://github.com/SimpleITK/SimpleITK/pull/1171
327
    version = getattr(sitk, '__version__', None)
328
    major_version = 1 if version is None else 2
329
    return major_version
330
331
332
def history_collate(batch: Sequence, collate_transforms=True):
333
    attr = 'history' if collate_transforms else 'applied_transforms'
334
    # Adapted from
335
    # https://github.com/romainVala/torchQC/blob/master/segmentation/collate_functions.py
336
    from .data import Subject
337
    first_element = batch[0]
338
    if isinstance(first_element, Subject):
339
        dictionary = {
340
            key: default_collate([d[key] for d in batch])
341
            for key in first_element
342
        }
343
        if hasattr(first_element, attr):
344
            dictionary.update({attr: [getattr(d, attr) for d in batch]})
345
        return dictionary
346