torchio.visualization.plot_volume()   F
last analyzed

Complexity

Conditions 15

Size

Total Lines 78
Code Lines 69

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 69
dl 0
loc 78
rs 2.9998
c 0
b 0
f 0
cc 15
nop 12

How to fix   Long Method    Complexity    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like torchio.visualization.plot_volume() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import warnings
2
3
import torch
4
import numpy as np
5
6
from .typing import TypePath
7
from .data.subject import Subject
8
from .data.image import Image, LabelMap
9
from .transforms.preprocessing.spatial.to_canonical import ToCanonical
10
from .transforms.preprocessing.intensity.rescale import RescaleIntensity
11
12
13
def import_mpl_plt():
14
    try:
15
        import matplotlib as mpl
16
        import matplotlib.pyplot as plt
17
    except ImportError as e:
18
        raise ImportError('Install matplotlib for plotting support') from e
19
    return mpl, plt
20
21
22
def rotate(image, radiological=True, n=-1):
23
    # Rotate for visualization purposes
24
    image = np.rot90(image, n)
25
    if radiological:
26
        image = np.fliplr(image)
27
    return image
28
29
30
def plot_volume(
31
        image: Image,
32
        radiological=True,
33
        channel=-1,  # default to foreground for binary maps
34
        axes=None,
35
        cmap=None,
36
        output_path=None,
37
        show=True,
38
        xlabels=True,
39
        percentiles=(0.5, 99.5),
40
        figsize=None,
41
        reorient=True,
42
        indices=None,
43
        ):
44
    _, plt = import_mpl_plt()
45
    fig = None
46
    if axes is None:
47
        fig, axes = plt.subplots(1, 3, figsize=figsize)
48
    sag_axis, cor_axis, axi_axis = axes
49
50
    if reorient:
51
        image = ToCanonical()(image)
52
    data = image.data[channel]
53
    if indices is None:
54
        indices = np.array(data.shape) // 2
55
    i, j, k = indices
56
    slice_x = rotate(data[i, :, :], radiological=radiological)
57
    slice_y = rotate(data[:, j, :], radiological=radiological)
58
    slice_z = rotate(data[:, :, k], radiological=radiological)
59
    kwargs = {}
60
    is_label = isinstance(image, LabelMap)
61
    if isinstance(cmap, dict):
62
        slices = slice_x, slice_y, slice_z
63
        slice_x, slice_y, slice_z = color_labels(slices, cmap)
64
    else:
65
        if cmap is None:
66
            cmap = 'cubehelix' if is_label else 'gray'
67
        kwargs['cmap'] = cmap
68
    if is_label:
69
        kwargs['interpolation'] = 'none'
70
71
    sr, sa, ss = image.spacing
72
    kwargs['origin'] = 'lower'
73
74
    if percentiles is not None:
75
        p1, p2 = np.percentile(data, percentiles)
76
        kwargs['vmin'] = p1
77
        kwargs['vmax'] = p2
78
79
    sag_aspect = ss / sa
80
    sag_axis.imshow(slice_x, aspect=sag_aspect, **kwargs)
81
    if xlabels:
82
        sag_axis.set_xlabel('A')
83
    sag_axis.set_ylabel('S')
84
    sag_axis.invert_xaxis()
85
    sag_axis.set_title('Sagittal')
86
87
    cor_aspect = ss / sr
88
    cor_axis.imshow(slice_y, aspect=cor_aspect, **kwargs)
89
    if xlabels:
90
        cor_axis.set_xlabel('R')
91
    cor_axis.set_ylabel('S')
92
    cor_axis.invert_xaxis()
93
    cor_axis.set_title('Coronal')
94
95
    axi_aspect = sa / sr
96
    axi_axis.imshow(slice_z, aspect=axi_aspect, **kwargs)
97
    if xlabels:
98
        axi_axis.set_xlabel('R')
99
    axi_axis.set_ylabel('A')
100
    axi_axis.invert_xaxis()
101
    axi_axis.set_title('Axial')
102
103
    plt.tight_layout()
104
    if output_path is not None and fig is not None:
105
        fig.savefig(output_path)
106
    if show:
107
        plt.show()
108
109
110
def plot_subject(
111
        subject: Subject,
112
        cmap_dict=None,
113
        show=True,
114
        output_path=None,
115
        figsize=None,
116
        clear_axes=True,
117
        **kwargs,
118
        ):
119
    _, plt = import_mpl_plt()
120
    num_images = len(subject)
121
    many_images = num_images > 2
122
    subplots_kwargs = {'figsize': figsize}
123
    try:
124
        if clear_axes:
125
            subject.check_consistent_spatial_shape()
126
            subplots_kwargs['sharex'] = 'row' if many_images else 'col'
127
            subplots_kwargs['sharey'] = 'row' if many_images else 'col'
128
    except RuntimeError:  # different shapes in subject
129
        pass
130
    args = (3, num_images) if many_images else (num_images, 3)
131
    fig, axes = plt.subplots(*args, **subplots_kwargs)
132
    # The array of axes must be 2D so that it can be indexed correctly within
133
    # the plot_volume() function
134
    axes = axes.T if many_images else axes.reshape(-1, 3)
135
    iterable = enumerate(subject.get_images_dict(intensity_only=False).items())
136
    axes_names = 'sagittal', 'coronal', 'axial'
137
    for image_index, (name, image) in iterable:
138
        image_axes = axes[image_index]
139
        cmap = None
140
        if cmap_dict is not None and name in cmap_dict:
141
            cmap = cmap_dict[name]
142
        last_row = image_index == len(axes) - 1
143
        plot_volume(
144
            image,
145
            axes=image_axes,
146
            show=False,
147
            cmap=cmap,
148
            xlabels=last_row,
149
            **kwargs,
150
        )
151
        for axis, axis_name in zip(image_axes, axes_names):
152
            axis.set_title(f'{name} ({axis_name})')
153
    plt.tight_layout()
154
    if output_path is not None:
155
        fig.savefig(output_path)
156
    if show:
157
        plt.show()
158
    plt.close(fig)
159
160
161
def color_labels(arrays, cmap_dict):
162
    results = []
163
    for array in arrays:
164
        si, sj = array.shape
165
        rgb = np.zeros((si, sj, 3), dtype=np.uint8)
166
        for label, color in cmap_dict.items():
167
            if isinstance(color, str):
168
                mpl, _ = import_mpl_plt()
169
                color = mpl.colors.to_rgb(color)
170
                color = [255 * n for n in color]
171
            rgb[array == label] = color
172
        results.append(rgb)
173
    return results
174
175
176
def make_gif(
177
        tensor: torch.Tensor,
178
        axis: int,
179
        duration: float,  # of full gif
180
        output_path: TypePath,
181
        loop: int = 0,
182
        optimize: bool = True,
183
        rescale: bool = True,
184
        reverse: bool = False,
185
        ) -> None:
186
    try:
187
        from PIL import Image as ImagePIL
188
    except ModuleNotFoundError as e:
189
        message = (
190
            'Please install Pillow to use Image.to_gif():'
191
            ' pip install Pillow'
192
        )
193
        raise RuntimeError(message) from e
194
    tensor = RescaleIntensity((0, 255))(tensor) if rescale else tensor
195
    single_channel = len(tensor) == 1
196
197
    # Move channels dimension to the end and bring selected axis to 0
198
    axes = np.roll(range(1, 4), -axis)
199
    tensor = tensor.permute(*axes, 0)
200
201
    if single_channel:
202
        mode = 'P'
203
        tensor = tensor[..., 0]
204
    else:
205
        mode = 'RGB'
206
    array = tensor.byte().numpy()
207
    n = 2 if axis == 1 else 1
208
    images = [ImagePIL.fromarray(rotate(i, n=n)).convert(mode) for i in array]
209
    num_images = len(images)
210
    images = list(reversed(images)) if reverse else images
211
    frame_duration_ms = duration / num_images * 1000
212
    if frame_duration_ms < 10:
213
        fps = round(1000 / frame_duration_ms)
214
        frame_duration_ms = 10
215
        new_duration = frame_duration_ms * num_images / 1000
216
        message = (
217
            'The computed frame rate from the given duration is too high'
218
            f' ({fps} fps). The highest possible frame rate in the GIF'
219
            ' file format specification is 100 fps. The duration has been set'
220
            f' to {new_duration:.1f} seconds, instead of {duration:.1f}'
221
        )
222
        warnings.warn(message)
223
    images[0].save(
224
        output_path,
225
        save_all=True,
226
        append_images=images[1:],
227
        optimize=optimize,
228
        duration=frame_duration_ms,
229
        loop=loop,
230
    )
231