Passed
Push — master ( 99e6cf...52d4d3 )
by Fernando
01:13
created

torchio.visualization.plot_volume()   F

Complexity

Conditions 14

Size

Total Lines 76
Code Lines 67

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 67
dl 0
loc 76
rs 3.6
c 0
b 0
f 0
cc 14
nop 11

How to fix   Long Method    Complexity    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like torchio.visualization.plot_volume() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import numpy as np
2
3
from .data.image import Image, LabelMap
4
from .data.subject import Subject
5
from .transforms.preprocessing.spatial.to_canonical import ToCanonical
6
7
8
def import_mpl_plt():
9
    try:
10
        import matplotlib as mpl
11
        import matplotlib.pyplot as plt
12
    except ImportError as e:
13
        raise ImportError('Install matplotlib for plotting support') from e
14
    return mpl, plt
15
16
17
def rotate(image, radiological=True):
18
    # Rotate for visualization purposes
19
    image = np.rot90(image, -1)
20
    if radiological:
21
        image = np.fliplr(image)
22
    return image
23
24
25
def plot_volume(
26
        image: Image,
27
        radiological=True,
28
        channel=-1,  # default to foreground for binary maps
29
        axes=None,
30
        cmap=None,
31
        output_path=None,
32
        show=True,
33
        xlabels=True,
34
        percentiles=(0.5, 99.5),
35
        figsize=None,
36
        reorient=True,
37
        ):
38
    _, plt = import_mpl_plt()
39
    fig = None
40
    if axes is None:
41
        fig, axes = plt.subplots(1, 3, figsize=figsize)
42
    sag_axis, cor_axis, axi_axis = axes
43
44
    if reorient:
45
        image = ToCanonical()(image)
46
    data = image.data[channel]
47
    indices = np.array(data.shape) // 2
48
    i, j, k = indices
49
    slice_x = rotate(data[i, :, :], radiological=radiological)
50
    slice_y = rotate(data[:, j, :], radiological=radiological)
51
    slice_z = rotate(data[:, :, k], radiological=radiological)
52
    kwargs = {}
53
    is_label = isinstance(image, LabelMap)
54
    if isinstance(cmap, dict):
55
        slices = slice_x, slice_y, slice_z
56
        slice_x, slice_y, slice_z = color_labels(slices, cmap)
57
    else:
58
        if cmap is None:
59
            cmap = 'cubehelix' if is_label else 'gray'
60
        kwargs['cmap'] = cmap
61
    if is_label:
62
        kwargs['interpolation'] = 'none'
63
64
    sr, sa, ss = image.spacing
65
    kwargs['origin'] = 'lower'
66
67
    if percentiles is not None:
68
        p1, p2 = np.percentile(data, percentiles)
69
        kwargs['vmin'] = p1
70
        kwargs['vmax'] = p2
71
72
    sag_aspect = ss / sa
73
    sag_axis.imshow(slice_x, aspect=sag_aspect, **kwargs)
74
    if xlabels:
75
        sag_axis.set_xlabel('A')
76
    sag_axis.set_ylabel('S')
77
    sag_axis.invert_xaxis()
78
    sag_axis.set_title('Sagittal')
79
80
    cor_aspect = ss / sr
81
    cor_axis.imshow(slice_y, aspect=cor_aspect, **kwargs)
82
    if xlabels:
83
        cor_axis.set_xlabel('R')
84
    cor_axis.set_ylabel('S')
85
    cor_axis.invert_xaxis()
86
    cor_axis.set_title('Coronal')
87
88
    axi_aspect = sa / sr
89
    axi_axis.imshow(slice_z, aspect=axi_aspect, **kwargs)
90
    if xlabels:
91
        axi_axis.set_xlabel('R')
92
    axi_axis.set_ylabel('A')
93
    axi_axis.invert_xaxis()
94
    axi_axis.set_title('Axial')
95
96
    plt.tight_layout()
97
    if output_path is not None and fig is not None:
98
        fig.savefig(output_path)
99
    if show:
100
        plt.show()
101
102
103
def plot_subject(
104
        subject: Subject,
105
        cmap_dict=None,
106
        show=True,
107
        output_path=None,
108
        figsize=None,
109
        clear_axes=True,
110
        **kwargs,
111
        ):
112
    _, plt = import_mpl_plt()
113
    num_images = len(subject)
114
    many_images = num_images > 2
115
    subplots_kwargs = {'figsize': figsize}
116
    try:
117
        if clear_axes:
118
            subject.check_consistent_spatial_shape()
119
            subplots_kwargs['sharex'] = 'row' if many_images else 'col'
120
            subplots_kwargs['sharey'] = 'row' if many_images else 'col'
121
    except RuntimeError:  # different shapes in subject
122
        pass
123
    args = (3, num_images) if many_images else (num_images, 3)
124
    fig, axes = plt.subplots(*args, **subplots_kwargs)
125
    # The array of axes must be 2D so that it can be indexed correctly within
126
    # the plot_volume() function
127
    axes = axes.T if many_images else axes.reshape(-1, 3)
128
    iterable = enumerate(subject.get_images_dict(intensity_only=False).items())
129
    axes_names = 'sagittal', 'coronal', 'axial'
130
    for image_index, (name, image) in iterable:
131
        image_axes = axes[image_index]
132
        cmap = None
133
        if cmap_dict is not None and name in cmap_dict:
134
            cmap = cmap_dict[name]
135
        last_row = image_index == len(axes) - 1
136
        plot_volume(
137
            image,
138
            axes=image_axes,
139
            show=False,
140
            cmap=cmap,
141
            xlabels=last_row,
142
            **kwargs,
143
        )
144
        for axis, axis_name in zip(image_axes, axes_names):
145
            axis.set_title(f'{name} ({axis_name})')
146
    plt.tight_layout()
147
    if output_path is not None:
148
        fig.savefig(output_path)
149
    if show:
150
        plt.show()
151
152
153
def color_labels(arrays, cmap_dict):
154
    results = []
155
    for array in arrays:
156
        si, sj = array.shape
157
        rgb = np.zeros((si, sj, 3), dtype=np.uint8)
158
        for label, color in cmap_dict.items():
159
            if isinstance(color, str):
160
                mpl, _ = import_mpl_plt()
161
                color = mpl.colors.to_rgb(color)
162
                color = [255 * n for n in color]
163
            rgb[array == label] = color
164
        results.append(rgb)
165
    return results
166