Passed
Pull Request — master (#682)
by
unknown
14:34
created

torchio.visualization.plot_volume()   D

Complexity

Conditions 12

Size

Total Lines 88
Code Lines 57

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 57
dl 0
loc 88
rs 4.7672
c 0
b 0
f 0
cc 12
nop 12

How to fix   Long Method    Complexity    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like torchio.visualization.plot_volume() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import warnings
2
3
import torch
4
import numpy as np
5
6
from .typing import TypePath
7
from .data.subject import Subject
8
from .data.image import Image, LabelMap
9
from .transforms.preprocessing.spatial.to_canonical import ToCanonical
10
from .transforms.preprocessing.intensity.rescale import RescaleIntensity
11
12
13
def import_mpl_plt():
14
    try:
15
        import matplotlib as mpl
16
        import matplotlib.pyplot as plt
17
    except ImportError as e:
18
        raise ImportError('Install matplotlib for plotting support') from e
19
    return mpl, plt
20
21
22
def rotate(image, radiological=True, n=-1):
23
    # Rotate for visualization purposes
24
    image = np.rot90(image, n)
25
    if radiological:
26
        image = np.fliplr(image)
27
    return image
28
29
30
def plot_volume(
31
        image: Image,
32
        radiological=True,
33
        channel=-1,  # default to foreground for binary maps
34
        axes=None,
35
        cmap=None,
36
        output_path=None,
37
        show=True,
38
        xlabels=True,
39
        percentiles=(0.5, 99.5),
40
        figsize=None,
41
        reorient=True,
42
        indices=None,
43
        ):
44
    _, plt = import_mpl_plt()
45
    fig = None
46
    if axes is None:
47
        fig, axes = plt.subplots(1, 3, figsize=figsize)
48
    sag_axis, cor_axis, axi_axis = axes
49
50
    if reorient:
51
        image = ToCanonical()(image)
52
    data = image.data[channel]
53
    if indices is None:
54
        indices = np.array(data.shape) // 2
55
    i, j, k = indices
56
    slice_x = rotate(data[i, :, :], radiological=radiological)
57
    slice_y = rotate(data[:, j, :], radiological=radiological)
58
    slice_z = rotate(data[:, :, k], radiological=radiological)
59
    kwargs = {}
60
    is_label = isinstance(image, LabelMap)
61
    if isinstance(cmap, dict):
62
        slices = slice_x, slice_y, slice_z
63
        slice_x, slice_y, slice_z = color_labels(slices, cmap)
64
    else:
65
        if cmap is None:
66
            cmap = 'cubehelix' if is_label else 'gray'
67
        kwargs['cmap'] = cmap
68
    if is_label:
69
        kwargs['interpolation'] = 'none'
70
71
    sr, sa, ss = image.spacing
72
    kwargs['origin'] = 'lower'
73
74
    if percentiles is not None:
75
        p1, p2 = np.percentile(data, percentiles)
76
        kwargs['vmin'] = p1
77
        kwargs['vmax'] = p2
78
79
    sag_aspect = ss / sa
80
    sag_axis.imshow(slice_x, aspect=sag_aspect, **kwargs)
81
    '''
82
    if xlabels:
83
        sag_axis.set_xlabel('A')
84
    sag_axis.set_ylabel('S')
85
    sag_axis.invert_xaxis()
86
    sag_axis.set_title('Sagittal')
87
    '''
88
89
    cor_aspect = ss / sr
90
    cor_axis.imshow(slice_y, aspect=cor_aspect, **kwargs)
91
    '''
92
    if xlabels:
93
        cor_axis.set_xlabel('R')
94
    cor_axis.set_ylabel('S')
95
    cor_axis.invert_xaxis()
96
    cor_axis.set_title('Coronal')
97
    '''
98
99
    axi_aspect = sa / sr
100
    axi_axis.imshow(slice_z, aspect=axi_aspect, **kwargs)
101
    '''
102
    if xlabels:
103
        axi_axis.set_xlabel('R')
104
    axi_axis.set_ylabel('A')
105
    axi_axis.invert_xaxis()
106
    axi_axis.set_title('Axial')
107
    '''
108
    
109
    sag_axis.axis('off')
110
    cor_axis.axis('off')
111
    axi_axis.axis('off')
112
113
    plt.tight_layout()
114
    if output_path is not None and fig is not None:
115
        fig.savefig(output_path, transparent=True)
116
    if show:
117
        plt.show()
118
119
120
def plot_subject(
121
        subject: Subject,
122
        cmap_dict=None,
123
        show=True,
124
        output_path=None,
125
        figsize=None,
126
        clear_axes=True,
127
        **kwargs,
128
        ):
129
    _, plt = import_mpl_plt()
130
    num_images = len(subject)
131
    many_images = num_images > 2
132
    subplots_kwargs = {'figsize': figsize}
133
    try:
134
        if clear_axes:
135
            subject.check_consistent_spatial_shape()
136
            subplots_kwargs['sharex'] = 'row' if many_images else 'col'
137
            subplots_kwargs['sharey'] = 'row' if many_images else 'col'
138
    except RuntimeError:  # different shapes in subject
139
        pass
140
    args = (3, num_images) if many_images else (num_images, 3)
141
    fig, axes = plt.subplots(*args, **subplots_kwargs)
142
    # The array of axes must be 2D so that it can be indexed correctly within
143
    # the plot_volume() function
144
    axes = axes.T if many_images else axes.reshape(-1, 3)
145
    iterable = enumerate(subject.get_images_dict(intensity_only=False).items())
146
    axes_names = 'sagittal', 'coronal', 'axial'
147
    for image_index, (name, image) in iterable:
148
        image_axes = axes[image_index]
149
        cmap = None
150
        if cmap_dict is not None and name in cmap_dict:
151
            cmap = cmap_dict[name]
152
        last_row = image_index == len(axes) - 1
153
        plot_volume(
154
            image,
155
            axes=image_axes,
156
            show=False,
157
            cmap=cmap,
158
            xlabels=last_row,
159
            **kwargs,
160
        )
161
        for axis, axis_name in zip(image_axes, axes_names):
162
            axis.set_title(f'{name} ({axis_name})')
163
    plt.tight_layout()
164
    if output_path is not None:
165
        fig.savefig(output_path)
166
    if show:
167
        plt.show()
168
    plt.close(fig) 
169
170
171
def color_labels(arrays, cmap_dict):
172
    results = []
173
    for array in arrays:
174
        si, sj = array.shape
175
        rgb = np.zeros((si, sj, 3), dtype=np.uint8)
176
        for label, color in cmap_dict.items():
177
            if isinstance(color, str):
178
                mpl, _ = import_mpl_plt()
179
                color = mpl.colors.to_rgb(color)
180
                color = [255 * n for n in color]
181
            rgb[array == label] = color
182
        results.append(rgb)
183
    return results
184
185
186
def make_gif(
187
        tensor: torch.Tensor,
188
        axis: int,
189
        duration: float,  # of full gif
190
        output_path: TypePath,
191
        loop: int = 0,
192
        optimize: bool = True,
193
        rescale: bool = True,
194
        reverse: bool = False,
195
        ) -> None:
196
    try:
197
        from PIL import Image as ImagePIL
198
    except ModuleNotFoundError as e:
199
        message = (
200
            'Please install Pillow to use Image.to_gif():'
201
            ' pip install Pillow'
202
        )
203
        raise RuntimeError(message) from e
204
    tensor = RescaleIntensity((0, 255))(tensor) if rescale else tensor
205
    single_channel = len(tensor) == 1
206
207
    # Move channels dimension to the end and bring selected axis to 0
208
    axes = np.roll(range(1, 4), -axis)
209
    tensor = tensor.permute(*axes, 0)
210
211
    if single_channel:
212
        mode = 'P'
213
        tensor = tensor[..., 0]
214
    else:
215
        mode = 'RGB'
216
    array = tensor.byte().numpy()
217
    n = 2 if axis == 1 else 1
218
    images = [ImagePIL.fromarray(rotate(i, n=n)).convert(mode) for i in array]
219
    num_images = len(images)
220
    images = list(reversed(images)) if reverse else images
221
    frame_duration_ms = duration / num_images * 1000
222
    if frame_duration_ms < 10:
223
        fps = round(1000 / frame_duration_ms)
224
        frame_duration_ms = 10
225
        new_duration = frame_duration_ms * num_images / 1000
226
        message = (
227
            'The computed frame rate from the given duration is too high'
228
            f' ({fps} fps). The highest possible frame rate in the GIF'
229
            ' file format specification is 100 fps. The duration has been set'
230
            f' to {new_duration:.1f} seconds, instead of {duration:.1f}'
231
        )
232
        warnings.warn(message)
233
    images[0].save(
234
        output_path,
235
        save_all=True,
236
        append_images=images[1:],
237
        optimize=optimize,
238
        duration=frame_duration_ms,
239
        loop=loop,
240
    )
241