torchio.visualization   F
last analyzed

Complexity

Total Complexity 76

Size/Duplication

Total Lines 466
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 362
dl 0
loc 466
rs 2.32
c 0
b 0
f 0
wmc 76

10 Functions

Rating   Name   Duplication   Size   Complexity  
A color_labels() 0 13 4
A get_num_bins() 0 16 1
F plot_volume() 0 124 21
D plot_subject() 0 48 13
F make_video() 0 117 21
A rotate() 0 6 2
A plot_histogram() 0 9 3
A import_mpl_plt() 0 7 2
A _create_categorical_colormap() 0 23 2
B make_gif() 0 52 7

How to fix   Complexity   

Complexity

Complex classes like torchio.visualization often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from __future__ import annotations
2
3
import warnings
4
from itertools import cycle
5
from pathlib import Path
6
from typing import TYPE_CHECKING
7
from typing import Any
8
9
import numpy as np
10
import torch
11
from einops import rearrange
12
13
from .data.image import Image
14
from .data.image import LabelMap
15
from .data.image import ScalarImage
16
from .data.subject import Subject
17
from .external.imports import get_ffmpeg
18
from .transforms.preprocessing.intensity.rescale import RescaleIntensity
19
from .transforms.preprocessing.intensity.to import To
20
from .transforms.preprocessing.spatial.ensure_shape_multiple import EnsureShapeMultiple
21
from .transforms.preprocessing.spatial.resample import Resample
22
from .transforms.preprocessing.spatial.to_canonical import ToCanonical
23
from .transforms.preprocessing.spatial.to_orientation import ToOrientation
24
from .types import TypePath
25
26
if TYPE_CHECKING:
27
    from matplotlib.colors import BoundaryNorm
28
    from matplotlib.colors import ListedColormap
29
    from matplotlib.figure import Figure
30
31
32
def import_mpl_plt():
33
    try:
34
        import matplotlib as mpl
35
        import matplotlib.pyplot as plt
36
    except ImportError as e:
37
        raise ImportError('Install matplotlib for plotting support') from e
38
    return mpl, plt
39
40
41
def rotate(image: np.ndarray, *, radiological: bool = True, n: int = -1) -> np.ndarray:
42
    # Rotate for visualization purposes
43
    image = np.rot90(image, n, axes=(0, 1))
44
    if radiological:
45
        image = np.fliplr(image)
46
    return image
47
48
49
def _create_categorical_colormap(
50
    data: torch.Tensor,
51
    cmap_name: str = 'glasbey_category10',
52
) -> tuple[ListedColormap, BoundaryNorm]:
53
    num_classes = int(data.max())
54
    mpl, _ = import_mpl_plt()
55
56
    colors = [
57
        (0, 0, 0),  # black for background
58
        (1, 1, 1),  # white for class 1
59
    ]
60
    if num_classes > 1:
61
        from .external.imports import get_colorcet
62
63
        colorcet = get_colorcet()
64
        cmap = getattr(colorcet.cm, cmap_name)
65
        color_cycle = cycle(cmap.colors)
66
        distinct_colors = [next(color_cycle) for _ in range(num_classes - 1)]
67
        colors.extend(distinct_colors)
68
    boundaries = np.arange(-0.5, num_classes + 1.5, 1)
69
    colormap = mpl.colors.ListedColormap(colors)
70
    boundary_norm = mpl.colors.BoundaryNorm(boundaries, ncolors=colormap.N)
71
    return colormap, boundary_norm
72
73
74
def plot_volume(
75
    image: Image,
76
    radiological=True,
77
    channel=None,
78
    axes=None,
79
    cmap=None,
80
    output_path=None,
81
    show=True,
82
    xlabels=True,
83
    percentiles: tuple[float, float] = (0, 100),
84
    figsize=None,
85
    title=None,
86
    reorient=True,
87
    indices=None,
88
    rgb=True,
89
    savefig_kwargs: dict[str, Any] | None = None,
90
    **imshow_kwargs,
91
) -> Figure | None:
92
    _, plt = import_mpl_plt()
93
    fig: Figure | None = None
0 ignored issues
show
introduced by
The variable Figure does not seem to be defined in case TYPE_CHECKING on line 26 is False. Are you sure this can never be the case?
Loading history...
94
    if axes is None:
95
        fig, axes = plt.subplots(1, 3, figsize=figsize)
96
97
    if reorient:
98
        image = ToCanonical()(image)  # type: ignore[assignment]
99
100
    is_label = isinstance(image, LabelMap)
101
    if is_label:  # probabilistic label map
102
        data = image.data[np.newaxis, -1]
103
    elif rgb and image.num_channels == 3:
104
        data = image.data  # keep image as it is
105
    elif channel is None:
106
        data = image.data[0:1]  # just use the first channel
107
    else:
108
        data = image.data[np.newaxis, channel]
109
    data = rearrange(data, 'c x y z -> x y z c')
110
    data_numpy: np.ndarray = data.cpu().numpy()
111
112
    if indices is None:
113
        indices = np.array(data_numpy.shape[:3]) // 2
114
    i, j, k = indices
115
    slice_x = rotate(data_numpy[i, :, :], radiological=radiological)
116
    slice_y = rotate(data_numpy[:, j, :], radiological=radiological)
117
    slice_z = rotate(data_numpy[:, :, k], radiological=radiological)
118
119
    if isinstance(cmap, dict):
120
        slices = slice_x, slice_y, slice_z
121
        slice_x, slice_y, slice_z = color_labels(slices, cmap)
122
    else:
123
        boundary_norm = None
124
        if cmap is None:
125
            if is_label:
126
                cmap, boundary_norm = _create_categorical_colormap(data)
127
            else:
128
                cmap = 'gray'
129
        imshow_kwargs['cmap'] = cmap
130
        imshow_kwargs['norm'] = boundary_norm
131
132
    if is_label:
133
        imshow_kwargs['interpolation'] = 'none'
134
    else:
135
        if 'interpolation' not in imshow_kwargs:
136
            imshow_kwargs['interpolation'] = 'bicubic'
137
138
    imshow_kwargs['origin'] = 'lower'
139
140
    if not is_label:
141
        displayed_data = np.concatenate(
142
            [
143
                slice_x.flatten(),
144
                slice_y.flatten(),
145
                slice_z.flatten(),
146
            ]
147
        )
148
        p1, p2 = np.percentile(displayed_data, percentiles)
149
        imshow_kwargs['vmin'] = p1
150
        imshow_kwargs['vmax'] = p2
151
152
    spacing_r, spacing_a, spacing_s = image.spacing
153
    sag_axis, cor_axis, axi_axis = axes
154
    slices_dict = {
155
        'Sagittal': {
156
            'aspect': spacing_s / spacing_a,
157
            'slice': slice_x,
158
            'xlabel': 'A',
159
            'ylabel': 'S',
160
            'axis': sag_axis,
161
        },
162
        'Coronal': {
163
            'aspect': spacing_s / spacing_r,
164
            'slice': slice_y,
165
            'xlabel': 'R',
166
            'ylabel': 'S',
167
            'axis': cor_axis,
168
        },
169
        'Axial': {
170
            'aspect': spacing_a / spacing_r,
171
            'slice': slice_z,
172
            'xlabel': 'R',
173
            'ylabel': 'A',
174
            'axis': axi_axis,
175
        },
176
    }
177
178
    for axis_title, info in slices_dict.items():
179
        axis = info['axis']
180
        axis.imshow(info['slice'], aspect=info['aspect'], **imshow_kwargs)
181
        if xlabels:
182
            axis.set_xlabel(info['xlabel'])
183
        axis.set_ylabel(info['ylabel'])
184
        axis.invert_xaxis()
185
        axis.set_title(axis_title)
186
187
    plt.tight_layout()
188
    if title is not None:
189
        plt.suptitle(title)
190
191
    if output_path is not None and fig is not None:
192
        if savefig_kwargs is None:
193
            savefig_kwargs = {}
194
        fig.savefig(output_path, **savefig_kwargs)
195
    if show:
196
        plt.show()
197
    return fig
198
199
200
def plot_subject(
201
    subject: Subject,
202
    cmap_dict=None,
203
    show=True,
204
    output_path=None,
205
    figsize=None,
206
    clear_axes=True,
207
    **plot_volume_kwargs,
208
):
209
    _, plt = import_mpl_plt()
210
    num_images = len(subject)
211
    many_images = num_images > 2
212
    subplots_kwargs = {'figsize': figsize}
213
    try:
214
        if clear_axes:
215
            subject.check_consistent_spatial_shape()
216
            subplots_kwargs['sharex'] = 'row' if many_images else 'col'
217
            subplots_kwargs['sharey'] = 'row' if many_images else 'col'
218
    except RuntimeError:  # different shapes in subject
219
        pass
220
    args = (3, num_images) if many_images else (num_images, 3)
221
    fig, axes = plt.subplots(*args, **subplots_kwargs)
222
    # The array of axes must be 2D so that it can be indexed correctly within
223
    # the plot_volume() function
224
    axes = axes.T if many_images else axes.reshape(-1, 3)
225
    iterable = enumerate(subject.get_images_dict(intensity_only=False).items())
226
    axes_names = 'sagittal', 'coronal', 'axial'
227
    for image_index, (name, image) in iterable:
228
        image_axes = axes[image_index]
229
        cmap = None
230
        if cmap_dict is not None and name in cmap_dict:
231
            cmap = cmap_dict[name]
232
        last_row = image_index == len(axes) - 1
233
        plot_volume(
234
            image,
235
            axes=image_axes,
236
            show=False,
237
            cmap=cmap,
238
            xlabels=last_row,
239
            **plot_volume_kwargs,
240
        )
241
        for axis, axis_name in zip(image_axes, axes_names):
242
            axis.set_title(f'{name} ({axis_name})')
243
    plt.tight_layout()
244
    if output_path is not None:
245
        fig.savefig(output_path)
246
    if show:
247
        plt.show()
248
249
250
def get_num_bins(x: np.ndarray) -> int:
251
    """Get the optimal number of bins for a histogram.
252
253
    This method uses the Freedman–Diaconis rule to compute the histogram that
254
    minimizes "the integral of the squared difference between the histogram
255
    (i.e., relative frequency density) and the density of the theoretical
256
    probability distribution" (`Wikipedia <https://en.wikipedia.org/wiki/Freedman%E2%80%93Diaconis_rule>`_).
257
258
    Args:
259
        x: Input values.
260
    """
261
    # Freedman–Diaconis number of bins
262
    q25, q75 = np.percentile(x, [25, 75])
263
    bin_width = 2 * (q75 - q25) * len(x) ** (-1 / 3)
264
    bins = round((x.max() - x.min()) / bin_width)
265
    return bins
266
267
268
def plot_histogram(x: np.ndarray, show=True, **kwargs) -> None:
269
    _, plt = import_mpl_plt()
270
    plt.hist(x, bins=get_num_bins(x), **kwargs)
271
    plt.xlabel('Intensity')
272
    density = kwargs.pop('density', False)
273
    ylabel = 'Density' if density else 'Frequency'
274
    plt.ylabel(ylabel)
275
    if show:
276
        plt.show()
277
278
279
def color_labels(arrays, cmap_dict):
280
    results = []
281
    for slice_array in arrays:
282
        si, sj, _ = slice_array.shape
283
        rgb = np.zeros((si, sj, 3), dtype=np.uint8)
284
        for label, color in cmap_dict.items():
285
            if isinstance(color, str):
286
                mpl, _ = import_mpl_plt()
287
                color = mpl.colors.to_rgb(color)
288
                color = [255 * n for n in color]
289
            rgb[slice_array[..., 0] == label] = color
290
        results.append(rgb)
291
    return results
292
293
294
def make_gif(
295
    tensor: torch.Tensor,
296
    axis: int,
297
    duration: float,  # of full gif
298
    output_path: TypePath,
299
    loop: int = 0,
300
    optimize: bool = True,
301
    rescale: bool = True,
302
    reverse: bool = False,
303
) -> None:
304
    try:
305
        from PIL import Image as ImagePIL
306
    except ModuleNotFoundError as e:
307
        message = 'Please install Pillow to use Image.to_gif(): pip install Pillow'
308
        raise RuntimeError(message) from e
309
    transform = RescaleIntensity((0, 255))
310
    tensor = transform(tensor) if rescale else tensor  # type: ignore[assignment]
311
    single_channel = len(tensor) == 1
312
313
    # Move channels dimension to the end and bring selected axis to 0
314
    axes = np.roll(range(1, 4), -axis)
315
    tensor = tensor.permute(*axes, 0)
316
317
    if single_channel:
318
        mode = 'P'
319
        tensor = tensor[..., 0]
320
    else:
321
        mode = 'RGB'
322
    array = tensor.byte().numpy()
323
    n = 2 if axis == 1 else 1
324
    images = [ImagePIL.fromarray(rotate(i, n=n)).convert(mode) for i in array]
325
    num_images = len(images)
326
    images = list(reversed(images)) if reverse else images
327
    frame_duration_ms = duration / num_images * 1000
328
    if frame_duration_ms < 10:
329
        fps = round(1000 / frame_duration_ms)
330
        frame_duration_ms = 10
331
        new_duration = frame_duration_ms * num_images / 1000
332
        message = (
333
            'The computed frame rate from the given duration is too high'
334
            f' ({fps} fps). The highest possible frame rate in the GIF'
335
            ' file format specification is 100 fps. The duration has been set'
336
            f' to {new_duration:.1f} seconds, instead of {duration:.1f}'
337
        )
338
        warnings.warn(message, RuntimeWarning, stacklevel=2)
339
    images[0].save(
340
        output_path,
341
        save_all=True,
342
        append_images=images[1:],
343
        optimize=optimize,
344
        duration=frame_duration_ms,
345
        loop=loop,
346
    )
347
348
349
def make_video(
350
    image: ScalarImage,
351
    output_path: TypePath,
352
    seconds: float | None = None,
353
    frame_rate: float | None = None,
354
    direction: str = 'I',
355
    verbosity: str = 'error',
356
) -> None:
357
    ffmpeg = get_ffmpeg()
358
359
    if seconds is None and frame_rate is None:
360
        message = 'Either seconds or frame_rate must be provided.'
361
        raise ValueError(message)
362
    if seconds is not None and frame_rate is not None:
363
        message = 'Provide either seconds or frame_rate, not both.'
364
        raise ValueError(message)
365
    if image.num_channels > 1:
366
        message = 'Only single-channel tensors are supported for video output for now.'
367
        raise ValueError(message)
368
    tmin, tmax = image.data.min(), image.data.max()
369
    if tmin < 0 or tmax > 255:
370
        message = (
371
            'The tensor must be in the range [0, 256) for video output.'
372
            ' The image data will be rescaled to this range.'
373
        )
374
        warnings.warn(message, RuntimeWarning, stacklevel=2)
375
        image = RescaleIntensity((0, 255))(image)
376
    if image.data.dtype != torch.uint8:
377
        message = (
378
            'Only uint8 tensors are supported for video output. The image data'
379
            ' will be cast to uint8.'
380
        )
381
        warnings.warn(message, RuntimeWarning, stacklevel=2)
382
        image = To(torch.uint8)(image)
383
384
    # Reorient so the output looks like in typical visualization software
385
    direction = direction.upper()
386
    if direction == 'I':  # axial top to bottom
387
        target = 'IPL'
388
    elif direction == 'S':  # axial bottom to top
389
        target = 'SPL'
390
    elif direction == 'A':  # coronal back to front
391
        target = 'AIL'
392
    elif direction == 'P':  # coronal front to back
393
        target = 'PIL'
394
    elif direction == 'R':  # sagittal left to right
395
        target = 'RIP'
396
    elif direction == 'L':  # sagittal right to left
397
        target = 'LIP'
398
    else:
399
        message = (
400
            'Direction must be one of "I", "S", "P", "A", "R" or "L".'
401
            f' Got {direction!r}.'
402
        )
403
        raise ValueError(message)
404
    image = ToOrientation(target)(image)
405
406
    # Check isotropy
407
    spacing_f, spacing_h, spacing_w = image.spacing
408
    if spacing_h != spacing_w:
409
        message = (
410
            'The height and width spacings should be the same video output.'
411
            f' Got {spacing_h:.2f} and {spacing_w:.2f}.'
412
            f' Resampling both to {spacing_f:.2f}.'
413
        )
414
        warnings.warn(message, RuntimeWarning, stacklevel=2)
415
        spacing_iso = min(spacing_h, spacing_w)
416
        target_spacing = spacing_f, spacing_iso, spacing_iso
417
        image = Resample(target_spacing)(image)  # type: ignore[assignment]
418
419
    # Check that height and width are multiples of 2 for H.265 encoding
420
    num_frames, height, width = image.spatial_shape
421
    if height % 2 != 0 or width % 2 != 0:
422
        message = (
423
            f'The height ({height}) and width ({width}) must be even.'
424
            ' The image will be cropped to the nearest even number.'
425
        )
426
        warnings.warn(message, RuntimeWarning, stacklevel=2)
427
        image = EnsureShapeMultiple((1, 2, 2), method='crop')(image)
428
429
    if seconds is not None:
430
        frame_rate = num_frames / seconds
431
432
    output_path = Path(output_path)
433
    if output_path.suffix.lower() != '.mp4':
434
        message = 'Only .mp4 files are supported for video output.'
435
        raise NotImplementedError(message)
436
437
    frames = image.numpy()[0]
438
    first = frames[0]
439
    height, width = first.shape
440
441
    process = (
442
        ffmpeg.input(
443
            'pipe:',
444
            format='rawvideo',
445
            pix_fmt='gray',
446
            s=f'{width}x{height}',
447
            framerate=frame_rate,
448
        )
449
        .output(
450
            str(output_path),
451
            vcodec='libx265',
452
            pix_fmt='yuv420p',
453
            loglevel=verbosity,
454
            **{'x265-params': f'log-level={verbosity}'},
455
        )
456
        .overwrite_output()
457
        .run_async(pipe_stdin=True)
458
    )
459
460
    for array in frames:
461
        buffer = array.tobytes()
462
        process.stdin.write(buffer)
463
464
    process.stdin.close()
465
    process.wait()
466