Passed
Push — master ( 99e6cf...52d4d3 )
by Fernando
01:13
created

torchio.visualization.plot_subject()   D

Complexity

Conditions 13

Size

Total Lines 48
Code Lines 44

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 44
dl 0
loc 48
rs 4.2
c 0
b 0
f 0
cc 13
nop 7

How to fix   Complexity   

Complexity

Complex classes like torchio.visualization.plot_subject() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import numpy as np
2
3
from .data.image import Image, LabelMap
4
from .data.subject import Subject
5
from .transforms.preprocessing.spatial.to_canonical import ToCanonical
6
7
8
def import_mpl_plt():
9
    try:
10
        import matplotlib as mpl
11
        import matplotlib.pyplot as plt
12
    except ImportError as e:
13
        raise ImportError('Install matplotlib for plotting support') from e
14
    return mpl, plt
15
16
17
def rotate(image, radiological=True):
18
    # Rotate for visualization purposes
19
    image = np.rot90(image, -1)
20
    if radiological:
21
        image = np.fliplr(image)
22
    return image
23
24
25
def plot_volume(
26
        image: Image,
27
        radiological=True,
28
        channel=-1,  # default to foreground for binary maps
29
        axes=None,
30
        cmap=None,
31
        output_path=None,
32
        show=True,
33
        xlabels=True,
34
        percentiles=(0.5, 99.5),
35
        figsize=None,
36
        reorient=True,
37
        ):
38
    _, plt = import_mpl_plt()
39
    fig = None
40
    if axes is None:
41
        fig, axes = plt.subplots(1, 3, figsize=figsize)
42
    sag_axis, cor_axis, axi_axis = axes
43
44
    if reorient:
45
        image = ToCanonical()(image)
46
    data = image.data[channel]
47
    indices = np.array(data.shape) // 2
48
    i, j, k = indices
49
    slice_x = rotate(data[i, :, :], radiological=radiological)
50
    slice_y = rotate(data[:, j, :], radiological=radiological)
51
    slice_z = rotate(data[:, :, k], radiological=radiological)
52
    kwargs = {}
53
    is_label = isinstance(image, LabelMap)
54
    if isinstance(cmap, dict):
55
        slices = slice_x, slice_y, slice_z
56
        slice_x, slice_y, slice_z = color_labels(slices, cmap)
57
    else:
58
        if cmap is None:
59
            cmap = 'cubehelix' if is_label else 'gray'
60
        kwargs['cmap'] = cmap
61
    if is_label:
62
        kwargs['interpolation'] = 'none'
63
64
    sr, sa, ss = image.spacing
65
    kwargs['origin'] = 'lower'
66
67
    if percentiles is not None:
68
        p1, p2 = np.percentile(data, percentiles)
69
        kwargs['vmin'] = p1
70
        kwargs['vmax'] = p2
71
72
    sag_aspect = ss / sa
73
    sag_axis.imshow(slice_x, aspect=sag_aspect, **kwargs)
74
    if xlabels:
75
        sag_axis.set_xlabel('A')
76
    sag_axis.set_ylabel('S')
77
    sag_axis.invert_xaxis()
78
    sag_axis.set_title('Sagittal')
79
80
    cor_aspect = ss / sr
81
    cor_axis.imshow(slice_y, aspect=cor_aspect, **kwargs)
82
    if xlabels:
83
        cor_axis.set_xlabel('R')
84
    cor_axis.set_ylabel('S')
85
    cor_axis.invert_xaxis()
86
    cor_axis.set_title('Coronal')
87
88
    axi_aspect = sa / sr
89
    axi_axis.imshow(slice_z, aspect=axi_aspect, **kwargs)
90
    if xlabels:
91
        axi_axis.set_xlabel('R')
92
    axi_axis.set_ylabel('A')
93
    axi_axis.invert_xaxis()
94
    axi_axis.set_title('Axial')
95
96
    plt.tight_layout()
97
    if output_path is not None and fig is not None:
98
        fig.savefig(output_path)
99
    if show:
100
        plt.show()
101
102
103
def plot_subject(
104
        subject: Subject,
105
        cmap_dict=None,
106
        show=True,
107
        output_path=None,
108
        figsize=None,
109
        clear_axes=True,
110
        **kwargs,
111
        ):
112
    _, plt = import_mpl_plt()
113
    num_images = len(subject)
114
    many_images = num_images > 2
115
    subplots_kwargs = {'figsize': figsize}
116
    try:
117
        if clear_axes:
118
            subject.check_consistent_spatial_shape()
119
            subplots_kwargs['sharex'] = 'row' if many_images else 'col'
120
            subplots_kwargs['sharey'] = 'row' if many_images else 'col'
121
    except RuntimeError:  # different shapes in subject
122
        pass
123
    args = (3, num_images) if many_images else (num_images, 3)
124
    fig, axes = plt.subplots(*args, **subplots_kwargs)
125
    # The array of axes must be 2D so that it can be indexed correctly within
126
    # the plot_volume() function
127
    axes = axes.T if many_images else axes.reshape(-1, 3)
128
    iterable = enumerate(subject.get_images_dict(intensity_only=False).items())
129
    axes_names = 'sagittal', 'coronal', 'axial'
130
    for image_index, (name, image) in iterable:
131
        image_axes = axes[image_index]
132
        cmap = None
133
        if cmap_dict is not None and name in cmap_dict:
134
            cmap = cmap_dict[name]
135
        last_row = image_index == len(axes) - 1
136
        plot_volume(
137
            image,
138
            axes=image_axes,
139
            show=False,
140
            cmap=cmap,
141
            xlabels=last_row,
142
            **kwargs,
143
        )
144
        for axis, axis_name in zip(image_axes, axes_names):
145
            axis.set_title(f'{name} ({axis_name})')
146
    plt.tight_layout()
147
    if output_path is not None:
148
        fig.savefig(output_path)
149
    if show:
150
        plt.show()
151
152
153
def color_labels(arrays, cmap_dict):
154
    results = []
155
    for array in arrays:
156
        si, sj = array.shape
157
        rgb = np.zeros((si, sj, 3), dtype=np.uint8)
158
        for label, color in cmap_dict.items():
159
            if isinstance(color, str):
160
                mpl, _ = import_mpl_plt()
161
                color = mpl.colors.to_rgb(color)
162
                color = [255 * n for n in color]
163
            rgb[array == label] = color
164
        results.append(rgb)
165
    return results
166