Passed
Pull Request — master (#682)
by
unknown
01:16
created

torchio.visualization   B

Complexity

Total Complexity 43

Size/Duplication

Total Lines 232
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 43
eloc 193
dl 0
loc 232
rs 8.96
c 0
b 0
f 0

6 Functions

Rating   Name   Duplication   Size   Complexity  
A rotate() 0 6 2
A import_mpl_plt() 0 7 2
A color_labels() 0 13 4
F plot_volume() 0 80 15
D plot_subject() 0 49 13
B make_gif() 0 54 7

How to fix   Complexity   

Complexity

Complex classes like torchio.visualization often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import warnings
2
3
import torch
4
import numpy as np
5
6
from .typing import TypePath
7
from .data.subject import Subject
8
from .data.image import Image, LabelMap
9
from .transforms.preprocessing.spatial.to_canonical import ToCanonical
10
from .transforms.preprocessing.intensity.rescale import RescaleIntensity
11
12
13
def import_mpl_plt():
14
    try:
15
        import matplotlib as mpl
16
        import matplotlib.pyplot as plt
17
    except ImportError as e:
18
        raise ImportError('Install matplotlib for plotting support') from e
19
    return mpl, plt
20
21
22
def rotate(image, radiological=True, n=-1):
23
    # Rotate for visualization purposes
24
    image = np.rot90(image, n)
25
    if radiological:
26
        image = np.fliplr(image)
27
    return image
28
29
30
def plot_volume(
31
        image: Image,
32
        radiological=True,
33
        channel=-1,  # default to foreground for binary maps
34
        axes=None,
35
        cmap=None,
36
        output_path=None,
37
        show=True,
38
        xlabels=True,
39
        percentiles=(0.5, 99.5),
40
        figsize=None,
41
        reorient=True,
42
        indices=None,
43
        ):
44
    _, plt = import_mpl_plt()
45
    fig = None
46
    if axes is None:
47
        fig, axes = plt.subplots(1, 3, figsize=figsize)
48
    sag_axis, cor_axis, axi_axis = axes
49
50
    if reorient:
51
        image = ToCanonical()(image)
52
    data = image.data[channel]
53
    if indices is None:
54
        indices = np.array(data.shape) // 2
55
    i, j, k = indices
56
    slice_x = rotate(data[i, :, :], radiological=radiological)
57
    slice_y = rotate(data[:, j, :], radiological=radiological)
58
    slice_z = rotate(data[:, :, k], radiological=radiological)
59
    kwargs = {}
60
    is_label = isinstance(image, LabelMap)
61
    if isinstance(cmap, dict):
62
        slices = slice_x, slice_y, slice_z
63
        slice_x, slice_y, slice_z = color_labels(slices, cmap)
64
    else:
65
        if cmap is None:
66
            cmap = 'cubehelix' if is_label else 'gray'
67
        kwargs['cmap'] = cmap
68
    if is_label:
69
        kwargs['interpolation'] = 'none'
70
71
    sr, sa, ss = image.spacing
72
    kwargs['origin'] = 'lower'
73
74
    if percentiles is not None:
75
        p1, p2 = np.percentile(data, percentiles)
76
        kwargs['vmin'] = p1
77
        kwargs['vmax'] = p2
78
79
    sag_aspect = ss / sa
80
    sag_axis.imshow(slice_x, aspect=sag_aspect, **kwargs)
81
    if xlabels:
82
        sag_axis.set_xlabel('A')
83
    sag_axis.set_ylabel('S')
84
    sag_axis.invert_xaxis()
85
    sag_axis.set_title('Sagittal')
86
87
    cor_aspect = ss / sr
88
    cor_axis.imshow(slice_y, aspect=cor_aspect, **kwargs)
89
    if xlabels:
90
        cor_axis.set_xlabel('R')
91
    cor_axis.set_ylabel('S')
92
    cor_axis.invert_xaxis()
93
    cor_axis.set_title('Coronal')
94
95
    axi_aspect = sa / sr
96
    axi_axis.imshow(slice_z, aspect=axi_aspect, **kwargs)
97
    if xlabels:
98
        axi_axis.set_xlabel('R')
99
    axi_axis.set_ylabel('A')
100
    axi_axis.invert_xaxis()
101
    axi_axis.set_title('Axial')
102
    
103
    axes.axis('off')
104
105
    plt.tight_layout()
106
    if output_path is not None and fig is not None:
107
        fig.savefig(output_path, transparent=True)
108
    if show:
109
        plt.show()
110
111
112
def plot_subject(
113
        subject: Subject,
114
        cmap_dict=None,
115
        show=True,
116
        output_path=None,
117
        figsize=None,
118
        clear_axes=True,
119
        **kwargs,
120
        ):
121
    _, plt = import_mpl_plt()
122
    num_images = len(subject)
123
    many_images = num_images > 2
124
    subplots_kwargs = {'figsize': figsize}
125
    try:
126
        if clear_axes:
127
            subject.check_consistent_spatial_shape()
128
            subplots_kwargs['sharex'] = 'row' if many_images else 'col'
129
            subplots_kwargs['sharey'] = 'row' if many_images else 'col'
130
    except RuntimeError:  # different shapes in subject
131
        pass
132
    args = (3, num_images) if many_images else (num_images, 3)
133
    fig, axes = plt.subplots(*args, **subplots_kwargs)
134
    # The array of axes must be 2D so that it can be indexed correctly within
135
    # the plot_volume() function
136
    axes = axes.T if many_images else axes.reshape(-1, 3)
137
    iterable = enumerate(subject.get_images_dict(intensity_only=False).items())
138
    axes_names = 'sagittal', 'coronal', 'axial'
139
    for image_index, (name, image) in iterable:
140
        image_axes = axes[image_index]
141
        cmap = None
142
        if cmap_dict is not None and name in cmap_dict:
143
            cmap = cmap_dict[name]
144
        last_row = image_index == len(axes) - 1
145
        plot_volume(
146
            image,
147
            axes=image_axes,
148
            show=False,
149
            cmap=cmap,
150
            xlabels=last_row,
151
            **kwargs,
152
        )
153
        for axis, axis_name in zip(image_axes, axes_names):
154
            axis.set_title(f'{name} ({axis_name})')
155
    plt.tight_layout()
156
    if output_path is not None:
157
        fig.savefig(output_path)
158
    if show:
159
        plt.show()
160
    plt.close(fig) 
161
162
163
def color_labels(arrays, cmap_dict):
164
    results = []
165
    for array in arrays:
166
        si, sj = array.shape
167
        rgb = np.zeros((si, sj, 3), dtype=np.uint8)
168
        for label, color in cmap_dict.items():
169
            if isinstance(color, str):
170
                mpl, _ = import_mpl_plt()
171
                color = mpl.colors.to_rgb(color)
172
                color = [255 * n for n in color]
173
            rgb[array == label] = color
174
        results.append(rgb)
175
    return results
176
177
178
def make_gif(
179
        tensor: torch.Tensor,
180
        axis: int,
181
        duration: float,  # of full gif
182
        output_path: TypePath,
183
        loop: int = 0,
184
        optimize: bool = True,
185
        rescale: bool = True,
186
        reverse: bool = False,
187
        ) -> None:
188
    try:
189
        from PIL import Image as ImagePIL
190
    except ModuleNotFoundError as e:
191
        message = (
192
            'Please install Pillow to use Image.to_gif():'
193
            ' pip install Pillow'
194
        )
195
        raise RuntimeError(message) from e
196
    tensor = RescaleIntensity((0, 255))(tensor) if rescale else tensor
197
    single_channel = len(tensor) == 1
198
199
    # Move channels dimension to the end and bring selected axis to 0
200
    axes = np.roll(range(1, 4), -axis)
201
    tensor = tensor.permute(*axes, 0)
202
203
    if single_channel:
204
        mode = 'P'
205
        tensor = tensor[..., 0]
206
    else:
207
        mode = 'RGB'
208
    array = tensor.byte().numpy()
209
    n = 2 if axis == 1 else 1
210
    images = [ImagePIL.fromarray(rotate(i, n=n)).convert(mode) for i in array]
211
    num_images = len(images)
212
    images = list(reversed(images)) if reverse else images
213
    frame_duration_ms = duration / num_images * 1000
214
    if frame_duration_ms < 10:
215
        fps = round(1000 / frame_duration_ms)
216
        frame_duration_ms = 10
217
        new_duration = frame_duration_ms * num_images / 1000
218
        message = (
219
            'The computed frame rate from the given duration is too high'
220
            f' ({fps} fps). The highest possible frame rate in the GIF'
221
            ' file format specification is 100 fps. The duration has been set'
222
            f' to {new_duration:.1f} seconds, instead of {duration:.1f}'
223
        )
224
        warnings.warn(message)
225
    images[0].save(
226
        output_path,
227
        save_all=True,
228
        append_images=images[1:],
229
        optimize=optimize,
230
        duration=frame_duration_ms,
231
        loop=loop,
232
    )
233