Passed
Pull Request — main (#1346)
by Fernando
01:27
created

torchio.visualization   F

Complexity

Total Complexity 60

Size/Duplication

Total Lines 343
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 269
dl 0
loc 343
rs 3.6
c 0
b 0
f 0
wmc 60

10 Functions

Rating   Name   Duplication   Size   Complexity  
A color_labels() 0 13 4
A get_num_bins() 0 16 1
F plot_volume() 0 84 17
D plot_subject() 0 48 13
A rotate() 0 6 2
A plot_histogram() 0 9 3
A import_mpl_plt() 0 7 2
A _create_categorical_colormap() 0 13 2
B make_gif() 0 52 7
C make_video() 0 54 9

How to fix   Complexity   

Complexity

Complex classes like torchio.visualization often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from __future__ import annotations
2
3
import warnings
4
from pathlib import Path
5
from typing import TYPE_CHECKING
6
7
import numpy as np
8
import torch
9
10
from .data.image import Image
11
from .data.image import LabelMap
12
from .data.subject import Subject
13
from .external.imports import get_ffmpeg
14
from .transforms.preprocessing.intensity.rescale import RescaleIntensity
15
from .transforms.preprocessing.spatial.to_canonical import ToCanonical
16
from .types import TypePath
17
18
if TYPE_CHECKING:
19
    from matplotlib.colors import ListedColormap
20
21
22
def import_mpl_plt():
23
    try:
24
        import matplotlib as mpl
25
        import matplotlib.pyplot as plt
26
    except ImportError as e:
27
        raise ImportError('Install matplotlib for plotting support') from e
28
    return mpl, plt
29
30
31
def rotate(image, radiological=True, n=-1):
32
    # Rotate for visualization purposes
33
    image = np.rot90(image, n)
34
    if radiological:
35
        image = np.fliplr(image)
36
    return image
37
38
39
def _create_categorical_colormap(data: torch.Tensor) -> ListedColormap:
40
    num_classes = int(data.max())
41
    mpl, _ = import_mpl_plt()
42
43
    if num_classes == 1:  # just do white
44
        distinct_colors = [(1, 1, 1)]
45
    else:
46
        from .external.imports import get_distinctipy
47
48
        distinctipy = get_distinctipy()
49
        distinct_colors = distinctipy.get_colors(num_classes, rng=0)
50
    colors = [(0, 0, 0), *distinct_colors]  # prepend black
51
    return mpl.colors.ListedColormap(colors)
52
53
54
def plot_volume(
55
    image: Image,
56
    radiological=True,
57
    channel=-1,  # default to foreground for binary maps
58
    axes=None,
59
    cmap=None,
60
    output_path=None,
61
    show=True,
62
    xlabels=True,
63
    percentiles: tuple[float, float] = (0.5, 99.5),
64
    figsize=None,
65
    title=None,
66
    reorient=True,
67
    indices=None,
68
    **imshow_kwargs,
69
):
70
    _, plt = import_mpl_plt()
71
    fig = None
72
    if axes is None:
73
        fig, axes = plt.subplots(1, 3, figsize=figsize)
74
    sag_axis, cor_axis, axi_axis = axes
75
76
    if reorient:
77
        image = ToCanonical()(image)  # type: ignore[assignment]
78
    data = image.data[channel]
79
    if indices is None:
80
        indices = np.array(data.shape) // 2
81
    i, j, k = indices
82
    slice_x = rotate(data[i, :, :], radiological=radiological)
83
    slice_y = rotate(data[:, j, :], radiological=radiological)
84
    slice_z = rotate(data[:, :, k], radiological=radiological)
85
    is_label = isinstance(image, LabelMap)
86
    if isinstance(cmap, dict):
87
        slices = slice_x, slice_y, slice_z
88
        slice_x, slice_y, slice_z = color_labels(slices, cmap)
89
    else:
90
        if cmap is None:
91
            cmap = _create_categorical_colormap(data) if is_label else 'gray'
92
        imshow_kwargs['cmap'] = cmap
93
94
    if is_label:
95
        imshow_kwargs['interpolation'] = 'none'
96
97
    sr, sa, ss = image.spacing
98
    imshow_kwargs['origin'] = 'lower'
99
100
    if percentiles is not None and not is_label:
101
        p1, p2 = np.percentile(data, percentiles)
102
        imshow_kwargs['vmin'] = p1
103
        imshow_kwargs['vmax'] = p2
104
105
    sag_aspect = ss / sa
106
    sag_axis.imshow(slice_x, aspect=sag_aspect, **imshow_kwargs)
107
    if xlabels:
108
        sag_axis.set_xlabel('A')
109
    sag_axis.set_ylabel('S')
110
    sag_axis.invert_xaxis()
111
    sag_axis.set_title('Sagittal')
112
113
    cor_aspect = ss / sr
114
    cor_axis.imshow(slice_y, aspect=cor_aspect, **imshow_kwargs)
115
    if xlabels:
116
        cor_axis.set_xlabel('R')
117
    cor_axis.set_ylabel('S')
118
    cor_axis.invert_xaxis()
119
    cor_axis.set_title('Coronal')
120
121
    axi_aspect = sa / sr
122
    axi_axis.imshow(slice_z, aspect=axi_aspect, **imshow_kwargs)
123
    if xlabels:
124
        axi_axis.set_xlabel('R')
125
    axi_axis.set_ylabel('A')
126
    axi_axis.invert_xaxis()
127
    axi_axis.set_title('Axial')
128
129
    plt.tight_layout()
130
    if title is not None:
131
        plt.suptitle(title)
132
133
    if output_path is not None and fig is not None:
134
        fig.savefig(output_path)
135
    if show:
136
        plt.show()
137
    return fig
138
139
140
def plot_subject(
141
    subject: Subject,
142
    cmap_dict=None,
143
    show=True,
144
    output_path=None,
145
    figsize=None,
146
    clear_axes=True,
147
    **plot_volume_kwargs,
148
):
149
    _, plt = import_mpl_plt()
150
    num_images = len(subject)
151
    many_images = num_images > 2
152
    subplots_kwargs = {'figsize': figsize}
153
    try:
154
        if clear_axes:
155
            subject.check_consistent_spatial_shape()
156
            subplots_kwargs['sharex'] = 'row' if many_images else 'col'
157
            subplots_kwargs['sharey'] = 'row' if many_images else 'col'
158
    except RuntimeError:  # different shapes in subject
159
        pass
160
    args = (3, num_images) if many_images else (num_images, 3)
161
    fig, axes = plt.subplots(*args, **subplots_kwargs)
162
    # The array of axes must be 2D so that it can be indexed correctly within
163
    # the plot_volume() function
164
    axes = axes.T if many_images else axes.reshape(-1, 3)
165
    iterable = enumerate(subject.get_images_dict(intensity_only=False).items())
166
    axes_names = 'sagittal', 'coronal', 'axial'
167
    for image_index, (name, image) in iterable:
168
        image_axes = axes[image_index]
169
        cmap = None
170
        if cmap_dict is not None and name in cmap_dict:
171
            cmap = cmap_dict[name]
172
        last_row = image_index == len(axes) - 1
173
        plot_volume(
174
            image,
175
            axes=image_axes,
176
            show=False,
177
            cmap=cmap,
178
            xlabels=last_row,
179
            **plot_volume_kwargs,
180
        )
181
        for axis, axis_name in zip(image_axes, axes_names):
182
            axis.set_title(f'{name} ({axis_name})')
183
    plt.tight_layout()
184
    if output_path is not None:
185
        fig.savefig(output_path)
186
    if show:
187
        plt.show()
188
189
190
def get_num_bins(x: np.ndarray) -> int:
191
    """Get the optimal number of bins for a histogram.
192
193
    This method uses the Freedman–Diaconis rule to compute the histogram that
194
    minimizes "the integral of the squared difference between the histogram
195
    (i.e., relative frequency density) and the density of the theoretical
196
    probability distribution" (`Wikipedia <https://en.wikipedia.org/wiki/Freedman%E2%80%93Diaconis_rule>`_).
197
198
    Args:
199
        x: Input values.
200
    """
201
    # Freedman–Diaconis number of bins
202
    q25, q75 = np.percentile(x, [25, 75])
203
    bin_width = 2 * (q75 - q25) * len(x) ** (-1 / 3)
204
    bins = round((x.max() - x.min()) / bin_width)
205
    return bins
206
207
208
def plot_histogram(x: np.ndarray, show=True, **kwargs) -> None:
209
    _, plt = import_mpl_plt()
210
    plt.hist(x, bins=get_num_bins(x), **kwargs)
211
    plt.xlabel('Intensity')
212
    density = kwargs.pop('density', False)
213
    ylabel = 'Density' if density else 'Frequency'
214
    plt.ylabel(ylabel)
215
    if show:
216
        plt.show()
217
218
219
def color_labels(arrays, cmap_dict):
220
    results = []
221
    for array in arrays:
222
        si, sj = array.shape
223
        rgb = np.zeros((si, sj, 3), dtype=np.uint8)
224
        for label, color in cmap_dict.items():
225
            if isinstance(color, str):
226
                mpl, _ = import_mpl_plt()
227
                color = mpl.colors.to_rgb(color)
228
                color = [255 * n for n in color]
229
            rgb[array == label] = color
230
        results.append(rgb)
231
    return results
232
233
234
def make_gif(
235
    tensor: torch.Tensor,
236
    axis: int,
237
    duration: float,  # of full gif
238
    output_path: TypePath,
239
    loop: int = 0,
240
    optimize: bool = True,
241
    rescale: bool = True,
242
    reverse: bool = False,
243
) -> None:
244
    try:
245
        from PIL import Image as ImagePIL
246
    except ModuleNotFoundError as e:
247
        message = 'Please install Pillow to use Image.to_gif(): pip install Pillow'
248
        raise RuntimeError(message) from e
249
    transform = RescaleIntensity((0, 255))
250
    tensor = transform(tensor) if rescale else tensor  # type: ignore[assignment]
251
    single_channel = len(tensor) == 1
252
253
    # Move channels dimension to the end and bring selected axis to 0
254
    axes = np.roll(range(1, 4), -axis)
255
    tensor = tensor.permute(*axes, 0)
256
257
    if single_channel:
258
        mode = 'P'
259
        tensor = tensor[..., 0]
260
    else:
261
        mode = 'RGB'
262
    array = tensor.byte().numpy()
263
    n = 2 if axis == 1 else 1
264
    images = [ImagePIL.fromarray(rotate(i, n=n)).convert(mode) for i in array]
265
    num_images = len(images)
266
    images = list(reversed(images)) if reverse else images
267
    frame_duration_ms = duration / num_images * 1000
268
    if frame_duration_ms < 10:
269
        fps = round(1000 / frame_duration_ms)
270
        frame_duration_ms = 10
271
        new_duration = frame_duration_ms * num_images / 1000
272
        message = (
273
            'The computed frame rate from the given duration is too high'
274
            f' ({fps} fps). The highest possible frame rate in the GIF'
275
            ' file format specification is 100 fps. The duration has been set'
276
            f' to {new_duration:.1f} seconds, instead of {duration:.1f}'
277
        )
278
        warnings.warn(message, RuntimeWarning, stacklevel=2)
279
    images[0].save(
280
        output_path,
281
        save_all=True,
282
        append_images=images[1:],
283
        optimize=optimize,
284
        duration=frame_duration_ms,
285
        loop=loop,
286
    )
287
288
289
def make_video(
290
    tensor: torch.Tensor,
291
    output_path: TypePath,
292
    duration: float | None = None,
293
    frame_rate: float | None = None,
294
) -> None:
295
    """Encode a 3D array into an MP4 video."""
296
    ffmpeg = get_ffmpeg()
297
298
    if duration is None and frame_rate is None:
299
        message = 'Either duration or frame_rate must be provided.'
300
        raise ValueError(message)
301
    if duration is not None and frame_rate is not None:
302
        message = 'Provide either duration or frame_rate, not both.'
303
        raise ValueError(message)
304
    if len(tensor) > 1:
305
        message = 'Only single-channel tensors are supported for video output for now.'
306
        raise ValueError(message)
307
    frames = tensor.numpy()[0].T
308
    num_frames = len(frames)
309
    if duration is not None:
310
        frame_rate = num_frames / duration
311
312
    output_path = Path(output_path)
313
    if output_path.suffix.lower() != '.mp4':
314
        message = 'Only .mp4 files are supported for video output.'
315
        raise ValueError(message)
316
317
    first = frames[0]
318
    height, width = first.shape
319
320
    process = (
321
        ffmpeg.input(
322
            'pipe:',
323
            format='rawvideo',
324
            pix_fmt='gray',
325
            s=f'{width}x{height}',
326
            framerate=frame_rate,
327
        )
328
        .output(
329
            str(output_path),
330
            vcodec='libx264',
331
            pix_fmt='yuv420p',
332
        )
333
        .overwrite_output()
334
        .run_async(pipe_stdin=True)
335
    )
336
337
    for array in frames:
338
        buffer = array.tobytes()
339
        process.stdin.write(buffer)
340
341
    process.stdin.close()
342
    process.wait()
343