Passed
Pull Request — master (#246)
by Fernando
01:19
created

torchio.utils.ensure_4d()   D

Complexity

Conditions 12

Size

Total Lines 50
Code Lines 32

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 12
eloc 32
nop 3
dl 0
loc 50
rs 4.8
c 0
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like torchio.utils.ensure_4d() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import ast
2
import shutil
3
import tempfile
4
from pathlib import Path
5
from typing import Union, Iterable, Tuple, Any, Optional, List
6
7
import torch
8
import numpy as np
9
import nibabel as nib
10
import SimpleITK as sitk
11
from tqdm import trange
12
from .torchio import (
13
    INTENSITY,
14
    LABEL,
15
    TypeData,
16
    TypeNumber,
17
    TypePath,
18
    REPO_URL,
19
)
20
21
22
FLIP_XY = np.diag((-1, -1, 1))  # used to switch between LPS and RAS
23
24
25
def to_tuple(
26
        value: Union[TypeNumber, Iterable[TypeNumber]],
27
        length: int = 1,
28
        ) -> Tuple[TypeNumber, ...]:
29
    """
30
    to_tuple(1, length=1) -> (1,)
31
    to_tuple(1, length=3) -> (1, 1, 1)
32
33
    If value is an iterable, n is ignored and tuple(value) is returned
34
    to_tuple((1,), length=1) -> (1,)
35
    to_tuple((1, 2), length=1) -> (1, 2)
36
    to_tuple([1, 2], length=3) -> (1, 2)
37
    """
38
    try:
39
        iter(value)
40
        value = tuple(value)
41
    except TypeError:
42
        value = length * (value,)
43
    return value
44
45
46
def get_stem(path: TypePath) -> str:
47
    """
48
    '/home/user/image.nii.gz' -> 'image'
49
    """
50
    path = Path(path)
51
    return path.name.split('.')[0]
52
53
54
def create_dummy_dataset(
55
        num_images: int,
56
        size_range: Tuple[int, int],
57
        directory: Optional[TypePath] = None,
58
        suffix: str = '.nii.gz',
59
        force: bool = False,
60
        verbose: bool = False,
61
        ):
62
    from .data import Image, Subject
63
    output_dir = tempfile.gettempdir() if directory is None else directory
64
    output_dir = Path(output_dir)
65
    images_dir = output_dir / 'dummy_images'
66
    labels_dir = output_dir / 'dummy_labels'
67
68
    if force:
69
        shutil.rmtree(images_dir)
70
        shutil.rmtree(labels_dir)
71
72
    subjects: List[Subject] = []
73
    if images_dir.is_dir():
74
        for i in trange(num_images):
75
            image_path = images_dir / f'image_{i}{suffix}'
76
            label_path = labels_dir / f'label_{i}{suffix}'
77
            subject = Subject(
78
                one_modality=Image(image_path, INTENSITY),
79
                segmentation=Image(label_path, LABEL),
80
            )
81
            subjects.append(subject)
82
    else:
83
        images_dir.mkdir(exist_ok=True, parents=True)
84
        labels_dir.mkdir(exist_ok=True, parents=True)
85
        if verbose:
86
            print('Creating dummy dataset...')
87
            iterable = trange(num_images)
88
        else:
89
            iterable = range(num_images)
90
        for i in iterable:
91
            shape = np.random.randint(*size_range, size=3)
92
            affine = np.eye(4)
93
            image = np.random.rand(*shape)
94
            label = np.ones_like(image)
95
            label[image < 0.33] = 0
96
            label[image > 0.66] = 2
97
            image *= 255
98
99
            image_path = images_dir / f'image_{i}{suffix}'
100
            nii = nib.Nifti1Image(image.astype(np.uint8), affine)
101
            nii.to_filename(str(image_path))
102
103
            label_path = labels_dir / f'label_{i}{suffix}'
104
            nii = nib.Nifti1Image(label.astype(np.uint8), affine)
105
            nii.to_filename(str(label_path))
106
107
            subject = Subject(
108
                one_modality=Image(image_path, INTENSITY),
109
                segmentation=Image(label_path, LABEL),
110
            )
111
            subjects.append(subject)
112
    return subjects
113
114
115
def apply_transform_to_file(
116
        input_path: TypePath,
117
        transform,  # : Transform seems to create a circular import (TODO)
118
        output_path: TypePath,
119
        type: str = INTENSITY,
120
        verbose: bool = False,
121
        ):
122
    from . import Image, ImagesDataset, Subject
123
    subject = Subject(image=Image(input_path, type))
124
    transformed = transform(subject)
125
    transformed.image.save(output_path)
126
    if verbose and transformed.history:
127
        print(transformed.history[0])
128
129
def guess_type(string: str) -> Any:
130
    # Adapted from
131
    # https://www.reddit.com/r/learnpython/comments/4599hl/module_to_guess_type_from_a_string/czw3f5s
132
    string = string.replace(' ', '')
133
    try:
134
        value = ast.literal_eval(string)
135
    except ValueError:
136
        result_type = str
137
    else:
138
        result_type = type(value)
139
    if result_type in (list, tuple):
140
        string = string[1:-1]  # remove brackets
141
        split = string.split(',')
142
        list_result = [guess_type(n) for n in split]
143
        value = tuple(list_result) if result_type is tuple else list_result
144
        return value
145
    try:
146
        value = result_type(string)
147
    except TypeError:
148
        value = None
149
    return value
150
151
152
def get_rotation_and_spacing_from_affine(
153
        affine: np.ndarray,
154
        ) -> Tuple[np.ndarray, np.ndarray]:
155
    # From https://github.com/nipy/nibabel/blob/master/nibabel/orientations.py
156
    rotation_zoom = affine[:3, :3]
157
    spacing = np.sqrt(np.sum(rotation_zoom * rotation_zoom, axis=0))
158
    rotation = rotation_zoom / spacing
159
    return rotation, spacing
160
161
162
def nib_to_sitk(
163
        data: TypeData,
164
        affine: TypeData,
165
        squeeze: bool = False,
166
        force_3d: bool = False,
167
        force_4d: bool = False,
168
        ) -> sitk.Image:
169
    """Create a SimpleITK image from a tensor and a 4x4 affine matrix.
170
171
    Args:
172
        data: PyTorch tensor or NumPy array
173
        affine: # TODO
174
    """
175
    if data.ndim != 4:
176
        raise ValueError(f'Input must be 4D, but has shape {tuple(data.shape)}')
177
    # Possibilities
178
    # (1, 1, h, w)
179
    # (c, 1, h, w)
180
    # (1, d, h, w)
181
    # (c, d, h, w)
182
    array = np.asarray(data)
183
    affine = np.asarray(affine).astype(np.float64)
184
185
    is_multichannel = array.shape[0] > 1 and not force_4d
186
    is_2d = array.shape[1] == 1 and not force_3d
187
    if is_2d:
188
        array = array[:, 0, :, :]
189
    if not is_multichannel and not force_4d:
190
        array = array[0]
191
    array = array.transpose()  # (W, H, D, C) or (W, H, D)
192
    image = sitk.GetImageFromArray(array, isVector=is_multichannel)
193
194
    rotation, spacing = get_rotation_and_spacing_from_affine(affine)
195
    origin = np.dot(FLIP_XY, affine[:3, 3])
196
    direction = np.dot(FLIP_XY, rotation)
197
    if is_2d:  # ignore first dimension if 2D (1, 1, H, W)
198
        direction = direction[1:3, 1:3]
199
    image.SetOrigin(origin)  # should I add a 4th value if force_4d?
200
    image.SetSpacing(spacing)
201
    image.SetDirection(direction.flatten())
202
    if data.ndim == 4:
203
        assert image.GetNumberOfComponentsPerPixel() == data.shape[0]
204
    num_spatial_dims = 2 if is_2d else 3
205
    assert image.GetSize() == data.shape[-num_spatial_dims:]
206
    return image
207
208
209
def sitk_to_nib(
210
        image: sitk.Image,
211
        keepdim: bool = False,
212
        ) -> Tuple[np.ndarray, np.ndarray]:
213
    """[summary]
214
215
    Args:
216
        image (sitk.Image): [description]
217
        keepdim (bool, optional): [description]. Defaults to False.
218
219
    Returns:
220
        Tuple[np.ndarray, np.ndarray]: [description]
221
    """
222
    data = sitk.GetArrayFromImage(image).transpose()
223
    num_components = image.GetNumberOfComponentsPerPixel()
224
    if num_components == 1:
225
        data = data[np.newaxis]  # add channels dimension
226
    input_spatial_dims = image.GetDimension()
227
    if not keepdim:
228
        data = ensure_4d(data, False, num_spatial_dims=input_spatial_dims)
229
    assert data.shape[0] == num_components
230
    assert data.shape[-input_spatial_dims:] == image.GetSize()
231
    spacing = np.array(image.GetSpacing())
232
    direction = np.array(image.GetDirection())
233
    origin = image.GetOrigin()
234
    if len(direction) == 9:
235
        rotation = direction.reshape(3, 3)
236
    elif len(direction) == 4:  # ignore first dimension if 2D (1, 1, H, W)
237
        rotation_2d = direction.reshape(2, 2)
238
        rotation = np.eye(3)
239
        rotation[1:3, 1:3] = rotation_2d
240
        spacing = 1, *spacing
241
        origin = 0, *origin
242
    rotation = np.dot(FLIP_XY, rotation)
0 ignored issues
show
introduced by
The variable rotation does not seem to be defined for all execution paths.
Loading history...
243
    rotation_zoom = rotation * spacing
244
    translation = np.dot(FLIP_XY, origin)
245
    affine = np.eye(4)
246
    affine[:3, :3] = rotation_zoom
247
    affine[:3, 3] = translation
248
    return data, affine
249
250
251
def ensure_4d(
252
        tensor: TypeData,
253
        channels_last: bool,
254
        num_spatial_dims=None,
255
        ) -> TypeData:
256
    """[summary] # TODO
257
258
    Args:
259
        tensor: [description].
260
        channels_last: If ``True``, last dimension of the input represents
261
            channels.
262
        num_spatial_dims: [description].
263
264
    Raises:
265
        ValueError: [description]
266
    """
267
    # I wish named tensors were properly supported in PyTorch
268
    num_dimensions = tensor.ndim
269
    if num_dimensions == 5:  # hope (X, X, X, 1, X)
270
        if tensor.shape[-1] == 1:
271
            tensor = tensor[..., 0, :]
272
    if num_dimensions == 4:  # assume 3D multichannel
273
        if channels_last:  # (D, H, W, C)
274
            tensor = tensor.permute(3, 0, 1, 2)  # (C, D, H, W)
275
    elif num_dimensions == 2:  # assume 2D monochannel (H, W)
276
        tensor = tensor[np.newaxis, np.newaxis]  # (1, 1, H, W)
277
    elif num_dimensions == 3:  # 2D multichannel or 3D monochannel?
278
        if num_spatial_dims == 2:
279
            if channels_last:  # (H, W, C)
280
                tensor = tensor.permute(2, 0, 1)  # (C, H, W)
281
            tensor = tensor[:, np.newaxis]  # (C, 1, H, W)
282
        elif num_spatial_dims == 3:  # (D, H, W)
283
            tensor = tensor[np.newaxis]  # (1, D, H, W)
284
        else:  # try to guess
285
            shape = tensor.shape
286
            maybe_rgb = 3 in (shape[0], shape[-1])
287
            if maybe_rgb:
288
                if shape[-1] == 3:  # (H, W, 3)
289
                    tensor = tensor.permute(2, 0, 1)  # (3, H, W)
290
                tensor = tensor[:, np.newaxis]  # (3, 1, H, W)
291
            else:  # (D, H, W)
292
                tensor = tensor[np.newaxis]  # (1, D, H, W)
293
    else:
294
        message = (
295
            f'{num_dimensions}D images not supported yet. Please create an'
296
            f' issue in {REPO_URL} if you would like support for them'
297
        )
298
        raise ValueError(message)
299
    assert tensor.ndim == 4
300
    return tensor
301
302
303
def get_torchio_cache_dir():
304
    return Path('~/.cache/torchio').expanduser()
305
306
307
def round_up(value: float) -> float:
308
    """Round half towards infinity.
309
310
    Args:
311
        value: The value to round.
312
313
    Example:
314
315
        >>> round(2.5)
316
        2
317
        >>> round(3.5)
318
        4
319
        >>> round_up(2.5)
320
        3
321
        >>> round_up(3.5)
322
        4
323
324
    """
325
    return np.floor(value + 0.5)
326