Passed
Push — master ( 6deb01...ca4890 )
by Fernando
01:56 queued 40s
created

torchio.visualization.make_gif()   B

Complexity

Conditions 7

Size

Total Lines 54
Code Lines 43

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 43
dl 0
loc 54
rs 7.448
c 0
b 0
f 0
cc 7
nop 8

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import warnings
2
3
import torch
4
import numpy as np
5
6
from .typing import TypePath
7
from .data.subject import Subject
8
from .data.image import Image, LabelMap
9
from .transforms.preprocessing.spatial.to_canonical import ToCanonical
10
from .transforms.preprocessing.intensity.rescale import RescaleIntensity
11
12
13
def import_mpl_plt():
14
    try:
15
        import matplotlib as mpl
16
        import matplotlib.pyplot as plt
17
    except ImportError as e:
18
        raise ImportError('Install matplotlib for plotting support') from e
19
    return mpl, plt
20
21
22
def rotate(image, radiological=True, n=-1):
23
    # Rotate for visualization purposes
24
    image = np.rot90(image, n)
25
    if radiological:
26
        image = np.fliplr(image)
27
    return image
28
29
30
def plot_volume(
31
        image: Image,
32
        radiological=True,
33
        channel=-1,  # default to foreground for binary maps
34
        axes=None,
35
        cmap=None,
36
        output_path=None,
37
        show=True,
38
        xlabels=True,
39
        percentiles=(0.5, 99.5),
40
        figsize=None,
41
        reorient=True,
42
        ):
43
    _, plt = import_mpl_plt()
44
    fig = None
45
    if axes is None:
46
        fig, axes = plt.subplots(1, 3, figsize=figsize)
47
    sag_axis, cor_axis, axi_axis = axes
48
49
    if reorient:
50
        image = ToCanonical()(image)
51
    data = image.data[channel]
52
    indices = np.array(data.shape) // 2
53
    i, j, k = indices
54
    slice_x = rotate(data[i, :, :], radiological=radiological)
55
    slice_y = rotate(data[:, j, :], radiological=radiological)
56
    slice_z = rotate(data[:, :, k], radiological=radiological)
57
    kwargs = {}
58
    is_label = isinstance(image, LabelMap)
59
    if isinstance(cmap, dict):
60
        slices = slice_x, slice_y, slice_z
61
        slice_x, slice_y, slice_z = color_labels(slices, cmap)
62
    else:
63
        if cmap is None:
64
            cmap = 'cubehelix' if is_label else 'gray'
65
        kwargs['cmap'] = cmap
66
    if is_label:
67
        kwargs['interpolation'] = 'none'
68
69
    sr, sa, ss = image.spacing
70
    kwargs['origin'] = 'lower'
71
72
    if percentiles is not None:
73
        p1, p2 = np.percentile(data, percentiles)
74
        kwargs['vmin'] = p1
75
        kwargs['vmax'] = p2
76
77
    sag_aspect = ss / sa
78
    sag_axis.imshow(slice_x, aspect=sag_aspect, **kwargs)
79
    if xlabels:
80
        sag_axis.set_xlabel('A')
81
    sag_axis.set_ylabel('S')
82
    sag_axis.invert_xaxis()
83
    sag_axis.set_title('Sagittal')
84
85
    cor_aspect = ss / sr
86
    cor_axis.imshow(slice_y, aspect=cor_aspect, **kwargs)
87
    if xlabels:
88
        cor_axis.set_xlabel('R')
89
    cor_axis.set_ylabel('S')
90
    cor_axis.invert_xaxis()
91
    cor_axis.set_title('Coronal')
92
93
    axi_aspect = sa / sr
94
    axi_axis.imshow(slice_z, aspect=axi_aspect, **kwargs)
95
    if xlabels:
96
        axi_axis.set_xlabel('R')
97
    axi_axis.set_ylabel('A')
98
    axi_axis.invert_xaxis()
99
    axi_axis.set_title('Axial')
100
101
    plt.tight_layout()
102
    if output_path is not None and fig is not None:
103
        fig.savefig(output_path)
104
    if show:
105
        plt.show()
106
107
108
def plot_subject(
109
        subject: Subject,
110
        cmap_dict=None,
111
        show=True,
112
        output_path=None,
113
        figsize=None,
114
        clear_axes=True,
115
        **kwargs,
116
        ):
117
    _, plt = import_mpl_plt()
118
    num_images = len(subject)
119
    many_images = num_images > 2
120
    subplots_kwargs = {'figsize': figsize}
121
    try:
122
        if clear_axes:
123
            subject.check_consistent_spatial_shape()
124
            subplots_kwargs['sharex'] = 'row' if many_images else 'col'
125
            subplots_kwargs['sharey'] = 'row' if many_images else 'col'
126
    except RuntimeError:  # different shapes in subject
127
        pass
128
    args = (3, num_images) if many_images else (num_images, 3)
129
    fig, axes = plt.subplots(*args, **subplots_kwargs)
130
    # The array of axes must be 2D so that it can be indexed correctly within
131
    # the plot_volume() function
132
    axes = axes.T if many_images else axes.reshape(-1, 3)
133
    iterable = enumerate(subject.get_images_dict(intensity_only=False).items())
134
    axes_names = 'sagittal', 'coronal', 'axial'
135
    for image_index, (name, image) in iterable:
136
        image_axes = axes[image_index]
137
        cmap = None
138
        if cmap_dict is not None and name in cmap_dict:
139
            cmap = cmap_dict[name]
140
        last_row = image_index == len(axes) - 1
141
        plot_volume(
142
            image,
143
            axes=image_axes,
144
            show=False,
145
            cmap=cmap,
146
            xlabels=last_row,
147
            **kwargs,
148
        )
149
        for axis, axis_name in zip(image_axes, axes_names):
150
            axis.set_title(f'{name} ({axis_name})')
151
    plt.tight_layout()
152
    if output_path is not None:
153
        fig.savefig(output_path)
154
    if show:
155
        plt.show()
156
157
158
def color_labels(arrays, cmap_dict):
159
    results = []
160
    for array in arrays:
161
        si, sj = array.shape
162
        rgb = np.zeros((si, sj, 3), dtype=np.uint8)
163
        for label, color in cmap_dict.items():
164
            if isinstance(color, str):
165
                mpl, _ = import_mpl_plt()
166
                color = mpl.colors.to_rgb(color)
167
                color = [255 * n for n in color]
168
            rgb[array == label] = color
169
        results.append(rgb)
170
    return results
171
172
173
def make_gif(
174
        tensor: torch.Tensor,
175
        axis: int,
176
        duration: float,  # of full gif
177
        output_path: TypePath,
178
        loop: int = 0,
179
        optimize: bool = True,
180
        rescale: bool = True,
181
        reverse: bool = False,
182
        ) -> None:
183
    try:
184
        from PIL import Image as ImagePIL
185
    except ModuleNotFoundError as e:
186
        message = (
187
            'Please install Pillow to use Image.to_gif():'
188
            ' pip install Pillow'
189
        )
190
        raise RuntimeError(message) from e
191
    tensor = RescaleIntensity((0, 255))(tensor) if rescale else tensor
192
    single_channel = len(tensor) == 1
193
194
    # Move channels dimension to the end and bring selected axis to 0
195
    axes = np.roll(range(1, 4), -axis)
196
    tensor = tensor.permute(*axes, 0)
197
198
    if single_channel:
199
        mode = 'P'
200
        tensor = tensor[..., 0]
201
    else:
202
        mode = 'RGB'
203
    array = tensor.byte().numpy()
204
    n = 2 if axis == 1 else 1
205
    images = [ImagePIL.fromarray(rotate(i, n=n)).convert(mode) for i in array]
206
    num_images = len(images)
207
    images = list(reversed(images)) if reverse else images
208
    frame_duration_ms = duration / num_images * 1000
209
    if frame_duration_ms < 10:
210
        fps = round(1000 / frame_duration_ms)
211
        frame_duration_ms = 10
212
        new_duration = frame_duration_ms * num_images / 1000
213
        message = (
214
            'The computed frame rate from the given duration is too high'
215
            f' ({fps} fps). The highest possible frame rate in the GIF'
216
            ' file format specification is 100 fps. The duration has been set'
217
            f' to {new_duration:.1f} seconds, instead of {duration:.1f}'
218
        )
219
        warnings.warn(message)
220
    images[0].save(
221
        output_path,
222
        save_all=True,
223
        append_images=images[1:],
224
        optimize=optimize,
225
        duration=frame_duration_ms,
226
        loop=loop,
227
    )
228