Passed
Push — master ( 9c7ee9...0b0a34 )
by Fernando
01:06
created

torchio.utils   A

Complexity

Total Complexity 30

Size/Duplication

Total Lines 267
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 187
dl 0
loc 267
rs 10
c 0
b 0
f 0
wmc 30

12 Functions

Rating   Name   Duplication   Size   Complexity  
A to_tuple() 0 19 2
A get_torchio_cache_dir() 0 2 1
B guess_type() 0 21 6
B create_dummy_dataset() 0 59 7
A round_up() 0 19 1
A sitk_to_nib() 0 20 3
A get_colin_subject() 0 23 2
A get_stem() 0 6 1
A apply_transform_to_file() 0 11 1
A nib_to_sitk() 0 13 4
A get_slicer_mrhead_subject() 0 16 1
A get_rotation_and_spacing_from_affine() 0 8 1
1
import ast
2
import shutil
3
import tempfile
4
from pathlib import Path
5
from typing import Union, Iterable, Tuple, Any, Optional, List
6
7
import torch
8
import numpy as np
9
import nibabel as nib
10
import SimpleITK as sitk
11
from tqdm import trange
12
from torchvision.datasets.utils import (
13
    download_url,
14
    download_and_extract_archive,
15
)
16
from .torchio import (
17
    INTENSITY,
18
    LABEL,
19
    TypeData,
20
    TypeNumber,
21
    TypePath,
22
)
23
24
25
FLIP_XY = np.diag((-1, -1, 1))  # used to switch between LPS and RAS
26
27
28
def to_tuple(
29
        value: Union[TypeNumber, Iterable[TypeNumber]],
30
        length: int = 1,
31
        ) -> Tuple[TypeNumber, ...]:
32
    """
33
    to_tuple(1, length=1) -> (1,)
34
    to_tuple(1, length=3) -> (1, 1, 1)
35
36
    If value is an iterable, n is ignored and tuple(value) is returned
37
    to_tuple((1,), length=1) -> (1,)
38
    to_tuple((1, 2), length=1) -> (1, 2)
39
    to_tuple([1, 2], length=3) -> (1, 2)
40
    """
41
    try:
42
        iter(value)
43
        value = tuple(value)
44
    except TypeError:
45
        value = length * (value,)
46
    return value
47
48
49
def get_stem(path: TypePath) -> str:
50
    """
51
    '/home/user/image.nii.gz' -> 'image'
52
    """
53
    path = Path(path)
54
    return path.name.split('.')[0]
55
56
57
def create_dummy_dataset(
58
        num_images: int,
59
        size_range: Tuple[int, int],
60
        directory: Optional[TypePath] = None,
61
        suffix: str = '.nii.gz',
62
        force: bool = False,
63
        verbose: bool = False,
64
        ):
65
    from .data import Image, Subject
66
    output_dir = tempfile.gettempdir() if directory is None else directory
67
    output_dir = Path(output_dir)
68
    images_dir = output_dir / 'dummy_images'
69
    labels_dir = output_dir / 'dummy_labels'
70
71
    if force:
72
        shutil.rmtree(images_dir)
73
        shutil.rmtree(labels_dir)
74
75
    subjects: List[Subject] = []
76
    if images_dir.is_dir():
77
        for i in trange(num_images):
78
            image_path = images_dir / f'image_{i}{suffix}'
79
            label_path = labels_dir / f'label_{i}{suffix}'
80
            subject = Subject(
81
                one_modality=Image(image_path, INTENSITY),
82
                segmentation=Image(label_path, LABEL),
83
            )
84
            subjects.append(subject)
85
    else:
86
        images_dir.mkdir(exist_ok=True, parents=True)
87
        labels_dir.mkdir(exist_ok=True, parents=True)
88
        if verbose:
89
            print('Creating dummy dataset...')
90
            iterable = trange(num_images)
91
        else:
92
            iterable = range(num_images)
93
        for i in iterable:
94
            shape = np.random.randint(*size_range, size=3)
95
            affine = np.eye(4)
96
            image = np.random.rand(*shape)
97
            label = np.ones_like(image)
98
            label[image < 0.33] = 0
99
            label[image > 0.66] = 2
100
            image *= 255
101
102
            image_path = images_dir / f'image_{i}{suffix}'
103
            nii = nib.Nifti1Image(image.astype(np.uint8), affine)
104
            nii.to_filename(str(image_path))
105
106
            label_path = labels_dir / f'label_{i}{suffix}'
107
            nii = nib.Nifti1Image(label.astype(np.uint8), affine)
108
            nii.to_filename(str(label_path))
109
110
            subject = Subject(
111
                one_modality=Image(image_path, INTENSITY),
112
                segmentation=Image(label_path, LABEL),
113
            )
114
            subjects.append(subject)
115
    return subjects
116
117
118
def apply_transform_to_file(
119
        input_path: TypePath,
120
        transform,  # : Transform seems to create a circular import (TODO)
121
        output_path: TypePath,
122
        type: str = INTENSITY,
123
        ):
124
    from . import Image, ImagesDataset, Subject
125
    subject = Subject(image=Image(input_path, type))
126
    dataset = ImagesDataset([subject], transform=transform)
127
    transformed = dataset[0]
128
    dataset.save_sample(transformed, dict(image=output_path))
129
130
131
def guess_type(string: str) -> Any:
132
    # Adapted from
133
    # https://www.reddit.com/r/learnpython/comments/4599hl/module_to_guess_type_from_a_string/czw3f5s
134
    string = string.replace(' ', '')
135
    try:
136
        value = ast.literal_eval(string)
137
    except ValueError:
138
        result_type = str
139
    else:
140
        result_type = type(value)
141
    if result_type in (list, tuple):
142
        string = string[1:-1]  # remove brackets
143
        split = string.split(',')
144
        list_result = [guess_type(n) for n in split]
145
        value = tuple(list_result) if result_type is tuple else list_result
146
        return value
147
    try:
148
        value = result_type(string)
149
    except TypeError:
150
        value = None
151
    return value
152
153
154
def get_rotation_and_spacing_from_affine(
155
        affine: np.ndarray,
156
        ) -> Tuple[np.ndarray, np.ndarray]:
157
    # From https://github.com/nipy/nibabel/blob/master/nibabel/orientations.py
158
    rotation_zoom = affine[:3, :3]
159
    spacing = np.sqrt(np.sum(rotation_zoom * rotation_zoom, axis=0))
160
    rotation = rotation_zoom / spacing
161
    return rotation, spacing
162
163
164
def nib_to_sitk(data: TypeData, affine: TypeData) -> sitk.Image:
165
    array = data.numpy() if isinstance(data, torch.Tensor) else data
166
    affine = affine.numpy() if isinstance(affine, torch.Tensor) else affine
167
    origin = np.dot(FLIP_XY, affine[:3, 3]).astype(np.float64)
168
    rotation, spacing = get_rotation_and_spacing_from_affine(affine)
169
    direction = np.dot(FLIP_XY, rotation)
170
    image = sitk.GetImageFromArray(array.transpose())
171
    if array.ndim == 2:  # ignore first dimension if 2D (1, 1, H, W)
172
        direction = direction[1:3, 1:3]
173
    image.SetOrigin(origin)
174
    image.SetSpacing(spacing)
175
    image.SetDirection(direction.flatten())
176
    return image
177
178
179
def sitk_to_nib(image: sitk.Image) -> Tuple[np.ndarray, np.ndarray]:
180
    data = sitk.GetArrayFromImage(image).transpose()
181
    spacing = np.array(image.GetSpacing())
182
    direction = np.array(image.GetDirection())
183
    origin = image.GetOrigin()
184
    if len(direction) == 9:
185
        rotation = direction.reshape(3, 3)
186
    elif len(direction) == 4:  # ignore first dimension if 2D (1, 1, H, W)
187
        rotation_2d = direction.reshape(2, 2)
188
        rotation = np.eye(3)
189
        rotation[1:3, 1:3] = rotation_2d
190
        spacing = 1, *spacing
191
        origin = 0, *origin
192
    rotation = np.dot(FLIP_XY, rotation)
0 ignored issues
show
introduced by
The variable rotation does not seem to be defined for all execution paths.
Loading history...
193
    rotation_zoom = rotation * spacing
194
    translation = np.dot(FLIP_XY, origin)
195
    affine = np.eye(4)
196
    affine[:3, :3] = rotation_zoom
197
    affine[:3, 3] = translation
198
    return data, affine
199
200
201
def get_torchio_cache_dir():
202
    return Path('~/.cache/torchio').expanduser()
203
204
205
def get_colin_subject():
206
    from .data import Image, Subject
207
    url = 'http://packages.bic.mni.mcgill.ca/mni-models/colin27/mni_colin27_1998_nifti.zip'
208
    download_root = get_torchio_cache_dir() / 'colin27'
209
    if download_root.is_dir():
210
        print(f'Using cache found in {download_root}')
211
    else:
212
        filename = 'mni_colin27_1998_nifti.zip'
213
        download_and_extract_archive(
214
            url,
215
            download_root=download_root,
216
            filename=filename,
217
        )
218
    t1, head, mask = [
219
        download_root / f'colin27_t1_tal_lin{suffix}.nii'
220
        for suffix in ('', '_headmask', '_mask')
221
    ]
222
    subject = Subject(
223
        t1=Image(t1),
224
        head=Image(head, type=LABEL),
225
        brain=Image(mask, type=LABEL),
226
    )
227
    return subject
228
229
230
def get_slicer_mrhead_subject():
231
    from .data import Image, Subject
232
    slicer_data_url = 'https://github.com/Slicer/SlicerTestingData/releases/download/'
233
    name = 'SHA256/cc211f0dfd9a05ca3841ce1141b292898b2dd2d3f08286affadf823a7e58df93'
234
    url = slicer_data_url + name
235
    download_root = get_torchio_cache_dir() / 'slicer'
236
    filename = 'MRHead.nrrd'
237
    download_url(
238
        url,
239
        download_root,
240
        filename=filename,
241
    )
242
    subject = Subject(
243
        t1=Image(download_root / filename),
244
    )
245
    return subject
246
247
248
def round_up(value: float) -> float:
249
    """Round half towards infinity.
250
251
    Args:
252
        value: The value to round.
253
254
    Example:
255
256
        >>> round(2.5)
257
        2
258
        >>> round(3.5)
259
        4
260
        >>> round_up(2.5)
261
        3
262
        >>> round_up(3.5)
263
        4
264
265
    """
266
    return np.floor(value + 0.5)
267