Passed
Pull Request — master (#682)
by
unknown
01:26
created

torchio.visualization.plot_volume()   F

Complexity

Conditions 15

Size

Total Lines 83
Code Lines 73

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 73
dl 0
loc 83
rs 2.9836
c 0
b 0
f 0
cc 15
nop 12

How to fix   Long Method    Complexity    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like torchio.visualization.plot_volume() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import warnings
2
3
import torch
4
import numpy as np
5
6
from .typing import TypePath
7
from .data.subject import Subject
8
from .data.image import Image, LabelMap
9
from .transforms.preprocessing.spatial.to_canonical import ToCanonical
10
from .transforms.preprocessing.intensity.rescale import RescaleIntensity
11
12
13
def import_mpl_plt():
14
    try:
15
        import matplotlib as mpl
16
        import matplotlib.pyplot as plt
17
    except ImportError as e:
18
        raise ImportError('Install matplotlib for plotting support') from e
19
    return mpl, plt
20
21
22
def rotate(image, radiological=True, n=-1):
23
    # Rotate for visualization purposes
24
    image = np.rot90(image, n)
25
    if radiological:
26
        image = np.fliplr(image)
27
    return image
28
29
30
def plot_volume(
31
        image: Image,
32
        radiological=True,
33
        channel=-1,  # default to foreground for binary maps
34
        axes=None,
35
        cmap=None,
36
        output_path=None,
37
        show=True,
38
        xlabels=True,
39
        percentiles=(0.5, 99.5),
40
        figsize=None,
41
        reorient=True,
42
        indices=None,
43
        ):
44
    _, plt = import_mpl_plt()
45
    fig = None
46
    if axes is None:
47
        fig, axes = plt.subplots(1, 3, figsize=figsize)
48
    sag_axis, cor_axis, axi_axis = axes
49
50
    if reorient:
51
        image = ToCanonical()(image)
52
    data = image.data[channel]
53
    if indices is None:
54
        indices = np.array(data.shape) // 2
55
    i, j, k = indices
56
    slice_x = rotate(data[i, :, :], radiological=radiological)
57
    slice_y = rotate(data[:, j, :], radiological=radiological)
58
    slice_z = rotate(data[:, :, k], radiological=radiological)
59
    kwargs = {}
60
    is_label = isinstance(image, LabelMap)
61
    if isinstance(cmap, dict):
62
        slices = slice_x, slice_y, slice_z
63
        slice_x, slice_y, slice_z = color_labels(slices, cmap)
64
    else:
65
        if cmap is None:
66
            cmap = 'cubehelix' if is_label else 'gray'
67
        kwargs['cmap'] = cmap
68
    if is_label:
69
        kwargs['interpolation'] = 'none'
70
71
    sr, sa, ss = image.spacing
72
    kwargs['origin'] = 'lower'
73
74
    if percentiles is not None:
75
        p1, p2 = np.percentile(data, percentiles)
76
        kwargs['vmin'] = p1
77
        kwargs['vmax'] = p2
78
79
    sag_aspect = ss / sa
80
    sag_axis.imshow(slice_x, aspect=sag_aspect, **kwargs)
81
    if xlabels:
82
        sag_axis.set_xlabel('A')
83
    sag_axis.set_ylabel('S')
84
    sag_axis.invert_xaxis()
85
    sag_axis.set_title('Sagittal')
86
87
    cor_aspect = ss / sr
88
    cor_axis.imshow(slice_y, aspect=cor_aspect, **kwargs)
89
    if xlabels:
90
        cor_axis.set_xlabel('R')
91
    cor_axis.set_ylabel('S')
92
    cor_axis.invert_xaxis()
93
    cor_axis.set_title('Coronal')
94
95
    axi_aspect = sa / sr
96
    axi_axis.imshow(slice_z, aspect=axi_aspect, **kwargs)
97
    if xlabels:
98
        axi_axis.set_xlabel('R')
99
    axi_axis.set_ylabel('A')
100
    axi_axis.invert_xaxis()
101
    axi_axis.set_title('Axial')
102
    
103
    axes.axis('off')
104
    sag_axis.axis('off')
105
    cor_axis.axis('off')
106
    axi_axis.axis('off')
107
108
    plt.tight_layout()
109
    if output_path is not None and fig is not None:
110
        fig.savefig(output_path, transparent=True)
111
    if show:
112
        plt.show()
113
114
115
def plot_subject(
116
        subject: Subject,
117
        cmap_dict=None,
118
        show=True,
119
        output_path=None,
120
        figsize=None,
121
        clear_axes=True,
122
        **kwargs,
123
        ):
124
    _, plt = import_mpl_plt()
125
    num_images = len(subject)
126
    many_images = num_images > 2
127
    subplots_kwargs = {'figsize': figsize}
128
    try:
129
        if clear_axes:
130
            subject.check_consistent_spatial_shape()
131
            subplots_kwargs['sharex'] = 'row' if many_images else 'col'
132
            subplots_kwargs['sharey'] = 'row' if many_images else 'col'
133
    except RuntimeError:  # different shapes in subject
134
        pass
135
    args = (3, num_images) if many_images else (num_images, 3)
136
    fig, axes = plt.subplots(*args, **subplots_kwargs)
137
    # The array of axes must be 2D so that it can be indexed correctly within
138
    # the plot_volume() function
139
    axes = axes.T if many_images else axes.reshape(-1, 3)
140
    iterable = enumerate(subject.get_images_dict(intensity_only=False).items())
141
    axes_names = 'sagittal', 'coronal', 'axial'
142
    for image_index, (name, image) in iterable:
143
        image_axes = axes[image_index]
144
        cmap = None
145
        if cmap_dict is not None and name in cmap_dict:
146
            cmap = cmap_dict[name]
147
        last_row = image_index == len(axes) - 1
148
        plot_volume(
149
            image,
150
            axes=image_axes,
151
            show=False,
152
            cmap=cmap,
153
            xlabels=last_row,
154
            **kwargs,
155
        )
156
        for axis, axis_name in zip(image_axes, axes_names):
157
            axis.set_title(f'{name} ({axis_name})')
158
    plt.tight_layout()
159
    if output_path is not None:
160
        fig.savefig(output_path)
161
    if show:
162
        plt.show()
163
    plt.close(fig) 
164
165
166
def color_labels(arrays, cmap_dict):
167
    results = []
168
    for array in arrays:
169
        si, sj = array.shape
170
        rgb = np.zeros((si, sj, 3), dtype=np.uint8)
171
        for label, color in cmap_dict.items():
172
            if isinstance(color, str):
173
                mpl, _ = import_mpl_plt()
174
                color = mpl.colors.to_rgb(color)
175
                color = [255 * n for n in color]
176
            rgb[array == label] = color
177
        results.append(rgb)
178
    return results
179
180
181
def make_gif(
182
        tensor: torch.Tensor,
183
        axis: int,
184
        duration: float,  # of full gif
185
        output_path: TypePath,
186
        loop: int = 0,
187
        optimize: bool = True,
188
        rescale: bool = True,
189
        reverse: bool = False,
190
        ) -> None:
191
    try:
192
        from PIL import Image as ImagePIL
193
    except ModuleNotFoundError as e:
194
        message = (
195
            'Please install Pillow to use Image.to_gif():'
196
            ' pip install Pillow'
197
        )
198
        raise RuntimeError(message) from e
199
    tensor = RescaleIntensity((0, 255))(tensor) if rescale else tensor
200
    single_channel = len(tensor) == 1
201
202
    # Move channels dimension to the end and bring selected axis to 0
203
    axes = np.roll(range(1, 4), -axis)
204
    tensor = tensor.permute(*axes, 0)
205
206
    if single_channel:
207
        mode = 'P'
208
        tensor = tensor[..., 0]
209
    else:
210
        mode = 'RGB'
211
    array = tensor.byte().numpy()
212
    n = 2 if axis == 1 else 1
213
    images = [ImagePIL.fromarray(rotate(i, n=n)).convert(mode) for i in array]
214
    num_images = len(images)
215
    images = list(reversed(images)) if reverse else images
216
    frame_duration_ms = duration / num_images * 1000
217
    if frame_duration_ms < 10:
218
        fps = round(1000 / frame_duration_ms)
219
        frame_duration_ms = 10
220
        new_duration = frame_duration_ms * num_images / 1000
221
        message = (
222
            'The computed frame rate from the given duration is too high'
223
            f' ({fps} fps). The highest possible frame rate in the GIF'
224
            ' file format specification is 100 fps. The duration has been set'
225
            f' to {new_duration:.1f} seconds, instead of {duration:.1f}'
226
        )
227
        warnings.warn(message)
228
    images[0].save(
229
        output_path,
230
        save_all=True,
231
        append_images=images[1:],
232
        optimize=optimize,
233
        duration=frame_duration_ms,
234
        loop=loop,
235
    )
236