Passed
Pull Request — master (#682)
by
unknown
01:16
created

torchio.visualization.make_gif()   B

Complexity

Conditions 7

Size

Total Lines 54
Code Lines 43

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 43
dl 0
loc 54
rs 7.448
c 0
b 0
f 0
cc 7
nop 8

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import warnings
2
3
import torch
4
import numpy as np
5
6
from .typing import TypePath
7
from .data.subject import Subject
8
from .data.image import Image, LabelMap
9
from .transforms.preprocessing.spatial.to_canonical import ToCanonical
10
from .transforms.preprocessing.intensity.rescale import RescaleIntensity
11
12
13
def import_mpl_plt():
14
    try:
15
        import matplotlib as mpl
16
        import matplotlib.pyplot as plt
17
    except ImportError as e:
18
        raise ImportError('Install matplotlib for plotting support') from e
19
    return mpl, plt
20
21
22
def rotate(image, radiological=True, n=-1):
23
    # Rotate for visualization purposes
24
    image = np.rot90(image, n)
25
    if radiological:
26
        image = np.fliplr(image)
27
    return image
28
29
30
def plot_volume(
31
        image: Image,
32
        radiological=True,
33
        channel=-1,  # default to foreground for binary maps
34
        axes=None,
35
        cmap=None,
36
        output_path=None,
37
        show=True,
38
        xlabels=True,
39
        percentiles=(0.5, 99.5),
40
        figsize=None,
41
        reorient=True,
42
        indices=None,
43
        ):
44
    _, plt = import_mpl_plt()
45
    fig = None
46
    if axes is None:
47
        fig, axes = plt.subplots(1, 3, figsize=figsize)
48
    sag_axis, cor_axis, axi_axis = axes
49
50
    if reorient:
51
        image = ToCanonical()(image)
52
    data = image.data[channel]
53
    if indices is None:
54
        indices = np.array(data.shape) // 2
55
    i, j, k = indices
56
    slice_x = rotate(data[i, :, :], radiological=radiological)
57
    slice_y = rotate(data[:, j, :], radiological=radiological)
58
    slice_z = rotate(data[:, :, k], radiological=radiological)
59
    kwargs = {}
60
    is_label = isinstance(image, LabelMap)
61
    if isinstance(cmap, dict):
62
        slices = slice_x, slice_y, slice_z
63
        slice_x, slice_y, slice_z = color_labels(slices, cmap)
64
    else:
65
        if cmap is None:
66
            cmap = 'cubehelix' if is_label else 'gray'
67
        kwargs['cmap'] = cmap
68
    if is_label:
69
        kwargs['interpolation'] = 'none'
70
71
    sr, sa, ss = image.spacing
72
    kwargs['origin'] = 'lower'
73
74
    if percentiles is not None:
75
        p1, p2 = np.percentile(data, percentiles)
76
        kwargs['vmin'] = p1
77
        kwargs['vmax'] = p2
78
79
    sag_aspect = ss / sa
80
    sag_axis.imshow(slice_x, aspect=sag_aspect, **kwargs)
81
    if xlabels:
82
        sag_axis.set_xlabel('A')
83
    sag_axis.set_ylabel('S')
84
    sag_axis.invert_xaxis()
85
    sag_axis.set_title('Sagittal')
86
87
    cor_aspect = ss / sr
88
    cor_axis.imshow(slice_y, aspect=cor_aspect, **kwargs)
89
    if xlabels:
90
        cor_axis.set_xlabel('R')
91
    cor_axis.set_ylabel('S')
92
    cor_axis.invert_xaxis()
93
    cor_axis.set_title('Coronal')
94
95
    axi_aspect = sa / sr
96
    axi_axis.imshow(slice_z, aspect=axi_aspect, **kwargs)
97
    if xlabels:
98
        axi_axis.set_xlabel('R')
99
    axi_axis.set_ylabel('A')
100
    axi_axis.invert_xaxis()
101
    axi_axis.set_title('Axial')
102
    
103
    axes.axis('off')
104
105
    plt.tight_layout()
106
    if output_path is not None and fig is not None:
107
        fig.savefig(output_path, transparent=True)
108
    if show:
109
        plt.show()
110
111
112
def plot_subject(
113
        subject: Subject,
114
        cmap_dict=None,
115
        show=True,
116
        output_path=None,
117
        figsize=None,
118
        clear_axes=True,
119
        **kwargs,
120
        ):
121
    _, plt = import_mpl_plt()
122
    num_images = len(subject)
123
    many_images = num_images > 2
124
    subplots_kwargs = {'figsize': figsize}
125
    try:
126
        if clear_axes:
127
            subject.check_consistent_spatial_shape()
128
            subplots_kwargs['sharex'] = 'row' if many_images else 'col'
129
            subplots_kwargs['sharey'] = 'row' if many_images else 'col'
130
    except RuntimeError:  # different shapes in subject
131
        pass
132
    args = (3, num_images) if many_images else (num_images, 3)
133
    fig, axes = plt.subplots(*args, **subplots_kwargs)
134
    # The array of axes must be 2D so that it can be indexed correctly within
135
    # the plot_volume() function
136
    axes = axes.T if many_images else axes.reshape(-1, 3)
137
    iterable = enumerate(subject.get_images_dict(intensity_only=False).items())
138
    axes_names = 'sagittal', 'coronal', 'axial'
139
    for image_index, (name, image) in iterable:
140
        image_axes = axes[image_index]
141
        cmap = None
142
        if cmap_dict is not None and name in cmap_dict:
143
            cmap = cmap_dict[name]
144
        last_row = image_index == len(axes) - 1
145
        plot_volume(
146
            image,
147
            axes=image_axes,
148
            show=False,
149
            cmap=cmap,
150
            xlabels=last_row,
151
            **kwargs,
152
        )
153
        for axis, axis_name in zip(image_axes, axes_names):
154
            axis.set_title(f'{name} ({axis_name})')
155
    plt.tight_layout()
156
    if output_path is not None:
157
        fig.savefig(output_path)
158
    if show:
159
        plt.show()
160
    plt.close(fig) 
161
162
163
def color_labels(arrays, cmap_dict):
164
    results = []
165
    for array in arrays:
166
        si, sj = array.shape
167
        rgb = np.zeros((si, sj, 3), dtype=np.uint8)
168
        for label, color in cmap_dict.items():
169
            if isinstance(color, str):
170
                mpl, _ = import_mpl_plt()
171
                color = mpl.colors.to_rgb(color)
172
                color = [255 * n for n in color]
173
            rgb[array == label] = color
174
        results.append(rgb)
175
    return results
176
177
178
def make_gif(
179
        tensor: torch.Tensor,
180
        axis: int,
181
        duration: float,  # of full gif
182
        output_path: TypePath,
183
        loop: int = 0,
184
        optimize: bool = True,
185
        rescale: bool = True,
186
        reverse: bool = False,
187
        ) -> None:
188
    try:
189
        from PIL import Image as ImagePIL
190
    except ModuleNotFoundError as e:
191
        message = (
192
            'Please install Pillow to use Image.to_gif():'
193
            ' pip install Pillow'
194
        )
195
        raise RuntimeError(message) from e
196
    tensor = RescaleIntensity((0, 255))(tensor) if rescale else tensor
197
    single_channel = len(tensor) == 1
198
199
    # Move channels dimension to the end and bring selected axis to 0
200
    axes = np.roll(range(1, 4), -axis)
201
    tensor = tensor.permute(*axes, 0)
202
203
    if single_channel:
204
        mode = 'P'
205
        tensor = tensor[..., 0]
206
    else:
207
        mode = 'RGB'
208
    array = tensor.byte().numpy()
209
    n = 2 if axis == 1 else 1
210
    images = [ImagePIL.fromarray(rotate(i, n=n)).convert(mode) for i in array]
211
    num_images = len(images)
212
    images = list(reversed(images)) if reverse else images
213
    frame_duration_ms = duration / num_images * 1000
214
    if frame_duration_ms < 10:
215
        fps = round(1000 / frame_duration_ms)
216
        frame_duration_ms = 10
217
        new_duration = frame_duration_ms * num_images / 1000
218
        message = (
219
            'The computed frame rate from the given duration is too high'
220
            f' ({fps} fps). The highest possible frame rate in the GIF'
221
            ' file format specification is 100 fps. The duration has been set'
222
            f' to {new_duration:.1f} seconds, instead of {duration:.1f}'
223
        )
224
        warnings.warn(message)
225
    images[0].save(
226
        output_path,
227
        save_all=True,
228
        append_images=images[1:],
229
        optimize=optimize,
230
        duration=frame_duration_ms,
231
        loop=loop,
232
    )
233