Passed
Pull Request — master (#682)
by
unknown
01:13
created

torchio.visualization   B

Complexity

Total Complexity 43

Size/Duplication

Total Lines 234
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 43
eloc 190
dl 0
loc 234
rs 8.96
c 0
b 0
f 0

6 Functions

Rating   Name   Duplication   Size   Complexity  
A rotate() 0 6 2
A import_mpl_plt() 0 7 2
A color_labels() 0 13 4
B make_gif() 0 54 7
F plot_volume() 0 83 15
D plot_subject() 0 48 13

How to fix   Complexity   

Complexity

Complex classes like torchio.visualization often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import warnings
2
3
import torch
4
import numpy as np
5
6
from .typing import TypePath
7
from .data.subject import Subject
8
from .data.image import Image, LabelMap
9
from .transforms.preprocessing.spatial.to_canonical import ToCanonical
10
from .transforms.preprocessing.intensity.rescale import RescaleIntensity
11
12
13
def import_mpl_plt():
14
    try:
15
        import matplotlib as mpl
16
        import matplotlib.pyplot as plt
17
    except ImportError as e:
18
        raise ImportError('Install matplotlib for plotting support') from e
19
    return mpl, plt
20
21
22
def rotate(image, radiological=True, n=-1):
23
    # Rotate for visualization purposes
24
    image = np.rot90(image, n)
25
    if radiological:
26
        image = np.fliplr(image)
27
    return image
28
29
30
def plot_volume(
31
        image: Image,
32
        radiological=True,
33
        channel=-1,  # default to foreground for binary maps
34
        axes=None,
35
        cmap=None,
36
        output_path=None,
37
        show=True,
38
        xlabels=True,
39
        percentiles=(0.5, 99.5),
40
        figsize=None,
41
        reorient=True,
42
        indices=None,
43
        ):
44
    _, plt = import_mpl_plt()
45
    fig = None
46
    if axes is None:
47
        fig, axes = plt.subplots(1, 3, figsize=figsize)
48
    sag_axis, cor_axis, axi_axis = axes
49
50
    if reorient:
51
        image = ToCanonical()(image)
52
    data = image.data[channel]
53
    if indices is None:
54
        indices = np.array(data.shape) // 2
55
    i, j, k = indices
56
    slice_x = rotate(data[i, :, :], radiological=radiological)
57
    slice_y = rotate(data[:, j, :], radiological=radiological)
58
    slice_z = rotate(data[:, :, k], radiological=radiological)
59
    kwargs = {}
60
    is_label = isinstance(image, LabelMap)
61
    if isinstance(cmap, dict):
62
        slices = slice_x, slice_y, slice_z
63
        slice_x, slice_y, slice_z = color_labels(slices, cmap)
64
    else:
65
        if cmap is None:
66
            cmap = 'cubehelix' if is_label else 'gray'
67
        kwargs['cmap'] = cmap
68
    if is_label:
69
        kwargs['interpolation'] = 'none'
70
71
    sr, sa, ss = image.spacing
72
    kwargs['origin'] = 'lower'
73
74
    if percentiles is not None:
75
        p1, p2 = np.percentile(data, percentiles)
76
        kwargs['vmin'] = p1
77
        kwargs['vmax'] = p2
78
79
    sag_aspect = ss / sa
80
    sag_axis.imshow(slice_x, aspect=sag_aspect, **kwargs)
81
    if xlabels:
82
        sag_axis.set_xlabel('A')
83
    sag_axis.set_ylabel('S')
84
    sag_axis.invert_xaxis()
85
    #sag_axis.set_title('Sagittal')
86
87
    cor_aspect = ss / sr
88
    cor_axis.imshow(slice_y, aspect=cor_aspect, **kwargs)
89
    if xlabels:
90
        cor_axis.set_xlabel('R')
91
    cor_axis.set_ylabel('S')
92
    cor_axis.invert_xaxis()
93
    #cor_axis.set_title('Coronal')
94
95
    axi_aspect = sa / sr
96
    axi_axis.imshow(slice_z, **kwargs)
97
    if xlabels:
98
        axi_axis.set_xlabel('R')
99
    axi_axis.set_ylabel('A')
100
    axi_axis.invert_xaxis()
101
    #axi_axis.set_title('Axial')
102
    
103
    sag_axis.axis('off')
104
    cor_axis.axis('off')
105
    axi_axis.axis('off')
106
107
    #plt.tight_layout()
108
109
    if output_path is not None and fig is not None:
110
        fig.savefig(output_path, transparent=True)
111
    if show:
112
        plt.show()
113
114
115
def plot_subject(
116
        subject: Subject,
117
        cmap_dict=None,
118
        show=True,
119
        output_path=None,
120
        figsize=None,
121
        clear_axes=True,
122
        **kwargs,
123
        ):
124
    _, plt = import_mpl_plt()
125
    num_images = len(subject)
126
    many_images = num_images > 2
127
    subplots_kwargs = {'figsize': figsize}
128
    try:
129
        if clear_axes:
130
            subject.check_consistent_spatial_shape()
131
            subplots_kwargs['sharex'] = 'row' if many_images else 'col'
132
            subplots_kwargs['sharey'] = 'row' if many_images else 'col'
133
    except RuntimeError:  # different shapes in subject
134
        pass
135
    args = (3, num_images) if many_images else (num_images, 3)
136
    fig, axes = plt.subplots(*args, **subplots_kwargs)
137
    # The array of axes must be 2D so that it can be indexed correctly within
138
    # the plot_volume() function
139
    axes = axes.T if many_images else axes.reshape(-1, 3)
140
    iterable = enumerate(subject.get_images_dict(intensity_only=False).items())
141
    axes_names = 'sagittal', 'coronal', 'axial'
142
    for image_index, (name, image) in iterable:
143
        image_axes = axes[image_index]
144
        cmap = None
145
        if cmap_dict is not None and name in cmap_dict:
146
            cmap = cmap_dict[name]
147
        last_row = image_index == len(axes) - 1
148
        plot_volume(
149
            image,
150
            axes=image_axes,
151
            show=False,
152
            cmap=cmap,
153
            xlabels=last_row,
154
            **kwargs,
155
        )
156
        for axis, axis_name in zip(image_axes, axes_names):
157
            axis.set_title(f'{name} ({axis_name})')
158
    plt.tight_layout()
159
    if output_path is not None:
160
        fig.savefig(output_path)
161
    if show:
162
        plt.show()
163
164
165
def color_labels(arrays, cmap_dict):
166
    results = []
167
    for array in arrays:
168
        si, sj = array.shape
169
        rgb = np.zeros((si, sj, 3), dtype=np.uint8)
170
        for label, color in cmap_dict.items():
171
            if isinstance(color, str):
172
                mpl, _ = import_mpl_plt()
173
                color = mpl.colors.to_rgb(color)
174
                color = [255 * n for n in color]
175
            rgb[array == label] = color
176
        results.append(rgb)
177
    return results
178
179
180
def make_gif(
181
        tensor: torch.Tensor,
182
        axis: int,
183
        duration: float,  # of full gif
184
        output_path: TypePath,
185
        loop: int = 0,
186
        optimize: bool = True,
187
        rescale: bool = True,
188
        reverse: bool = False,
189
        ) -> None:
190
    try:
191
        from PIL import Image as ImagePIL
192
    except ModuleNotFoundError as e:
193
        message = (
194
            'Please install Pillow to use Image.to_gif():'
195
            ' pip install Pillow'
196
        )
197
        raise RuntimeError(message) from e
198
    tensor = RescaleIntensity((0, 255))(tensor) if rescale else tensor
199
    single_channel = len(tensor) == 1
200
201
    # Move channels dimension to the end and bring selected axis to 0
202
    axes = np.roll(range(1, 4), -axis)
203
    tensor = tensor.permute(*axes, 0)
204
205
    if single_channel:
206
        mode = 'P'
207
        tensor = tensor[..., 0]
208
    else:
209
        mode = 'RGB'
210
    array = tensor.byte().numpy()
211
    n = 2 if axis == 1 else 1
212
    images = [ImagePIL.fromarray(rotate(i, n=n)).convert(mode) for i in array]
213
    num_images = len(images)
214
    images = list(reversed(images)) if reverse else images
215
    frame_duration_ms = duration / num_images * 1000
216
    if frame_duration_ms < 10:
217
        fps = round(1000 / frame_duration_ms)
218
        frame_duration_ms = 10
219
        new_duration = frame_duration_ms * num_images / 1000
220
        message = (
221
            'The computed frame rate from the given duration is too high'
222
            f' ({fps} fps). The highest possible frame rate in the GIF'
223
            ' file format specification is 100 fps. The duration has been set'
224
            f' to {new_duration:.1f} seconds, instead of {duration:.1f}'
225
        )
226
        warnings.warn(message)
227
    images[0].save(
228
        output_path,
229
        save_all=True,
230
        append_images=images[1:],
231
        optimize=optimize,
232
        duration=frame_duration_ms,
233
        loop=loop,
234
    )
235