Passed
Pull Request — master (#682)
by
unknown
14:34
created

torchio.visualization   A

Complexity

Total Complexity 40

Size/Duplication

Total Lines 240
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 40
eloc 180
dl 0
loc 240
rs 9.2
c 0
b 0
f 0

6 Functions

Rating   Name   Duplication   Size   Complexity  
A rotate() 0 6 2
A import_mpl_plt() 0 7 2
A color_labels() 0 13 4
D plot_volume() 0 88 12
D plot_subject() 0 49 13
B make_gif() 0 54 7

How to fix   Complexity   

Complexity

Complex classes like torchio.visualization often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import warnings
2
3
import torch
4
import numpy as np
5
6
from .typing import TypePath
7
from .data.subject import Subject
8
from .data.image import Image, LabelMap
9
from .transforms.preprocessing.spatial.to_canonical import ToCanonical
10
from .transforms.preprocessing.intensity.rescale import RescaleIntensity
11
12
13
def import_mpl_plt():
14
    try:
15
        import matplotlib as mpl
16
        import matplotlib.pyplot as plt
17
    except ImportError as e:
18
        raise ImportError('Install matplotlib for plotting support') from e
19
    return mpl, plt
20
21
22
def rotate(image, radiological=True, n=-1):
23
    # Rotate for visualization purposes
24
    image = np.rot90(image, n)
25
    if radiological:
26
        image = np.fliplr(image)
27
    return image
28
29
30
def plot_volume(
31
        image: Image,
32
        radiological=True,
33
        channel=-1,  # default to foreground for binary maps
34
        axes=None,
35
        cmap=None,
36
        output_path=None,
37
        show=True,
38
        xlabels=True,
39
        percentiles=(0.5, 99.5),
40
        figsize=None,
41
        reorient=True,
42
        indices=None,
43
        ):
44
    _, plt = import_mpl_plt()
45
    fig = None
46
    if axes is None:
47
        fig, axes = plt.subplots(1, 3, figsize=figsize)
48
    sag_axis, cor_axis, axi_axis = axes
49
50
    if reorient:
51
        image = ToCanonical()(image)
52
    data = image.data[channel]
53
    if indices is None:
54
        indices = np.array(data.shape) // 2
55
    i, j, k = indices
56
    slice_x = rotate(data[i, :, :], radiological=radiological)
57
    slice_y = rotate(data[:, j, :], radiological=radiological)
58
    slice_z = rotate(data[:, :, k], radiological=radiological)
59
    kwargs = {}
60
    is_label = isinstance(image, LabelMap)
61
    if isinstance(cmap, dict):
62
        slices = slice_x, slice_y, slice_z
63
        slice_x, slice_y, slice_z = color_labels(slices, cmap)
64
    else:
65
        if cmap is None:
66
            cmap = 'cubehelix' if is_label else 'gray'
67
        kwargs['cmap'] = cmap
68
    if is_label:
69
        kwargs['interpolation'] = 'none'
70
71
    sr, sa, ss = image.spacing
72
    kwargs['origin'] = 'lower'
73
74
    if percentiles is not None:
75
        p1, p2 = np.percentile(data, percentiles)
76
        kwargs['vmin'] = p1
77
        kwargs['vmax'] = p2
78
79
    sag_aspect = ss / sa
80
    sag_axis.imshow(slice_x, aspect=sag_aspect, **kwargs)
81
    '''
82
    if xlabels:
83
        sag_axis.set_xlabel('A')
84
    sag_axis.set_ylabel('S')
85
    sag_axis.invert_xaxis()
86
    sag_axis.set_title('Sagittal')
87
    '''
88
89
    cor_aspect = ss / sr
90
    cor_axis.imshow(slice_y, aspect=cor_aspect, **kwargs)
91
    '''
92
    if xlabels:
93
        cor_axis.set_xlabel('R')
94
    cor_axis.set_ylabel('S')
95
    cor_axis.invert_xaxis()
96
    cor_axis.set_title('Coronal')
97
    '''
98
99
    axi_aspect = sa / sr
100
    axi_axis.imshow(slice_z, aspect=axi_aspect, **kwargs)
101
    '''
102
    if xlabels:
103
        axi_axis.set_xlabel('R')
104
    axi_axis.set_ylabel('A')
105
    axi_axis.invert_xaxis()
106
    axi_axis.set_title('Axial')
107
    '''
108
    
109
    sag_axis.axis('off')
110
    cor_axis.axis('off')
111
    axi_axis.axis('off')
112
113
    plt.tight_layout()
114
    if output_path is not None and fig is not None:
115
        fig.savefig(output_path, transparent=True)
116
    if show:
117
        plt.show()
118
119
120
def plot_subject(
121
        subject: Subject,
122
        cmap_dict=None,
123
        show=True,
124
        output_path=None,
125
        figsize=None,
126
        clear_axes=True,
127
        **kwargs,
128
        ):
129
    _, plt = import_mpl_plt()
130
    num_images = len(subject)
131
    many_images = num_images > 2
132
    subplots_kwargs = {'figsize': figsize}
133
    try:
134
        if clear_axes:
135
            subject.check_consistent_spatial_shape()
136
            subplots_kwargs['sharex'] = 'row' if many_images else 'col'
137
            subplots_kwargs['sharey'] = 'row' if many_images else 'col'
138
    except RuntimeError:  # different shapes in subject
139
        pass
140
    args = (3, num_images) if many_images else (num_images, 3)
141
    fig, axes = plt.subplots(*args, **subplots_kwargs)
142
    # The array of axes must be 2D so that it can be indexed correctly within
143
    # the plot_volume() function
144
    axes = axes.T if many_images else axes.reshape(-1, 3)
145
    iterable = enumerate(subject.get_images_dict(intensity_only=False).items())
146
    axes_names = 'sagittal', 'coronal', 'axial'
147
    for image_index, (name, image) in iterable:
148
        image_axes = axes[image_index]
149
        cmap = None
150
        if cmap_dict is not None and name in cmap_dict:
151
            cmap = cmap_dict[name]
152
        last_row = image_index == len(axes) - 1
153
        plot_volume(
154
            image,
155
            axes=image_axes,
156
            show=False,
157
            cmap=cmap,
158
            xlabels=last_row,
159
            **kwargs,
160
        )
161
        for axis, axis_name in zip(image_axes, axes_names):
162
            axis.set_title(f'{name} ({axis_name})')
163
    plt.tight_layout()
164
    if output_path is not None:
165
        fig.savefig(output_path)
166
    if show:
167
        plt.show()
168
    plt.close(fig) 
169
170
171
def color_labels(arrays, cmap_dict):
172
    results = []
173
    for array in arrays:
174
        si, sj = array.shape
175
        rgb = np.zeros((si, sj, 3), dtype=np.uint8)
176
        for label, color in cmap_dict.items():
177
            if isinstance(color, str):
178
                mpl, _ = import_mpl_plt()
179
                color = mpl.colors.to_rgb(color)
180
                color = [255 * n for n in color]
181
            rgb[array == label] = color
182
        results.append(rgb)
183
    return results
184
185
186
def make_gif(
187
        tensor: torch.Tensor,
188
        axis: int,
189
        duration: float,  # of full gif
190
        output_path: TypePath,
191
        loop: int = 0,
192
        optimize: bool = True,
193
        rescale: bool = True,
194
        reverse: bool = False,
195
        ) -> None:
196
    try:
197
        from PIL import Image as ImagePIL
198
    except ModuleNotFoundError as e:
199
        message = (
200
            'Please install Pillow to use Image.to_gif():'
201
            ' pip install Pillow'
202
        )
203
        raise RuntimeError(message) from e
204
    tensor = RescaleIntensity((0, 255))(tensor) if rescale else tensor
205
    single_channel = len(tensor) == 1
206
207
    # Move channels dimension to the end and bring selected axis to 0
208
    axes = np.roll(range(1, 4), -axis)
209
    tensor = tensor.permute(*axes, 0)
210
211
    if single_channel:
212
        mode = 'P'
213
        tensor = tensor[..., 0]
214
    else:
215
        mode = 'RGB'
216
    array = tensor.byte().numpy()
217
    n = 2 if axis == 1 else 1
218
    images = [ImagePIL.fromarray(rotate(i, n=n)).convert(mode) for i in array]
219
    num_images = len(images)
220
    images = list(reversed(images)) if reverse else images
221
    frame_duration_ms = duration / num_images * 1000
222
    if frame_duration_ms < 10:
223
        fps = round(1000 / frame_duration_ms)
224
        frame_duration_ms = 10
225
        new_duration = frame_duration_ms * num_images / 1000
226
        message = (
227
            'The computed frame rate from the given duration is too high'
228
            f' ({fps} fps). The highest possible frame rate in the GIF'
229
            ' file format specification is 100 fps. The duration has been set'
230
            f' to {new_duration:.1f} seconds, instead of {duration:.1f}'
231
        )
232
        warnings.warn(message)
233
    images[0].save(
234
        output_path,
235
        save_all=True,
236
        append_images=images[1:],
237
        optimize=optimize,
238
        duration=frame_duration_ms,
239
        loop=loop,
240
    )
241