Passed
Push — master ( 9fff8c...fc33ea )
by Fernando
01:18
created

torchio.visualization.plot_subject()   C

Complexity

Conditions 9

Size

Total Lines 45
Code Lines 41

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 41
dl 0
loc 45
rs 6.5626
c 0
b 0
f 0
cc 9
nop 7
1
import numpy as np
2
3
from .data.image import Image, LabelMap
4
from .data.subject import Subject
5
from .transforms.preprocessing.spatial.to_canonical import ToCanonical
6
7
8
def import_mpl_plt():
9
    try:
10
        import matplotlib as mpl
11
        import matplotlib.pyplot as plt
12
    except ImportError as e:
13
        raise ImportError('Install matplotlib for plotting support') from e
14
    return mpl, plt
15
16
17
def rotate(image, radiological=True):
18
    # Rotate for visualization purposes
19
    image = np.rot90(image, -1)
20
    if radiological:
21
        image = np.fliplr(image)
22
    return image
23
24
25
def plot_volume(
26
        image: Image,
27
        radiological=True,
28
        channel=-1,  # default to foreground for binary maps
29
        axes=None,
30
        cmap=None,
31
        output_path=None,
32
        show=True,
33
        xlabels=True,
34
        ):
35
    _, plt = import_mpl_plt()
36
    fig = None
37
    if axes is None:
38
        fig, axes = plt.subplots(1, 3)
39
    sag_axis, cor_axis, axi_axis = axes
40
41
    image = ToCanonical()(image)
42
    data = image.data[channel]
43
    indices = np.array(data.shape) // 2
44
    i, j, k = indices
45
    slice_x = rotate(data[i, :, :], radiological=radiological)
46
    slice_y = rotate(data[:, j, :], radiological=radiological)
47
    slice_z = rotate(data[:, :, k], radiological=radiological)
48
    kwargs = {}
49
    is_label = isinstance(image, LabelMap)
50
    if isinstance(cmap, dict):
51
        slices = slice_x, slice_y, slice_z
52
        slice_x, slice_y, slice_z = color_labels(slices, cmap)
53
    else:
54
        if cmap is None:
55
            cmap = 'cubehelix' if is_label else 'gray'
56
        kwargs['cmap'] = cmap
57
    if is_label:
58
        kwargs['interpolation'] = 'none'
59
60
    sr, sa, ss = image.spacing
61
    kwargs['origin'] = 'lower'
62
63
    sag_aspect = ss / sa
64
    sag_axis.imshow(slice_x, aspect=sag_aspect, **kwargs)
65
    if xlabels:
66
        sag_axis.set_xlabel('A')
67
    sag_axis.set_ylabel('S')
68
    sag_axis.invert_xaxis()
69
    sag_axis.set_title('Sagittal')
70
71
    cor_aspect = ss / sr
72
    cor_axis.imshow(slice_y, aspect=cor_aspect, **kwargs)
73
    if xlabels:
74
        cor_axis.set_xlabel('R')
75
    cor_axis.set_ylabel('S')
76
    cor_axis.invert_xaxis()
77
    cor_axis.set_title('Coronal')
78
79
    axi_aspect = sa / sr
80
    axi_axis.imshow(slice_z, aspect=axi_aspect, **kwargs)
81
    if xlabels:
82
        axi_axis.set_xlabel('R')
83
    axi_axis.set_ylabel('A')
84
    axi_axis.invert_xaxis()
85
    axi_axis.set_title('Axial')
86
87
    plt.tight_layout()
88
    if output_path is not None and fig is not None:
89
        fig.savefig(output_path)
90
    if show:
91
        plt.show()
92
93
94
def plot_subject(
95
        subject: Subject,
96
        cmap_dict=None,
97
        show=True,
98
        output_path=None,
99
        figsize=None,
100
        clear_axes=True,
101
        **kwargs,
102
        ):
103
    _, plt = import_mpl_plt()
104
    subplots_kwargs = {'figsize': figsize}
105
    try:
106
        if clear_axes:
107
            subject.check_consistent_spatial_shape()
108
            subplots_kwargs['sharex'] = 'col'
109
            subplots_kwargs['sharey'] = 'col'
110
    except RuntimeError:  # different shapes in subject
111
        pass
112
    fig, axes = plt.subplots(len(subject), 3, **subplots_kwargs)
113
    # The array of axes must be 2D so that it can be indexed correctly within
114
    # the plot_volume() function
115
    axes = axes.reshape(-1, 3)
116
    iterable = enumerate(subject.get_images_dict(intensity_only=False).items())
117
    axes_names = 'sagittal', 'coronal', 'axial'
118
    for row_index, (name, image) in iterable:
119
        row_axes = axes[row_index]
120
        cmap = None
121
        if cmap_dict is not None and name in cmap_dict:
122
            cmap = cmap_dict[name]
123
        last_row = row_index == len(axes) - 1
124
        plot_volume(
125
            image,
126
            axes=row_axes,
127
            show=False,
128
            cmap=cmap,
129
            xlabels=last_row,
130
            **kwargs,
131
        )
132
        for axis, axis_name in zip(row_axes, axes_names):
133
            axis.set_title(f'{name} ({axis_name})')
134
    plt.tight_layout()
135
    if output_path is not None:
136
        fig.savefig(output_path)
137
    if show:
138
        plt.show()
139
140
141
def color_labels(arrays, cmap_dict):
142
    results = []
143
    for array in arrays:
144
        si, sj = array.shape
145
        rgb = np.zeros((si, sj, 3), dtype=np.uint8)
146
        for label, color in cmap_dict.items():
147
            if isinstance(color, str):
148
                mpl, _ = import_mpl_plt()
149
                color = mpl.colors.to_rgb(color)
150
                color = [255 * n for n in color]
151
            rgb[array == label] = color
152
        results.append(rgb)
153
    return results
154