Passed
Push — main ( 287682...fc78a5 )
by Fernando
01:27
created

torchio.visualization.get_num_bins()   A

Complexity

Conditions 1

Size

Total Lines 16
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 5
nop 1
dl 0
loc 16
rs 10
c 0
b 0
f 0
1
from __future__ import annotations
2
3
import warnings
4
from pathlib import Path
5
from typing import TYPE_CHECKING
6
7
import numpy as np
8
import torch
9
10
from .data.image import Image
11
from .data.image import LabelMap
12
from .data.image import ScalarImage
13
from .data.subject import Subject
14
from .external.imports import get_ffmpeg
15
from .transforms.preprocessing.intensity.rescale import RescaleIntensity
16
from .transforms.preprocessing.intensity.to import To
17
from .transforms.preprocessing.spatial.ensure_shape_multiple import EnsureShapeMultiple
18
from .transforms.preprocessing.spatial.resample import Resample
19
from .transforms.preprocessing.spatial.to_canonical import ToCanonical
20
from .transforms.preprocessing.spatial.to_orientation import ToOrientation
21
from .types import TypePath
22
23
if TYPE_CHECKING:
24
    from matplotlib.colors import ListedColormap
25
26
27
def import_mpl_plt():
28
    try:
29
        import matplotlib as mpl
30
        import matplotlib.pyplot as plt
31
    except ImportError as e:
32
        raise ImportError('Install matplotlib for plotting support') from e
33
    return mpl, plt
34
35
36
def rotate(image, radiological=True, n=-1):
37
    # Rotate for visualization purposes
38
    image = np.rot90(image, n)
39
    if radiological:
40
        image = np.fliplr(image)
41
    return image
42
43
44
def _create_categorical_colormap(data: torch.Tensor) -> ListedColormap:
45
    num_classes = int(data.max())
46
    mpl, _ = import_mpl_plt()
47
48
    if num_classes == 1:  # just do white
49
        distinct_colors = [(1, 1, 1)]
50
    else:
51
        from .external.imports import get_distinctipy
52
53
        distinctipy = get_distinctipy()
54
        distinct_colors = distinctipy.get_colors(num_classes, rng=0)
55
    colors = [(0, 0, 0), *distinct_colors]  # prepend black
56
    return mpl.colors.ListedColormap(colors)
57
58
59
def plot_volume(
60
    image: Image,
61
    radiological=True,
62
    channel=-1,  # default to foreground for binary maps
63
    axes=None,
64
    cmap=None,
65
    output_path=None,
66
    show=True,
67
    xlabels=True,
68
    percentiles: tuple[float, float] = (0.5, 99.5),
69
    figsize=None,
70
    title=None,
71
    reorient=True,
72
    indices=None,
73
    **imshow_kwargs,
74
):
75
    _, plt = import_mpl_plt()
76
    fig = None
77
    if axes is None:
78
        fig, axes = plt.subplots(1, 3, figsize=figsize)
79
    sag_axis, cor_axis, axi_axis = axes
80
81
    if reorient:
82
        image = ToCanonical()(image)  # type: ignore[assignment]
83
    data = image.data[channel]
84
    if indices is None:
85
        indices = np.array(data.shape) // 2
86
    i, j, k = indices
87
    slice_x = rotate(data[i, :, :], radiological=radiological)
88
    slice_y = rotate(data[:, j, :], radiological=radiological)
89
    slice_z = rotate(data[:, :, k], radiological=radiological)
90
    is_label = isinstance(image, LabelMap)
91
    if isinstance(cmap, dict):
92
        slices = slice_x, slice_y, slice_z
93
        slice_x, slice_y, slice_z = color_labels(slices, cmap)
94
    else:
95
        if cmap is None:
96
            cmap = _create_categorical_colormap(data) if is_label else 'gray'
97
        imshow_kwargs['cmap'] = cmap
98
99
    if is_label:
100
        imshow_kwargs['interpolation'] = 'none'
101
102
    sr, sa, ss = image.spacing
103
    imshow_kwargs['origin'] = 'lower'
104
105
    if percentiles is not None and not is_label:
106
        p1, p2 = np.percentile(data, percentiles)
107
        imshow_kwargs['vmin'] = p1
108
        imshow_kwargs['vmax'] = p2
109
110
    sag_aspect = ss / sa
111
    sag_axis.imshow(slice_x, aspect=sag_aspect, **imshow_kwargs)
112
    if xlabels:
113
        sag_axis.set_xlabel('A')
114
    sag_axis.set_ylabel('S')
115
    sag_axis.invert_xaxis()
116
    sag_axis.set_title('Sagittal')
117
118
    cor_aspect = ss / sr
119
    cor_axis.imshow(slice_y, aspect=cor_aspect, **imshow_kwargs)
120
    if xlabels:
121
        cor_axis.set_xlabel('R')
122
    cor_axis.set_ylabel('S')
123
    cor_axis.invert_xaxis()
124
    cor_axis.set_title('Coronal')
125
126
    axi_aspect = sa / sr
127
    axi_axis.imshow(slice_z, aspect=axi_aspect, **imshow_kwargs)
128
    if xlabels:
129
        axi_axis.set_xlabel('R')
130
    axi_axis.set_ylabel('A')
131
    axi_axis.invert_xaxis()
132
    axi_axis.set_title('Axial')
133
134
    plt.tight_layout()
135
    if title is not None:
136
        plt.suptitle(title)
137
138
    if output_path is not None and fig is not None:
139
        fig.savefig(output_path)
140
    if show:
141
        plt.show()
142
    return fig
143
144
145
def plot_subject(
146
    subject: Subject,
147
    cmap_dict=None,
148
    show=True,
149
    output_path=None,
150
    figsize=None,
151
    clear_axes=True,
152
    **plot_volume_kwargs,
153
):
154
    _, plt = import_mpl_plt()
155
    num_images = len(subject)
156
    many_images = num_images > 2
157
    subplots_kwargs = {'figsize': figsize}
158
    try:
159
        if clear_axes:
160
            subject.check_consistent_spatial_shape()
161
            subplots_kwargs['sharex'] = 'row' if many_images else 'col'
162
            subplots_kwargs['sharey'] = 'row' if many_images else 'col'
163
    except RuntimeError:  # different shapes in subject
164
        pass
165
    args = (3, num_images) if many_images else (num_images, 3)
166
    fig, axes = plt.subplots(*args, **subplots_kwargs)
167
    # The array of axes must be 2D so that it can be indexed correctly within
168
    # the plot_volume() function
169
    axes = axes.T if many_images else axes.reshape(-1, 3)
170
    iterable = enumerate(subject.get_images_dict(intensity_only=False).items())
171
    axes_names = 'sagittal', 'coronal', 'axial'
172
    for image_index, (name, image) in iterable:
173
        image_axes = axes[image_index]
174
        cmap = None
175
        if cmap_dict is not None and name in cmap_dict:
176
            cmap = cmap_dict[name]
177
        last_row = image_index == len(axes) - 1
178
        plot_volume(
179
            image,
180
            axes=image_axes,
181
            show=False,
182
            cmap=cmap,
183
            xlabels=last_row,
184
            **plot_volume_kwargs,
185
        )
186
        for axis, axis_name in zip(image_axes, axes_names):
187
            axis.set_title(f'{name} ({axis_name})')
188
    plt.tight_layout()
189
    if output_path is not None:
190
        fig.savefig(output_path)
191
    if show:
192
        plt.show()
193
194
195
def get_num_bins(x: np.ndarray) -> int:
196
    """Get the optimal number of bins for a histogram.
197
198
    This method uses the Freedman–Diaconis rule to compute the histogram that
199
    minimizes "the integral of the squared difference between the histogram
200
    (i.e., relative frequency density) and the density of the theoretical
201
    probability distribution" (`Wikipedia <https://en.wikipedia.org/wiki/Freedman%E2%80%93Diaconis_rule>`_).
202
203
    Args:
204
        x: Input values.
205
    """
206
    # Freedman–Diaconis number of bins
207
    q25, q75 = np.percentile(x, [25, 75])
208
    bin_width = 2 * (q75 - q25) * len(x) ** (-1 / 3)
209
    bins = round((x.max() - x.min()) / bin_width)
210
    return bins
211
212
213
def plot_histogram(x: np.ndarray, show=True, **kwargs) -> None:
214
    _, plt = import_mpl_plt()
215
    plt.hist(x, bins=get_num_bins(x), **kwargs)
216
    plt.xlabel('Intensity')
217
    density = kwargs.pop('density', False)
218
    ylabel = 'Density' if density else 'Frequency'
219
    plt.ylabel(ylabel)
220
    if show:
221
        plt.show()
222
223
224
def color_labels(arrays, cmap_dict):
225
    results = []
226
    for array in arrays:
227
        si, sj = array.shape
228
        rgb = np.zeros((si, sj, 3), dtype=np.uint8)
229
        for label, color in cmap_dict.items():
230
            if isinstance(color, str):
231
                mpl, _ = import_mpl_plt()
232
                color = mpl.colors.to_rgb(color)
233
                color = [255 * n for n in color]
234
            rgb[array == label] = color
235
        results.append(rgb)
236
    return results
237
238
239
def make_gif(
240
    tensor: torch.Tensor,
241
    axis: int,
242
    duration: float,  # of full gif
243
    output_path: TypePath,
244
    loop: int = 0,
245
    optimize: bool = True,
246
    rescale: bool = True,
247
    reverse: bool = False,
248
) -> None:
249
    try:
250
        from PIL import Image as ImagePIL
251
    except ModuleNotFoundError as e:
252
        message = 'Please install Pillow to use Image.to_gif(): pip install Pillow'
253
        raise RuntimeError(message) from e
254
    transform = RescaleIntensity((0, 255))
255
    tensor = transform(tensor) if rescale else tensor  # type: ignore[assignment]
256
    single_channel = len(tensor) == 1
257
258
    # Move channels dimension to the end and bring selected axis to 0
259
    axes = np.roll(range(1, 4), -axis)
260
    tensor = tensor.permute(*axes, 0)
261
262
    if single_channel:
263
        mode = 'P'
264
        tensor = tensor[..., 0]
265
    else:
266
        mode = 'RGB'
267
    array = tensor.byte().numpy()
268
    n = 2 if axis == 1 else 1
269
    images = [ImagePIL.fromarray(rotate(i, n=n)).convert(mode) for i in array]
270
    num_images = len(images)
271
    images = list(reversed(images)) if reverse else images
272
    frame_duration_ms = duration / num_images * 1000
273
    if frame_duration_ms < 10:
274
        fps = round(1000 / frame_duration_ms)
275
        frame_duration_ms = 10
276
        new_duration = frame_duration_ms * num_images / 1000
277
        message = (
278
            'The computed frame rate from the given duration is too high'
279
            f' ({fps} fps). The highest possible frame rate in the GIF'
280
            ' file format specification is 100 fps. The duration has been set'
281
            f' to {new_duration:.1f} seconds, instead of {duration:.1f}'
282
        )
283
        warnings.warn(message, RuntimeWarning, stacklevel=2)
284
    images[0].save(
285
        output_path,
286
        save_all=True,
287
        append_images=images[1:],
288
        optimize=optimize,
289
        duration=frame_duration_ms,
290
        loop=loop,
291
    )
292
293
294
def make_video(
295
    image: ScalarImage,
296
    output_path: TypePath,
297
    seconds: float | None = None,
298
    frame_rate: float | None = None,
299
    direction: str = 'I',
300
    verbosity: str = 'error',
301
) -> None:
302
    ffmpeg = get_ffmpeg()
303
304
    if seconds is None and frame_rate is None:
305
        message = 'Either seconds or frame_rate must be provided.'
306
        raise ValueError(message)
307
    if seconds is not None and frame_rate is not None:
308
        message = 'Provide either seconds or frame_rate, not both.'
309
        raise ValueError(message)
310
    if image.num_channels > 1:
311
        message = 'Only single-channel tensors are supported for video output for now.'
312
        raise ValueError(message)
313
    tmin, tmax = image.data.min(), image.data.max()
314
    if tmin < 0 or tmax > 255:
315
        message = (
316
            'The tensor must be in the range [0, 256) for video output.'
317
            ' The image data will be rescaled to this range.'
318
        )
319
        warnings.warn(message, RuntimeWarning, stacklevel=2)
320
        image = RescaleIntensity((0, 255))(image)
321
    if image.data.dtype != torch.uint8:
322
        message = (
323
            'Only uint8 tensors are supported for video output. The image data'
324
            ' will be cast to uint8.'
325
        )
326
        warnings.warn(message, RuntimeWarning, stacklevel=2)
327
        image = To(torch.uint8)(image)
328
329
    # Reorient so the output looks like in typical visualization software
330
    direction = direction.upper()
331
    if direction == 'I':  # axial top to bottom
332
        target = 'IPL'
333
    elif direction == 'S':  # axial bottom to top
334
        target = 'SPL'
335
    elif direction == 'A':  # coronal back to front
336
        target = 'AIL'
337
    elif direction == 'P':  # coronal front to back
338
        target = 'PIL'
339
    elif direction == 'R':  # sagittal left to right
340
        target = 'RIP'
341
    elif direction == 'L':  # sagittal right to left
342
        target = 'LIP'
343
    else:
344
        message = (
345
            'Direction must be one of "I", "S", "P", "A", "R" or "L".'
346
            f' Got {direction!r}.'
347
        )
348
        raise ValueError(message)
349
    image = ToOrientation(target)(image)
350
351
    # Check isotropy
352
    spacing_f, spacing_h, spacing_w = image.spacing
353
    if spacing_h != spacing_w:
354
        message = (
355
            'The height and width spacings should be the same video output.'
356
            f' Got {spacing_h:.2f} and {spacing_w:.2f}.'
357
            f' Resampling both to {spacing_f:.2f}.'
358
        )
359
        warnings.warn(message, RuntimeWarning, stacklevel=2)
360
        spacing_iso = min(spacing_h, spacing_w)
361
        target_spacing = spacing_f, spacing_iso, spacing_iso
362
        image = Resample(target_spacing)(image)  # type: ignore[assignment]
363
364
    # Check that height and width are multiples of 2 for H.265 encoding
365
    num_frames, height, width = image.spatial_shape
366
    if height % 2 != 0 or width % 2 != 0:
367
        message = (
368
            f'The height ({height}) and width ({width}) must be even.'
369
            ' The image will be cropped to the nearest even number.'
370
        )
371
        warnings.warn(message, RuntimeWarning, stacklevel=2)
372
        image = EnsureShapeMultiple((1, 2, 2), method='crop')(image)
373
374
    if seconds is not None:
375
        frame_rate = num_frames / seconds
376
377
    output_path = Path(output_path)
378
    if output_path.suffix.lower() != '.mp4':
379
        message = 'Only .mp4 files are supported for video output.'
380
        raise NotImplementedError(message)
381
382
    frames = image.numpy()[0]
383
    first = frames[0]
384
    height, width = first.shape
385
386
    process = (
387
        ffmpeg.input(
388
            'pipe:',
389
            format='rawvideo',
390
            pix_fmt='gray',
391
            s=f'{width}x{height}',
392
            framerate=frame_rate,
393
        )
394
        .output(
395
            str(output_path),
396
            vcodec='libx265',
397
            pix_fmt='yuv420p',
398
            loglevel=verbosity,
399
            **{'x265-params': f'log-level={verbosity}'},
400
        )
401
        .overwrite_output()
402
        .run_async(pipe_stdin=True)
403
    )
404
405
    for array in frames:
406
        buffer = array.tobytes()
407
        process.stdin.write(buffer)
408
409
    process.stdin.close()
410
    process.wait()
411