Passed
Push — master ( 9fff8c...fc33ea )
by Fernando
01:18
created

torchio.visualization.plot_volume()   D

Complexity

Conditions 12

Size

Total Lines 67
Code Lines 59

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 59
dl 0
loc 67
rs 4.7018
c 0
b 0
f 0
cc 12
nop 8

How to fix   Long Method    Complexity    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like torchio.visualization.plot_volume() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import numpy as np
2
3
from .data.image import Image, LabelMap
4
from .data.subject import Subject
5
from .transforms.preprocessing.spatial.to_canonical import ToCanonical
6
7
8
def import_mpl_plt():
9
    try:
10
        import matplotlib as mpl
11
        import matplotlib.pyplot as plt
12
    except ImportError as e:
13
        raise ImportError('Install matplotlib for plotting support') from e
14
    return mpl, plt
15
16
17
def rotate(image, radiological=True):
18
    # Rotate for visualization purposes
19
    image = np.rot90(image, -1)
20
    if radiological:
21
        image = np.fliplr(image)
22
    return image
23
24
25
def plot_volume(
26
        image: Image,
27
        radiological=True,
28
        channel=-1,  # default to foreground for binary maps
29
        axes=None,
30
        cmap=None,
31
        output_path=None,
32
        show=True,
33
        xlabels=True,
34
        ):
35
    _, plt = import_mpl_plt()
36
    fig = None
37
    if axes is None:
38
        fig, axes = plt.subplots(1, 3)
39
    sag_axis, cor_axis, axi_axis = axes
40
41
    image = ToCanonical()(image)
42
    data = image.data[channel]
43
    indices = np.array(data.shape) // 2
44
    i, j, k = indices
45
    slice_x = rotate(data[i, :, :], radiological=radiological)
46
    slice_y = rotate(data[:, j, :], radiological=radiological)
47
    slice_z = rotate(data[:, :, k], radiological=radiological)
48
    kwargs = {}
49
    is_label = isinstance(image, LabelMap)
50
    if isinstance(cmap, dict):
51
        slices = slice_x, slice_y, slice_z
52
        slice_x, slice_y, slice_z = color_labels(slices, cmap)
53
    else:
54
        if cmap is None:
55
            cmap = 'cubehelix' if is_label else 'gray'
56
        kwargs['cmap'] = cmap
57
    if is_label:
58
        kwargs['interpolation'] = 'none'
59
60
    sr, sa, ss = image.spacing
61
    kwargs['origin'] = 'lower'
62
63
    sag_aspect = ss / sa
64
    sag_axis.imshow(slice_x, aspect=sag_aspect, **kwargs)
65
    if xlabels:
66
        sag_axis.set_xlabel('A')
67
    sag_axis.set_ylabel('S')
68
    sag_axis.invert_xaxis()
69
    sag_axis.set_title('Sagittal')
70
71
    cor_aspect = ss / sr
72
    cor_axis.imshow(slice_y, aspect=cor_aspect, **kwargs)
73
    if xlabels:
74
        cor_axis.set_xlabel('R')
75
    cor_axis.set_ylabel('S')
76
    cor_axis.invert_xaxis()
77
    cor_axis.set_title('Coronal')
78
79
    axi_aspect = sa / sr
80
    axi_axis.imshow(slice_z, aspect=axi_aspect, **kwargs)
81
    if xlabels:
82
        axi_axis.set_xlabel('R')
83
    axi_axis.set_ylabel('A')
84
    axi_axis.invert_xaxis()
85
    axi_axis.set_title('Axial')
86
87
    plt.tight_layout()
88
    if output_path is not None and fig is not None:
89
        fig.savefig(output_path)
90
    if show:
91
        plt.show()
92
93
94
def plot_subject(
95
        subject: Subject,
96
        cmap_dict=None,
97
        show=True,
98
        output_path=None,
99
        figsize=None,
100
        clear_axes=True,
101
        **kwargs,
102
        ):
103
    _, plt = import_mpl_plt()
104
    subplots_kwargs = {'figsize': figsize}
105
    try:
106
        if clear_axes:
107
            subject.check_consistent_spatial_shape()
108
            subplots_kwargs['sharex'] = 'col'
109
            subplots_kwargs['sharey'] = 'col'
110
    except RuntimeError:  # different shapes in subject
111
        pass
112
    fig, axes = plt.subplots(len(subject), 3, **subplots_kwargs)
113
    # The array of axes must be 2D so that it can be indexed correctly within
114
    # the plot_volume() function
115
    axes = axes.reshape(-1, 3)
116
    iterable = enumerate(subject.get_images_dict(intensity_only=False).items())
117
    axes_names = 'sagittal', 'coronal', 'axial'
118
    for row_index, (name, image) in iterable:
119
        row_axes = axes[row_index]
120
        cmap = None
121
        if cmap_dict is not None and name in cmap_dict:
122
            cmap = cmap_dict[name]
123
        last_row = row_index == len(axes) - 1
124
        plot_volume(
125
            image,
126
            axes=row_axes,
127
            show=False,
128
            cmap=cmap,
129
            xlabels=last_row,
130
            **kwargs,
131
        )
132
        for axis, axis_name in zip(row_axes, axes_names):
133
            axis.set_title(f'{name} ({axis_name})')
134
    plt.tight_layout()
135
    if output_path is not None:
136
        fig.savefig(output_path)
137
    if show:
138
        plt.show()
139
140
141
def color_labels(arrays, cmap_dict):
142
    results = []
143
    for array in arrays:
144
        si, sj = array.shape
145
        rgb = np.zeros((si, sj, 3), dtype=np.uint8)
146
        for label, color in cmap_dict.items():
147
            if isinstance(color, str):
148
                mpl, _ = import_mpl_plt()
149
                color = mpl.colors.to_rgb(color)
150
                color = [255 * n for n in color]
151
            rgb[array == label] = color
152
        results.append(rgb)
153
    return results
154