Passed
Pull Request — master (#332)
by Fernando
03:27 queued 02:14
created

torchio.visualization.rotate()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
import numpy as np
2
3
from .data.image import Image, LabelMap
4
from .data.subject import Subject
5
from .transforms.preprocessing.spatial.to_canonical import ToCanonical
6
7
8
def import_pyplot():
9
    try:
10
        import matplotlib.pyplot as plt
11
    except ImportError as e:
12
        raise ImportError('Install matplotlib for plotting support') from e
13
    return plt
14
15
16
def rotate(image):
17
    return np.rot90(image)
18
19
20
def plot_image(
21
        image: Image,
22
        channel=0,
23
        axes=None,
24
        show=True,
25
        cmap=None,
26
        ):
27
    plt = import_pyplot()
28
    if axes is None:
29
        _, axes = plt.subplots(1, 3)
30
    image = ToCanonical()(image)
31
    data = image.data[channel]
32
    indices = np.array(data.shape) // 2
33
    i, j, k = indices
34
    slice_x = rotate(data[i, :, :])
35
    slice_y = rotate(data[:, j, :])
36
    slice_z = rotate(data[:, :, k])
37
    kwargs = {}
38
    is_label = isinstance(image, LabelMap)
39
    if isinstance(cmap, dict):
40
        slices = slice_x, slice_y, slice_z
41
        slice_x, slice_y, slice_z = color_labels(slices, cmap)
42
    else:
43
        if cmap is None:
44
            cmap = 'inferno' if is_label else 'gray'
45
        kwargs['cmap'] = cmap
46
    if is_label:
47
        kwargs['interpolation'] = 'none'
48
    x_extent, y_extent, z_extent = [tuple(b) for b in image.bounds.T]
49
    axes[0].imshow(slice_x, extent=y_extent + z_extent, **kwargs)
50
    axes[1].imshow(slice_y, extent=x_extent + z_extent, **kwargs)
51
    axes[2].imshow(slice_z, extent=x_extent + y_extent, **kwargs)
52
    plt.tight_layout()
53
    if show:
54
        plt.show()
55
56
57
def plot_subject(
58
        subject: Subject,
59
        cmap_dict=None,
60
        ):
61
    plt = import_pyplot()
62
    _, axes = plt.subplots(len(subject), 3)
63
    iterable = enumerate(subject.get_images_dict(intensity_only=False).items())
64
    axes_names = 'sagittal', 'coronal', 'axial'
65
    for row, (name, image) in iterable:
66
        row_axes = axes[row]
67
        cmap = None
68
        if cmap_dict is not None and name in cmap_dict:
69
            cmap = cmap_dict[name]
70
        plot_image(image, axes=row_axes, show=False, cmap=cmap)
71
        for axis, axis_name in zip(row_axes, axes_names):
72
            axis.set_title(f'{name} ({axis_name})')
73
    plt.tight_layout()
74
    plt.show()
75
76
77
def color_labels(arrays, cmap_dict):
78
    results = []
79
    for array in arrays:
80
        si, sj = array.shape
81
        rgb = np.zeros((si, sj, 3), dtype=np.uint8)
82
        for label, value in cmap_dict.items():
83
            rgb[array == label] = value
84
        results.append(rgb)
85
    return results
86