torchio.visualization   F
last analyzed

Complexity

Total Complexity 78

Size/Duplication

Total Lines 468
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 364
dl 0
loc 468
rs 2.16
c 0
b 0
f 0
wmc 78

10 Functions

Rating   Name   Duplication   Size   Complexity  
A rotate() 0 6 2
A import_mpl_plt() 0 7 2
A _create_categorical_colormap() 0 23 2
A color_labels() 0 13 4
A get_num_bins() 0 16 1
F plot_volume() 0 126 23
D plot_subject() 0 48 13
F make_video() 0 117 21
A plot_histogram() 0 9 3
B make_gif() 0 52 7

How to fix   Complexity   

Complexity

Complex classes like torchio.visualization often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from __future__ import annotations
2
3
import warnings
4
from itertools import cycle
5
from pathlib import Path
6
from typing import TYPE_CHECKING
7
from typing import Any
8
9
import numpy as np
10
import torch
11
from einops import rearrange
12
13
from .data.image import Image
14
from .data.image import LabelMap
15
from .data.image import ScalarImage
16
from .data.subject import Subject
17
from .external.imports import get_ffmpeg
18
from .transforms.preprocessing.intensity.rescale import RescaleIntensity
19
from .transforms.preprocessing.intensity.to import To
20
from .transforms.preprocessing.spatial.ensure_shape_multiple import EnsureShapeMultiple
21
from .transforms.preprocessing.spatial.resample import Resample
22
from .transforms.preprocessing.spatial.to_canonical import ToCanonical
23
from .transforms.preprocessing.spatial.to_orientation import ToOrientation
24
from .types import TypePath
25
26
if TYPE_CHECKING:
27
    from matplotlib.colors import BoundaryNorm
28
    from matplotlib.colors import ListedColormap
29
    from matplotlib.figure import Figure
30
31
32
def import_mpl_plt():
33
    try:
34
        import matplotlib as mpl
35
        import matplotlib.pyplot as plt
36
    except ImportError as e:
37
        raise ImportError('Install matplotlib for plotting support') from e
38
    return mpl, plt
39
40
41
def rotate(image: np.ndarray, *, radiological: bool = True, n: int = -1) -> np.ndarray:
42
    # Rotate for visualization purposes
43
    image = np.rot90(image, n, axes=(0, 1))
44
    if radiological:
45
        image = np.fliplr(image)
46
    return image
47
48
49
def _create_categorical_colormap(
50
    data: torch.Tensor,
51
    cmap_name: str = 'glasbey_category10',
52
) -> tuple[ListedColormap, BoundaryNorm]:
53
    num_classes = int(data.max())
54
    mpl, _ = import_mpl_plt()
55
56
    colors = [
57
        (0, 0, 0),  # black for background
58
        (1, 1, 1),  # white for class 1
59
    ]
60
    if num_classes > 1:
61
        from .external.imports import get_colorcet
62
63
        colorcet = get_colorcet()
64
        cmap = getattr(colorcet.cm, cmap_name)
65
        color_cycle = cycle(cmap.colors)
66
        distinct_colors = [next(color_cycle) for _ in range(num_classes - 1)]
67
        colors.extend(distinct_colors)
68
    boundaries = np.arange(-0.5, num_classes + 1.5, 1)
69
    colormap = mpl.colors.ListedColormap(colors)
70
    boundary_norm = mpl.colors.BoundaryNorm(boundaries, ncolors=colormap.N)
71
    return colormap, boundary_norm
72
73
74
def plot_volume(
75
    image: Image,
76
    radiological=True,
77
    channel=None,
78
    axes=None,
79
    cmap=None,
80
    output_path=None,
81
    show=True,
82
    xlabels=True,
83
    percentiles: tuple[float, float] = (0, 100),
84
    figsize=None,
85
    title=None,
86
    reorient=True,
87
    indices=None,
88
    rgb=True,
89
    savefig_kwargs: dict[str, Any] | None = None,
90
    **imshow_kwargs,
91
) -> Figure | None:
92
    _, plt = import_mpl_plt()
93
    fig: Figure | None = None
0 ignored issues
show
introduced by
The variable Figure does not seem to be defined in case TYPE_CHECKING on line 26 is False. Are you sure this can never be the case?
Loading history...
94
    if axes is None:
95
        fig, axes = plt.subplots(1, 3, figsize=figsize)
96
97
    if reorient:
98
        image = ToCanonical()(image)  # type: ignore[assignment]
99
100
    is_label = isinstance(image, LabelMap)
101
    if is_label:  # probabilistic label map
102
        data = image.data[np.newaxis, -1]
103
    elif rgb and image.num_channels == 3:
104
        data = image.data  # keep image as it is
105
    elif channel is None:
106
        data = image.data[0:1]  # just use the first channel
107
    else:
108
        data = image.data[np.newaxis, channel]
109
    data = rearrange(data, 'c x y z -> x y z c')
110
    data_numpy: np.ndarray = data.cpu().numpy()
111
112
    if indices is None:
113
        indices = np.array(data_numpy.shape[:3]) // 2
114
    i, j, k = indices
115
    slice_x = rotate(data_numpy[i, :, :], radiological=radiological)
116
    slice_y = rotate(data_numpy[:, j, :], radiological=radiological)
117
    slice_z = rotate(data_numpy[:, :, k], radiological=radiological)
118
119
    if isinstance(cmap, dict):
120
        slices = slice_x, slice_y, slice_z
121
        slice_x, slice_y, slice_z = color_labels(slices, cmap)
122
    else:
123
        boundary_norm = None
124
        if cmap is None:
125
            if is_label:
126
                cmap, boundary_norm = _create_categorical_colormap(data)
127
            else:
128
                cmap = 'gray'
129
        imshow_kwargs['cmap'] = cmap
130
        imshow_kwargs['norm'] = boundary_norm
131
132
    if is_label:
133
        imshow_kwargs['interpolation'] = 'none'
134
    else:
135
        if 'interpolation' not in imshow_kwargs:
136
            imshow_kwargs['interpolation'] = 'bicubic'
137
138
    imshow_kwargs['origin'] = 'lower'
139
140
    if not is_label:
141
        displayed_data = np.concatenate(
142
            [
143
                slice_x.flatten(),
144
                slice_y.flatten(),
145
                slice_z.flatten(),
146
            ]
147
        )
148
        p1, p2 = np.percentile(displayed_data, percentiles)
149
        if 'vmin' not in imshow_kwargs:
150
            imshow_kwargs['vmin'] = p1
151
        if 'vmax' not in imshow_kwargs:
152
            imshow_kwargs['vmax'] = p2
153
154
    spacing_r, spacing_a, spacing_s = image.spacing
155
    sag_axis, cor_axis, axi_axis = axes
156
    slices_dict = {
157
        'Sagittal': {
158
            'aspect': spacing_s / spacing_a,
159
            'slice': slice_x,
160
            'xlabel': 'A',
161
            'ylabel': 'S',
162
            'axis': sag_axis,
163
        },
164
        'Coronal': {
165
            'aspect': spacing_s / spacing_r,
166
            'slice': slice_y,
167
            'xlabel': 'R',
168
            'ylabel': 'S',
169
            'axis': cor_axis,
170
        },
171
        'Axial': {
172
            'aspect': spacing_a / spacing_r,
173
            'slice': slice_z,
174
            'xlabel': 'R',
175
            'ylabel': 'A',
176
            'axis': axi_axis,
177
        },
178
    }
179
180
    for axis_title, info in slices_dict.items():
181
        axis = info['axis']
182
        axis.imshow(info['slice'], aspect=info['aspect'], **imshow_kwargs)
183
        if xlabels:
184
            axis.set_xlabel(info['xlabel'])
185
        axis.set_ylabel(info['ylabel'])
186
        axis.invert_xaxis()
187
        axis.set_title(axis_title)
188
189
    plt.tight_layout()
190
    if title is not None:
191
        plt.suptitle(title)
192
193
    if output_path is not None and fig is not None:
194
        if savefig_kwargs is None:
195
            savefig_kwargs = {}
196
        fig.savefig(output_path, **savefig_kwargs)
197
    if show:
198
        plt.show()
199
    return fig
200
201
202
def plot_subject(
203
    subject: Subject,
204
    cmap_dict=None,
205
    show=True,
206
    output_path=None,
207
    figsize=None,
208
    clear_axes=True,
209
    **plot_volume_kwargs,
210
):
211
    _, plt = import_mpl_plt()
212
    num_images = len(subject)
213
    many_images = num_images > 2
214
    subplots_kwargs = {'figsize': figsize}
215
    try:
216
        if clear_axes:
217
            subject.check_consistent_spatial_shape()
218
            subplots_kwargs['sharex'] = 'row' if many_images else 'col'
219
            subplots_kwargs['sharey'] = 'row' if many_images else 'col'
220
    except RuntimeError:  # different shapes in subject
221
        pass
222
    args = (3, num_images) if many_images else (num_images, 3)
223
    fig, axes = plt.subplots(*args, **subplots_kwargs)
224
    # The array of axes must be 2D so that it can be indexed correctly within
225
    # the plot_volume() function
226
    axes = axes.T if many_images else axes.reshape(-1, 3)
227
    iterable = enumerate(subject.get_images_dict(intensity_only=False).items())
228
    axes_names = 'sagittal', 'coronal', 'axial'
229
    for image_index, (name, image) in iterable:
230
        image_axes = axes[image_index]
231
        cmap = None
232
        if cmap_dict is not None and name in cmap_dict:
233
            cmap = cmap_dict[name]
234
        last_row = image_index == len(axes) - 1
235
        plot_volume(
236
            image,
237
            axes=image_axes,
238
            show=False,
239
            cmap=cmap,
240
            xlabels=last_row,
241
            **plot_volume_kwargs,
242
        )
243
        for axis, axis_name in zip(image_axes, axes_names):
244
            axis.set_title(f'{name} ({axis_name})')
245
    plt.tight_layout()
246
    if output_path is not None:
247
        fig.savefig(output_path)
248
    if show:
249
        plt.show()
250
251
252
def get_num_bins(x: np.ndarray) -> int:
253
    """Get the optimal number of bins for a histogram.
254
255
    This method uses the Freedman–Diaconis rule to compute the histogram that
256
    minimizes "the integral of the squared difference between the histogram
257
    (i.e., relative frequency density) and the density of the theoretical
258
    probability distribution" (`Wikipedia <https://en.wikipedia.org/wiki/Freedman%E2%80%93Diaconis_rule>`_).
259
260
    Args:
261
        x: Input values.
262
    """
263
    # Freedman–Diaconis number of bins
264
    q25, q75 = np.percentile(x, [25, 75])
265
    bin_width = 2 * (q75 - q25) * len(x) ** (-1 / 3)
266
    bins = round((x.max() - x.min()) / bin_width)
267
    return bins
268
269
270
def plot_histogram(x: np.ndarray, show=True, **kwargs) -> None:
271
    _, plt = import_mpl_plt()
272
    plt.hist(x, bins=get_num_bins(x), **kwargs)
273
    plt.xlabel('Intensity')
274
    density = kwargs.pop('density', False)
275
    ylabel = 'Density' if density else 'Frequency'
276
    plt.ylabel(ylabel)
277
    if show:
278
        plt.show()
279
280
281
def color_labels(arrays, cmap_dict):
282
    results = []
283
    for slice_array in arrays:
284
        si, sj, _ = slice_array.shape
285
        rgb = np.zeros((si, sj, 3), dtype=np.uint8)
286
        for label, color in cmap_dict.items():
287
            if isinstance(color, str):
288
                mpl, _ = import_mpl_plt()
289
                color = mpl.colors.to_rgb(color)
290
                color = [255 * n for n in color]
291
            rgb[slice_array[..., 0] == label] = color
292
        results.append(rgb)
293
    return results
294
295
296
def make_gif(
297
    tensor: torch.Tensor,
298
    axis: int,
299
    duration: float,  # of full gif
300
    output_path: TypePath,
301
    loop: int = 0,
302
    optimize: bool = True,
303
    rescale: bool = True,
304
    reverse: bool = False,
305
) -> None:
306
    try:
307
        from PIL import Image as ImagePIL
308
    except ModuleNotFoundError as e:
309
        message = 'Please install Pillow to use Image.to_gif(): pip install Pillow'
310
        raise RuntimeError(message) from e
311
    transform = RescaleIntensity((0, 255))
312
    tensor = transform(tensor) if rescale else tensor  # type: ignore[assignment]
313
    single_channel = len(tensor) == 1
314
315
    # Move channels dimension to the end and bring selected axis to 0
316
    axes = np.roll(range(1, 4), -axis)
317
    tensor = tensor.permute(*axes, 0)
318
319
    if single_channel:
320
        mode = 'P'
321
        tensor = tensor[..., 0]
322
    else:
323
        mode = 'RGB'
324
    array = tensor.byte().numpy()
325
    n = 2 if axis == 1 else 1
326
    images = [ImagePIL.fromarray(rotate(i, n=n)).convert(mode) for i in array]
327
    num_images = len(images)
328
    images = list(reversed(images)) if reverse else images
329
    frame_duration_ms = duration / num_images * 1000
330
    if frame_duration_ms < 10:
331
        fps = round(1000 / frame_duration_ms)
332
        frame_duration_ms = 10
333
        new_duration = frame_duration_ms * num_images / 1000
334
        message = (
335
            'The computed frame rate from the given duration is too high'
336
            f' ({fps} fps). The highest possible frame rate in the GIF'
337
            ' file format specification is 100 fps. The duration has been set'
338
            f' to {new_duration:.1f} seconds, instead of {duration:.1f}'
339
        )
340
        warnings.warn(message, RuntimeWarning, stacklevel=2)
341
    images[0].save(
342
        output_path,
343
        save_all=True,
344
        append_images=images[1:],
345
        optimize=optimize,
346
        duration=frame_duration_ms,
347
        loop=loop,
348
    )
349
350
351
def make_video(
352
    image: ScalarImage,
353
    output_path: TypePath,
354
    seconds: float | None = None,
355
    frame_rate: float | None = None,
356
    direction: str = 'I',
357
    verbosity: str = 'error',
358
) -> None:
359
    ffmpeg = get_ffmpeg()
360
361
    if seconds is None and frame_rate is None:
362
        message = 'Either seconds or frame_rate must be provided.'
363
        raise ValueError(message)
364
    if seconds is not None and frame_rate is not None:
365
        message = 'Provide either seconds or frame_rate, not both.'
366
        raise ValueError(message)
367
    if image.num_channels > 1:
368
        message = 'Only single-channel tensors are supported for video output for now.'
369
        raise ValueError(message)
370
    tmin, tmax = image.data.min(), image.data.max()
371
    if tmin < 0 or tmax > 255:
372
        message = (
373
            'The tensor must be in the range [0, 256) for video output.'
374
            ' The image data will be rescaled to this range.'
375
        )
376
        warnings.warn(message, RuntimeWarning, stacklevel=2)
377
        image = RescaleIntensity((0, 255))(image)
378
    if image.data.dtype != torch.uint8:
379
        message = (
380
            'Only uint8 tensors are supported for video output. The image data'
381
            ' will be cast to uint8.'
382
        )
383
        warnings.warn(message, RuntimeWarning, stacklevel=2)
384
        image = To(torch.uint8)(image)
385
386
    # Reorient so the output looks like in typical visualization software
387
    direction = direction.upper()
388
    if direction == 'I':  # axial top to bottom
389
        target = 'IPL'
390
    elif direction == 'S':  # axial bottom to top
391
        target = 'SPL'
392
    elif direction == 'A':  # coronal back to front
393
        target = 'AIL'
394
    elif direction == 'P':  # coronal front to back
395
        target = 'PIL'
396
    elif direction == 'R':  # sagittal left to right
397
        target = 'RIP'
398
    elif direction == 'L':  # sagittal right to left
399
        target = 'LIP'
400
    else:
401
        message = (
402
            'Direction must be one of "I", "S", "P", "A", "R" or "L".'
403
            f' Got {direction!r}.'
404
        )
405
        raise ValueError(message)
406
    image = ToOrientation(target)(image)
407
408
    # Check isotropy
409
    spacing_f, spacing_h, spacing_w = image.spacing
410
    if spacing_h != spacing_w:
411
        message = (
412
            'The height and width spacings should be the same video output.'
413
            f' Got {spacing_h:.2f} and {spacing_w:.2f}.'
414
            f' Resampling both to {spacing_f:.2f}.'
415
        )
416
        warnings.warn(message, RuntimeWarning, stacklevel=2)
417
        spacing_iso = min(spacing_h, spacing_w)
418
        target_spacing = spacing_f, spacing_iso, spacing_iso
419
        image = Resample(target_spacing)(image)  # type: ignore[assignment]
420
421
    # Check that height and width are multiples of 2 for H.265 encoding
422
    num_frames, height, width = image.spatial_shape
423
    if height % 2 != 0 or width % 2 != 0:
424
        message = (
425
            f'The height ({height}) and width ({width}) must be even.'
426
            ' The image will be cropped to the nearest even number.'
427
        )
428
        warnings.warn(message, RuntimeWarning, stacklevel=2)
429
        image = EnsureShapeMultiple((1, 2, 2), method='crop')(image)
430
431
    if seconds is not None:
432
        frame_rate = num_frames / seconds
433
434
    output_path = Path(output_path)
435
    if output_path.suffix.lower() != '.mp4':
436
        message = 'Only .mp4 files are supported for video output.'
437
        raise NotImplementedError(message)
438
439
    frames = image.numpy()[0]
440
    first = frames[0]
441
    height, width = first.shape
442
443
    process = (
444
        ffmpeg.input(
445
            'pipe:',
446
            format='rawvideo',
447
            pix_fmt='gray',
448
            s=f'{width}x{height}',
449
            framerate=frame_rate,
450
        )
451
        .output(
452
            str(output_path),
453
            vcodec='libx265',
454
            pix_fmt='yuv420p',
455
            loglevel=verbosity,
456
            **{'x265-params': f'log-level={verbosity}'},
457
        )
458
        .overwrite_output()
459
        .run_async(pipe_stdin=True)
460
    )
461
462
    for array in frames:
463
        buffer = array.tobytes()
464
        process.stdin.write(buffer)
465
466
    process.stdin.close()
467
    process.wait()
468