Passed
Pull Request — master (#682)
by
unknown
01:13
created

torchio.visualization.make_gif()   B

Complexity

Conditions 7

Size

Total Lines 54
Code Lines 43

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 43
dl 0
loc 54
rs 7.448
c 0
b 0
f 0
cc 7
nop 8

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import warnings
2
3
import torch
4
import numpy as np
5
6
from .typing import TypePath
7
from .data.subject import Subject
8
from .data.image import Image, LabelMap
9
from .transforms.preprocessing.spatial.to_canonical import ToCanonical
10
from .transforms.preprocessing.intensity.rescale import RescaleIntensity
11
12
13
def import_mpl_plt():
14
    try:
15
        import matplotlib as mpl
16
        import matplotlib.pyplot as plt
17
    except ImportError as e:
18
        raise ImportError('Install matplotlib for plotting support') from e
19
    return mpl, plt
20
21
22
def rotate(image, radiological=True, n=-1):
23
    # Rotate for visualization purposes
24
    image = np.rot90(image, n)
25
    if radiological:
26
        image = np.fliplr(image)
27
    return image
28
29
30
def plot_volume(
31
        image: Image,
32
        radiological=True,
33
        channel=-1,  # default to foreground for binary maps
34
        axes=None,
35
        cmap=None,
36
        output_path=None,
37
        show=True,
38
        xlabels=True,
39
        percentiles=(0.5, 99.5),
40
        figsize=None,
41
        reorient=True,
42
        indices=None,
43
        ):
44
    _, plt = import_mpl_plt()
45
    fig = None
46
    if axes is None:
47
        fig, axes = plt.subplots(1, 3, figsize=figsize)
48
    sag_axis, cor_axis, axi_axis = axes
49
50
    if reorient:
51
        image = ToCanonical()(image)
52
    data = image.data[channel]
53
    if indices is None:
54
        indices = np.array(data.shape) // 2
55
    i, j, k = indices
56
    slice_x = rotate(data[i, :, :], radiological=radiological)
57
    slice_y = rotate(data[:, j, :], radiological=radiological)
58
    slice_z = rotate(data[:, :, k], radiological=radiological)
59
    kwargs = {}
60
    is_label = isinstance(image, LabelMap)
61
    if isinstance(cmap, dict):
62
        slices = slice_x, slice_y, slice_z
63
        slice_x, slice_y, slice_z = color_labels(slices, cmap)
64
    else:
65
        if cmap is None:
66
            cmap = 'cubehelix' if is_label else 'gray'
67
        kwargs['cmap'] = cmap
68
    if is_label:
69
        kwargs['interpolation'] = 'none'
70
71
    sr, sa, ss = image.spacing
72
    kwargs['origin'] = 'lower'
73
74
    if percentiles is not None:
75
        p1, p2 = np.percentile(data, percentiles)
76
        kwargs['vmin'] = p1
77
        kwargs['vmax'] = p2
78
79
    sag_aspect = ss / sa
80
    sag_axis.imshow(slice_x, aspect=sag_aspect, **kwargs)
81
    if xlabels:
82
        sag_axis.set_xlabel('A')
83
    sag_axis.set_ylabel('S')
84
    sag_axis.invert_xaxis()
85
    #sag_axis.set_title('Sagittal')
86
87
    cor_aspect = ss / sr
88
    cor_axis.imshow(slice_y, aspect=cor_aspect, **kwargs)
89
    if xlabels:
90
        cor_axis.set_xlabel('R')
91
    cor_axis.set_ylabel('S')
92
    cor_axis.invert_xaxis()
93
    #cor_axis.set_title('Coronal')
94
95
    axi_aspect = sa / sr
96
    axi_axis.imshow(slice_z, **kwargs)
97
    if xlabels:
98
        axi_axis.set_xlabel('R')
99
    axi_axis.set_ylabel('A')
100
    axi_axis.invert_xaxis()
101
    #axi_axis.set_title('Axial')
102
    
103
    sag_axis.axis('off')
104
    cor_axis.axis('off')
105
    axi_axis.axis('off')
106
107
    #plt.tight_layout()
108
109
    if output_path is not None and fig is not None:
110
        fig.savefig(output_path, transparent=True)
111
    if show:
112
        plt.show()
113
114
115
def plot_subject(
116
        subject: Subject,
117
        cmap_dict=None,
118
        show=True,
119
        output_path=None,
120
        figsize=None,
121
        clear_axes=True,
122
        **kwargs,
123
        ):
124
    _, plt = import_mpl_plt()
125
    num_images = len(subject)
126
    many_images = num_images > 2
127
    subplots_kwargs = {'figsize': figsize}
128
    try:
129
        if clear_axes:
130
            subject.check_consistent_spatial_shape()
131
            subplots_kwargs['sharex'] = 'row' if many_images else 'col'
132
            subplots_kwargs['sharey'] = 'row' if many_images else 'col'
133
    except RuntimeError:  # different shapes in subject
134
        pass
135
    args = (3, num_images) if many_images else (num_images, 3)
136
    fig, axes = plt.subplots(*args, **subplots_kwargs)
137
    # The array of axes must be 2D so that it can be indexed correctly within
138
    # the plot_volume() function
139
    axes = axes.T if many_images else axes.reshape(-1, 3)
140
    iterable = enumerate(subject.get_images_dict(intensity_only=False).items())
141
    axes_names = 'sagittal', 'coronal', 'axial'
142
    for image_index, (name, image) in iterable:
143
        image_axes = axes[image_index]
144
        cmap = None
145
        if cmap_dict is not None and name in cmap_dict:
146
            cmap = cmap_dict[name]
147
        last_row = image_index == len(axes) - 1
148
        plot_volume(
149
            image,
150
            axes=image_axes,
151
            show=False,
152
            cmap=cmap,
153
            xlabels=last_row,
154
            **kwargs,
155
        )
156
        for axis, axis_name in zip(image_axes, axes_names):
157
            axis.set_title(f'{name} ({axis_name})')
158
    plt.tight_layout()
159
    if output_path is not None:
160
        fig.savefig(output_path)
161
    if show:
162
        plt.show()
163
164
165
def color_labels(arrays, cmap_dict):
166
    results = []
167
    for array in arrays:
168
        si, sj = array.shape
169
        rgb = np.zeros((si, sj, 3), dtype=np.uint8)
170
        for label, color in cmap_dict.items():
171
            if isinstance(color, str):
172
                mpl, _ = import_mpl_plt()
173
                color = mpl.colors.to_rgb(color)
174
                color = [255 * n for n in color]
175
            rgb[array == label] = color
176
        results.append(rgb)
177
    return results
178
179
180
def make_gif(
181
        tensor: torch.Tensor,
182
        axis: int,
183
        duration: float,  # of full gif
184
        output_path: TypePath,
185
        loop: int = 0,
186
        optimize: bool = True,
187
        rescale: bool = True,
188
        reverse: bool = False,
189
        ) -> None:
190
    try:
191
        from PIL import Image as ImagePIL
192
    except ModuleNotFoundError as e:
193
        message = (
194
            'Please install Pillow to use Image.to_gif():'
195
            ' pip install Pillow'
196
        )
197
        raise RuntimeError(message) from e
198
    tensor = RescaleIntensity((0, 255))(tensor) if rescale else tensor
199
    single_channel = len(tensor) == 1
200
201
    # Move channels dimension to the end and bring selected axis to 0
202
    axes = np.roll(range(1, 4), -axis)
203
    tensor = tensor.permute(*axes, 0)
204
205
    if single_channel:
206
        mode = 'P'
207
        tensor = tensor[..., 0]
208
    else:
209
        mode = 'RGB'
210
    array = tensor.byte().numpy()
211
    n = 2 if axis == 1 else 1
212
    images = [ImagePIL.fromarray(rotate(i, n=n)).convert(mode) for i in array]
213
    num_images = len(images)
214
    images = list(reversed(images)) if reverse else images
215
    frame_duration_ms = duration / num_images * 1000
216
    if frame_duration_ms < 10:
217
        fps = round(1000 / frame_duration_ms)
218
        frame_duration_ms = 10
219
        new_duration = frame_duration_ms * num_images / 1000
220
        message = (
221
            'The computed frame rate from the given duration is too high'
222
            f' ({fps} fps). The highest possible frame rate in the GIF'
223
            ' file format specification is 100 fps. The duration has been set'
224
            f' to {new_duration:.1f} seconds, instead of {duration:.1f}'
225
        )
226
        warnings.warn(message)
227
    images[0].save(
228
        output_path,
229
        save_all=True,
230
        append_images=images[1:],
231
        optimize=optimize,
232
        duration=frame_duration_ms,
233
        loop=loop,
234
    )
235