Passed
Pull Request — master (#332)
by Fernando
01:14
created

torchio.visualization   A

Complexity

Total Complexity 23

Size/Duplication

Total Lines 100
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 87
dl 0
loc 100
rs 10
c 0
b 0
f 0
wmc 23

5 Functions

Rating   Name   Duplication   Size   Complexity  
A color_labels() 0 13 4
C plot_volume() 0 39 9
B plot_subject() 0 23 7
A rotate() 0 2 1
A import_mpl_plt() 0 7 2
1
import numpy as np
2
3
from .data.image import Image, LabelMap
4
from .data.subject import Subject
5
from .transforms.preprocessing.spatial.to_canonical import ToCanonical
6
7
8
def import_mpl_plt():
9
    try:
10
        import matplotlib as mpl
11
        import matplotlib.pyplot as plt
12
    except ImportError as e:
13
        raise ImportError('Install matplotlib for plotting support') from e
14
    return mpl, plt
15
16
17
def rotate(image):
18
    return np.rot90(image)
19
20
21
def plot_volume(
22
        image: Image,
23
        channel=0,
24
        axes=None,
25
        cmap=None,
26
        output_path=None,
27
        show=True,
28
        ):
29
    _, plt = import_mpl_plt()
30
    fig = None
31
    if axes is None:
32
        fig, axes = plt.subplots(1, 3)
33
    image = ToCanonical()(image)
34
    data = image.data[channel]
35
    indices = np.array(data.shape) // 2
36
    i, j, k = indices
37
    slice_x = rotate(data[i, :, :])
38
    slice_y = rotate(data[:, j, :])
39
    slice_z = rotate(data[:, :, k])
40
    kwargs = {}
41
    is_label = isinstance(image, LabelMap)
42
    if isinstance(cmap, dict):
43
        slices = slice_x, slice_y, slice_z
44
        slice_x, slice_y, slice_z = color_labels(slices, cmap)
45
    else:
46
        if cmap is None:
47
            cmap = 'inferno' if is_label else 'gray'
48
        kwargs['cmap'] = cmap
49
    if is_label:
50
        kwargs['interpolation'] = 'none'
51
    x_extent, y_extent, z_extent = [tuple(b) for b in image.bounds.T]
52
    axes[0].imshow(slice_x, extent=y_extent + z_extent, **kwargs)
53
    axes[1].imshow(slice_y, extent=x_extent + z_extent, **kwargs)
54
    axes[2].imshow(slice_z, extent=x_extent + y_extent, **kwargs)
55
    plt.tight_layout()
56
    if output_path is not None and fig is not None:
57
        fig.savefig(output_path)
58
    if show:
59
        plt.show()
60
61
62
def plot_subject(
63
        subject: Subject,
64
        cmap_dict=None,
65
        show=True,
66
        output_path=None,
67
        ):
68
    _, plt = import_mpl_plt()
69
    fig, axes = plt.subplots(len(subject), 3)
70
    iterable = enumerate(subject.get_images_dict(intensity_only=False).items())
71
    axes_names = 'sagittal', 'coronal', 'axial'
72
    for row, (name, image) in iterable:
73
        row_axes = axes[row]
74
        cmap = None
75
        if cmap_dict is not None and name in cmap_dict:
76
            cmap = cmap_dict[name]
77
        plot_volume(image, axes=row_axes, show=False, cmap=cmap)
78
        for axis, axis_name in zip(row_axes, axes_names):
79
            axis.set_title(f'{name} ({axis_name})')
80
    plt.tight_layout()
81
    if output_path is not None:
82
        fig.savefig(output_path)
83
    if show:
84
        plt.show()
85
86
87
def color_labels(arrays, cmap_dict):
88
    results = []
89
    for array in arrays:
90
        si, sj = array.shape
91
        rgb = np.zeros((si, sj, 3), dtype=np.uint8)
92
        for label, color in cmap_dict.items():
93
            if isinstance(color, str):
94
                mpl, _ = import_mpl_plt()
95
                color = mpl.colors.to_rgb(color)
96
                color = [255 * n for n in color]
97
            rgb[array == label] = color
98
        results.append(rgb)
99
    return results
100