Passed
Push — master ( f0c5a2...027e9c )
by Fernando
01:10
created

torchio.visualization   B

Complexity

Total Complexity 43

Size/Duplication

Total Lines 230
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 43
eloc 192
dl 0
loc 230
rs 8.96
c 0
b 0
f 0

6 Functions

Rating   Name   Duplication   Size   Complexity  
A rotate() 0 6 2
A import_mpl_plt() 0 7 2
A color_labels() 0 13 4
F plot_volume() 0 78 15
D plot_subject() 0 49 13
B make_gif() 0 54 7

How to fix   Complexity   

Complexity

Complex classes like torchio.visualization often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import warnings
2
3
import torch
4
import numpy as np
5
6
from .typing import TypePath
7
from .data.subject import Subject
8
from .data.image import Image, LabelMap
9
from .transforms.preprocessing.spatial.to_canonical import ToCanonical
10
from .transforms.preprocessing.intensity.rescale import RescaleIntensity
11
12
13
def import_mpl_plt():
14
    try:
15
        import matplotlib as mpl
16
        import matplotlib.pyplot as plt
17
    except ImportError as e:
18
        raise ImportError('Install matplotlib for plotting support') from e
19
    return mpl, plt
20
21
22
def rotate(image, radiological=True, n=-1):
23
    # Rotate for visualization purposes
24
    image = np.rot90(image, n)
25
    if radiological:
26
        image = np.fliplr(image)
27
    return image
28
29
30
def plot_volume(
31
        image: Image,
32
        radiological=True,
33
        channel=-1,  # default to foreground for binary maps
34
        axes=None,
35
        cmap=None,
36
        output_path=None,
37
        show=True,
38
        xlabels=True,
39
        percentiles=(0.5, 99.5),
40
        figsize=None,
41
        reorient=True,
42
        indices=None,
43
        ):
44
    _, plt = import_mpl_plt()
45
    fig = None
46
    if axes is None:
47
        fig, axes = plt.subplots(1, 3, figsize=figsize)
48
    sag_axis, cor_axis, axi_axis = axes
49
50
    if reorient:
51
        image = ToCanonical()(image)
52
    data = image.data[channel]
53
    if indices is None:
54
        indices = np.array(data.shape) // 2
55
    i, j, k = indices
56
    slice_x = rotate(data[i, :, :], radiological=radiological)
57
    slice_y = rotate(data[:, j, :], radiological=radiological)
58
    slice_z = rotate(data[:, :, k], radiological=radiological)
59
    kwargs = {}
60
    is_label = isinstance(image, LabelMap)
61
    if isinstance(cmap, dict):
62
        slices = slice_x, slice_y, slice_z
63
        slice_x, slice_y, slice_z = color_labels(slices, cmap)
64
    else:
65
        if cmap is None:
66
            cmap = 'cubehelix' if is_label else 'gray'
67
        kwargs['cmap'] = cmap
68
    if is_label:
69
        kwargs['interpolation'] = 'none'
70
71
    sr, sa, ss = image.spacing
72
    kwargs['origin'] = 'lower'
73
74
    if percentiles is not None:
75
        p1, p2 = np.percentile(data, percentiles)
76
        kwargs['vmin'] = p1
77
        kwargs['vmax'] = p2
78
79
    sag_aspect = ss / sa
80
    sag_axis.imshow(slice_x, aspect=sag_aspect, **kwargs)
81
    if xlabels:
82
        sag_axis.set_xlabel('A')
83
    sag_axis.set_ylabel('S')
84
    sag_axis.invert_xaxis()
85
    sag_axis.set_title('Sagittal')
86
87
    cor_aspect = ss / sr
88
    cor_axis.imshow(slice_y, aspect=cor_aspect, **kwargs)
89
    if xlabels:
90
        cor_axis.set_xlabel('R')
91
    cor_axis.set_ylabel('S')
92
    cor_axis.invert_xaxis()
93
    cor_axis.set_title('Coronal')
94
95
    axi_aspect = sa / sr
96
    axi_axis.imshow(slice_z, aspect=axi_aspect, **kwargs)
97
    if xlabels:
98
        axi_axis.set_xlabel('R')
99
    axi_axis.set_ylabel('A')
100
    axi_axis.invert_xaxis()
101
    axi_axis.set_title('Axial')
102
103
    plt.tight_layout()
104
    if output_path is not None and fig is not None:
105
        fig.savefig(output_path)
106
    if show:
107
        plt.show()
108
109
110
def plot_subject(
111
        subject: Subject,
112
        cmap_dict=None,
113
        show=True,
114
        output_path=None,
115
        figsize=None,
116
        clear_axes=True,
117
        **kwargs,
118
        ):
119
    _, plt = import_mpl_plt()
120
    num_images = len(subject)
121
    many_images = num_images > 2
122
    subplots_kwargs = {'figsize': figsize}
123
    try:
124
        if clear_axes:
125
            subject.check_consistent_spatial_shape()
126
            subplots_kwargs['sharex'] = 'row' if many_images else 'col'
127
            subplots_kwargs['sharey'] = 'row' if many_images else 'col'
128
    except RuntimeError:  # different shapes in subject
129
        pass
130
    args = (3, num_images) if many_images else (num_images, 3)
131
    fig, axes = plt.subplots(*args, **subplots_kwargs)
132
    # The array of axes must be 2D so that it can be indexed correctly within
133
    # the plot_volume() function
134
    axes = axes.T if many_images else axes.reshape(-1, 3)
135
    iterable = enumerate(subject.get_images_dict(intensity_only=False).items())
136
    axes_names = 'sagittal', 'coronal', 'axial'
137
    for image_index, (name, image) in iterable:
138
        image_axes = axes[image_index]
139
        cmap = None
140
        if cmap_dict is not None and name in cmap_dict:
141
            cmap = cmap_dict[name]
142
        last_row = image_index == len(axes) - 1
143
        plot_volume(
144
            image,
145
            axes=image_axes,
146
            show=False,
147
            cmap=cmap,
148
            xlabels=last_row,
149
            **kwargs,
150
        )
151
        for axis, axis_name in zip(image_axes, axes_names):
152
            axis.set_title(f'{name} ({axis_name})')
153
    plt.tight_layout()
154
    if output_path is not None:
155
        fig.savefig(output_path)
156
    if show:
157
        plt.show()
158
    plt.close(fig) 
159
160
161
def color_labels(arrays, cmap_dict):
162
    results = []
163
    for array in arrays:
164
        si, sj = array.shape
165
        rgb = np.zeros((si, sj, 3), dtype=np.uint8)
166
        for label, color in cmap_dict.items():
167
            if isinstance(color, str):
168
                mpl, _ = import_mpl_plt()
169
                color = mpl.colors.to_rgb(color)
170
                color = [255 * n for n in color]
171
            rgb[array == label] = color
172
        results.append(rgb)
173
    return results
174
175
176
def make_gif(
177
        tensor: torch.Tensor,
178
        axis: int,
179
        duration: float,  # of full gif
180
        output_path: TypePath,
181
        loop: int = 0,
182
        optimize: bool = True,
183
        rescale: bool = True,
184
        reverse: bool = False,
185
        ) -> None:
186
    try:
187
        from PIL import Image as ImagePIL
188
    except ModuleNotFoundError as e:
189
        message = (
190
            'Please install Pillow to use Image.to_gif():'
191
            ' pip install Pillow'
192
        )
193
        raise RuntimeError(message) from e
194
    tensor = RescaleIntensity((0, 255))(tensor) if rescale else tensor
195
    single_channel = len(tensor) == 1
196
197
    # Move channels dimension to the end and bring selected axis to 0
198
    axes = np.roll(range(1, 4), -axis)
199
    tensor = tensor.permute(*axes, 0)
200
201
    if single_channel:
202
        mode = 'P'
203
        tensor = tensor[..., 0]
204
    else:
205
        mode = 'RGB'
206
    array = tensor.byte().numpy()
207
    n = 2 if axis == 1 else 1
208
    images = [ImagePIL.fromarray(rotate(i, n=n)).convert(mode) for i in array]
209
    num_images = len(images)
210
    images = list(reversed(images)) if reverse else images
211
    frame_duration_ms = duration / num_images * 1000
212
    if frame_duration_ms < 10:
213
        fps = round(1000 / frame_duration_ms)
214
        frame_duration_ms = 10
215
        new_duration = frame_duration_ms * num_images / 1000
216
        message = (
217
            'The computed frame rate from the given duration is too high'
218
            f' ({fps} fps). The highest possible frame rate in the GIF'
219
            ' file format specification is 100 fps. The duration has been set'
220
            f' to {new_duration:.1f} seconds, instead of {duration:.1f}'
221
        )
222
        warnings.warn(message)
223
    images[0].save(
224
        output_path,
225
        save_all=True,
226
        append_images=images[1:],
227
        optimize=optimize,
228
        duration=frame_duration_ms,
229
        loop=loop,
230
    )
231