Passed
Pull Request — master (#682)
by
unknown
01:42
created

torchio.visualization.plot_volume()   F

Complexity

Conditions 15

Size

Total Lines 82
Code Lines 72

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 72
dl 0
loc 82
rs 2.9998
c 0
b 0
f 0
cc 15
nop 12

How to fix   Long Method    Complexity    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like torchio.visualization.plot_volume() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import warnings
2
3
import torch
4
import numpy as np
5
6
from .typing import TypePath
7
from .data.subject import Subject
8
from .data.image import Image, LabelMap
9
from .transforms.preprocessing.spatial.to_canonical import ToCanonical
10
from .transforms.preprocessing.intensity.rescale import RescaleIntensity
11
12
13
def import_mpl_plt():
14
    try:
15
        import matplotlib as mpl
16
        import matplotlib.pyplot as plt
17
    except ImportError as e:
18
        raise ImportError('Install matplotlib for plotting support') from e
19
    return mpl, plt
20
21
22
def rotate(image, radiological=True, n=-1):
23
    # Rotate for visualization purposes
24
    image = np.rot90(image, n)
25
    if radiological:
26
        image = np.fliplr(image)
27
    return image
28
29
30
def plot_volume(
31
        image: Image,
32
        radiological=True,
33
        channel=-1,  # default to foreground for binary maps
34
        axes=None,
35
        cmap=None,
36
        output_path=None,
37
        show=True,
38
        xlabels=True,
39
        percentiles=(0.5, 99.5),
40
        figsize=None,
41
        reorient=True,
42
        indices=None,
43
        ):
44
    _, plt = import_mpl_plt()
45
    fig = None
46
    if axes is None:
47
        fig, axes = plt.subplots(1, 3, figsize=figsize)
48
    sag_axis, cor_axis, axi_axis = axes
49
50
    if reorient:
51
        image = ToCanonical()(image)
52
    data = image.data[channel]
53
    if indices is None:
54
        indices = np.array(data.shape) // 2
55
    i, j, k = indices
56
    slice_x = rotate(data[i, :, :], radiological=radiological)
57
    slice_y = rotate(data[:, j, :], radiological=radiological)
58
    slice_z = rotate(data[:, :, k], radiological=radiological)
59
    kwargs = {}
60
    is_label = isinstance(image, LabelMap)
61
    if isinstance(cmap, dict):
62
        slices = slice_x, slice_y, slice_z
63
        slice_x, slice_y, slice_z = color_labels(slices, cmap)
64
    else:
65
        if cmap is None:
66
            cmap = 'cubehelix' if is_label else 'gray'
67
        kwargs['cmap'] = cmap
68
    if is_label:
69
        kwargs['interpolation'] = 'none'
70
71
    sr, sa, ss = image.spacing
72
    kwargs['origin'] = 'lower'
73
74
    if percentiles is not None:
75
        p1, p2 = np.percentile(data, percentiles)
76
        kwargs['vmin'] = p1
77
        kwargs['vmax'] = p2
78
79
    sag_aspect = ss / sa
80
    sag_axis.imshow(slice_x, aspect=sag_aspect, **kwargs)
81
    if xlabels:
82
        sag_axis.set_xlabel('A')
83
    sag_axis.set_ylabel('S')
84
    sag_axis.invert_xaxis()
85
    sag_axis.set_title('Sagittal')
86
87
    cor_aspect = ss / sr
88
    cor_axis.imshow(slice_y, aspect=cor_aspect, **kwargs)
89
    if xlabels:
90
        cor_axis.set_xlabel('R')
91
    cor_axis.set_ylabel('S')
92
    cor_axis.invert_xaxis()
93
    cor_axis.set_title('Coronal')
94
95
    axi_aspect = sa / sr
96
    axi_axis.imshow(slice_z, aspect=axi_aspect, **kwargs)
97
    if xlabels:
98
        axi_axis.set_xlabel('R')
99
    axi_axis.set_ylabel('A')
100
    axi_axis.invert_xaxis()
101
    axi_axis.set_title('Axial')
102
    
103
    sag_axis.axis('off')
104
    cor_axis.axis('off')
105
    axi_axis.axis('off')
106
107
    plt.tight_layout()
108
    if output_path is not None and fig is not None:
109
        fig.savefig(output_path, transparent=True)
110
    if show:
111
        plt.show()
112
113
114
def plot_subject(
115
        subject: Subject,
116
        cmap_dict=None,
117
        show=True,
118
        output_path=None,
119
        figsize=None,
120
        clear_axes=True,
121
        **kwargs,
122
        ):
123
    _, plt = import_mpl_plt()
124
    num_images = len(subject)
125
    many_images = num_images > 2
126
    subplots_kwargs = {'figsize': figsize}
127
    try:
128
        if clear_axes:
129
            subject.check_consistent_spatial_shape()
130
            subplots_kwargs['sharex'] = 'row' if many_images else 'col'
131
            subplots_kwargs['sharey'] = 'row' if many_images else 'col'
132
    except RuntimeError:  # different shapes in subject
133
        pass
134
    args = (3, num_images) if many_images else (num_images, 3)
135
    fig, axes = plt.subplots(*args, **subplots_kwargs)
136
    # The array of axes must be 2D so that it can be indexed correctly within
137
    # the plot_volume() function
138
    axes = axes.T if many_images else axes.reshape(-1, 3)
139
    iterable = enumerate(subject.get_images_dict(intensity_only=False).items())
140
    axes_names = 'sagittal', 'coronal', 'axial'
141
    for image_index, (name, image) in iterable:
142
        image_axes = axes[image_index]
143
        cmap = None
144
        if cmap_dict is not None and name in cmap_dict:
145
            cmap = cmap_dict[name]
146
        last_row = image_index == len(axes) - 1
147
        plot_volume(
148
            image,
149
            axes=image_axes,
150
            show=False,
151
            cmap=cmap,
152
            xlabels=last_row,
153
            **kwargs,
154
        )
155
        for axis, axis_name in zip(image_axes, axes_names):
156
            axis.set_title(f'{name} ({axis_name})')
157
    plt.tight_layout()
158
    if output_path is not None:
159
        fig.savefig(output_path)
160
    if show:
161
        plt.show()
162
    plt.close(fig) 
163
164
165
def color_labels(arrays, cmap_dict):
166
    results = []
167
    for array in arrays:
168
        si, sj = array.shape
169
        rgb = np.zeros((si, sj, 3), dtype=np.uint8)
170
        for label, color in cmap_dict.items():
171
            if isinstance(color, str):
172
                mpl, _ = import_mpl_plt()
173
                color = mpl.colors.to_rgb(color)
174
                color = [255 * n for n in color]
175
            rgb[array == label] = color
176
        results.append(rgb)
177
    return results
178
179
180
def make_gif(
181
        tensor: torch.Tensor,
182
        axis: int,
183
        duration: float,  # of full gif
184
        output_path: TypePath,
185
        loop: int = 0,
186
        optimize: bool = True,
187
        rescale: bool = True,
188
        reverse: bool = False,
189
        ) -> None:
190
    try:
191
        from PIL import Image as ImagePIL
192
    except ModuleNotFoundError as e:
193
        message = (
194
            'Please install Pillow to use Image.to_gif():'
195
            ' pip install Pillow'
196
        )
197
        raise RuntimeError(message) from e
198
    tensor = RescaleIntensity((0, 255))(tensor) if rescale else tensor
199
    single_channel = len(tensor) == 1
200
201
    # Move channels dimension to the end and bring selected axis to 0
202
    axes = np.roll(range(1, 4), -axis)
203
    tensor = tensor.permute(*axes, 0)
204
205
    if single_channel:
206
        mode = 'P'
207
        tensor = tensor[..., 0]
208
    else:
209
        mode = 'RGB'
210
    array = tensor.byte().numpy()
211
    n = 2 if axis == 1 else 1
212
    images = [ImagePIL.fromarray(rotate(i, n=n)).convert(mode) for i in array]
213
    num_images = len(images)
214
    images = list(reversed(images)) if reverse else images
215
    frame_duration_ms = duration / num_images * 1000
216
    if frame_duration_ms < 10:
217
        fps = round(1000 / frame_duration_ms)
218
        frame_duration_ms = 10
219
        new_duration = frame_duration_ms * num_images / 1000
220
        message = (
221
            'The computed frame rate from the given duration is too high'
222
            f' ({fps} fps). The highest possible frame rate in the GIF'
223
            ' file format specification is 100 fps. The duration has been set'
224
            f' to {new_duration:.1f} seconds, instead of {duration:.1f}'
225
        )
226
        warnings.warn(message)
227
    images[0].save(
228
        output_path,
229
        save_all=True,
230
        append_images=images[1:],
231
        optimize=optimize,
232
        duration=frame_duration_ms,
233
        loop=loop,
234
    )
235