Passed
Push — master ( 6deb01...ca4890 )
by Fernando
01:56 queued 40s
created

torchio.visualization   A

Complexity

Total Complexity 42

Size/Duplication

Total Lines 227
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 42
eloc 189
dl 0
loc 227
rs 9.0399
c 0
b 0
f 0

6 Functions

Rating   Name   Duplication   Size   Complexity  
A color_labels() 0 13 4
F plot_volume() 0 76 14
D plot_subject() 0 48 13
A rotate() 0 6 2
A import_mpl_plt() 0 7 2
B make_gif() 0 54 7

How to fix   Complexity   

Complexity

Complex classes like torchio.visualization often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import warnings
2
3
import torch
4
import numpy as np
5
6
from .typing import TypePath
7
from .data.subject import Subject
8
from .data.image import Image, LabelMap
9
from .transforms.preprocessing.spatial.to_canonical import ToCanonical
10
from .transforms.preprocessing.intensity.rescale import RescaleIntensity
11
12
13
def import_mpl_plt():
14
    try:
15
        import matplotlib as mpl
16
        import matplotlib.pyplot as plt
17
    except ImportError as e:
18
        raise ImportError('Install matplotlib for plotting support') from e
19
    return mpl, plt
20
21
22
def rotate(image, radiological=True, n=-1):
23
    # Rotate for visualization purposes
24
    image = np.rot90(image, n)
25
    if radiological:
26
        image = np.fliplr(image)
27
    return image
28
29
30
def plot_volume(
31
        image: Image,
32
        radiological=True,
33
        channel=-1,  # default to foreground for binary maps
34
        axes=None,
35
        cmap=None,
36
        output_path=None,
37
        show=True,
38
        xlabels=True,
39
        percentiles=(0.5, 99.5),
40
        figsize=None,
41
        reorient=True,
42
        ):
43
    _, plt = import_mpl_plt()
44
    fig = None
45
    if axes is None:
46
        fig, axes = plt.subplots(1, 3, figsize=figsize)
47
    sag_axis, cor_axis, axi_axis = axes
48
49
    if reorient:
50
        image = ToCanonical()(image)
51
    data = image.data[channel]
52
    indices = np.array(data.shape) // 2
53
    i, j, k = indices
54
    slice_x = rotate(data[i, :, :], radiological=radiological)
55
    slice_y = rotate(data[:, j, :], radiological=radiological)
56
    slice_z = rotate(data[:, :, k], radiological=radiological)
57
    kwargs = {}
58
    is_label = isinstance(image, LabelMap)
59
    if isinstance(cmap, dict):
60
        slices = slice_x, slice_y, slice_z
61
        slice_x, slice_y, slice_z = color_labels(slices, cmap)
62
    else:
63
        if cmap is None:
64
            cmap = 'cubehelix' if is_label else 'gray'
65
        kwargs['cmap'] = cmap
66
    if is_label:
67
        kwargs['interpolation'] = 'none'
68
69
    sr, sa, ss = image.spacing
70
    kwargs['origin'] = 'lower'
71
72
    if percentiles is not None:
73
        p1, p2 = np.percentile(data, percentiles)
74
        kwargs['vmin'] = p1
75
        kwargs['vmax'] = p2
76
77
    sag_aspect = ss / sa
78
    sag_axis.imshow(slice_x, aspect=sag_aspect, **kwargs)
79
    if xlabels:
80
        sag_axis.set_xlabel('A')
81
    sag_axis.set_ylabel('S')
82
    sag_axis.invert_xaxis()
83
    sag_axis.set_title('Sagittal')
84
85
    cor_aspect = ss / sr
86
    cor_axis.imshow(slice_y, aspect=cor_aspect, **kwargs)
87
    if xlabels:
88
        cor_axis.set_xlabel('R')
89
    cor_axis.set_ylabel('S')
90
    cor_axis.invert_xaxis()
91
    cor_axis.set_title('Coronal')
92
93
    axi_aspect = sa / sr
94
    axi_axis.imshow(slice_z, aspect=axi_aspect, **kwargs)
95
    if xlabels:
96
        axi_axis.set_xlabel('R')
97
    axi_axis.set_ylabel('A')
98
    axi_axis.invert_xaxis()
99
    axi_axis.set_title('Axial')
100
101
    plt.tight_layout()
102
    if output_path is not None and fig is not None:
103
        fig.savefig(output_path)
104
    if show:
105
        plt.show()
106
107
108
def plot_subject(
109
        subject: Subject,
110
        cmap_dict=None,
111
        show=True,
112
        output_path=None,
113
        figsize=None,
114
        clear_axes=True,
115
        **kwargs,
116
        ):
117
    _, plt = import_mpl_plt()
118
    num_images = len(subject)
119
    many_images = num_images > 2
120
    subplots_kwargs = {'figsize': figsize}
121
    try:
122
        if clear_axes:
123
            subject.check_consistent_spatial_shape()
124
            subplots_kwargs['sharex'] = 'row' if many_images else 'col'
125
            subplots_kwargs['sharey'] = 'row' if many_images else 'col'
126
    except RuntimeError:  # different shapes in subject
127
        pass
128
    args = (3, num_images) if many_images else (num_images, 3)
129
    fig, axes = plt.subplots(*args, **subplots_kwargs)
130
    # The array of axes must be 2D so that it can be indexed correctly within
131
    # the plot_volume() function
132
    axes = axes.T if many_images else axes.reshape(-1, 3)
133
    iterable = enumerate(subject.get_images_dict(intensity_only=False).items())
134
    axes_names = 'sagittal', 'coronal', 'axial'
135
    for image_index, (name, image) in iterable:
136
        image_axes = axes[image_index]
137
        cmap = None
138
        if cmap_dict is not None and name in cmap_dict:
139
            cmap = cmap_dict[name]
140
        last_row = image_index == len(axes) - 1
141
        plot_volume(
142
            image,
143
            axes=image_axes,
144
            show=False,
145
            cmap=cmap,
146
            xlabels=last_row,
147
            **kwargs,
148
        )
149
        for axis, axis_name in zip(image_axes, axes_names):
150
            axis.set_title(f'{name} ({axis_name})')
151
    plt.tight_layout()
152
    if output_path is not None:
153
        fig.savefig(output_path)
154
    if show:
155
        plt.show()
156
157
158
def color_labels(arrays, cmap_dict):
159
    results = []
160
    for array in arrays:
161
        si, sj = array.shape
162
        rgb = np.zeros((si, sj, 3), dtype=np.uint8)
163
        for label, color in cmap_dict.items():
164
            if isinstance(color, str):
165
                mpl, _ = import_mpl_plt()
166
                color = mpl.colors.to_rgb(color)
167
                color = [255 * n for n in color]
168
            rgb[array == label] = color
169
        results.append(rgb)
170
    return results
171
172
173
def make_gif(
174
        tensor: torch.Tensor,
175
        axis: int,
176
        duration: float,  # of full gif
177
        output_path: TypePath,
178
        loop: int = 0,
179
        optimize: bool = True,
180
        rescale: bool = True,
181
        reverse: bool = False,
182
        ) -> None:
183
    try:
184
        from PIL import Image as ImagePIL
185
    except ModuleNotFoundError as e:
186
        message = (
187
            'Please install Pillow to use Image.to_gif():'
188
            ' pip install Pillow'
189
        )
190
        raise RuntimeError(message) from e
191
    tensor = RescaleIntensity((0, 255))(tensor) if rescale else tensor
192
    single_channel = len(tensor) == 1
193
194
    # Move channels dimension to the end and bring selected axis to 0
195
    axes = np.roll(range(1, 4), -axis)
196
    tensor = tensor.permute(*axes, 0)
197
198
    if single_channel:
199
        mode = 'P'
200
        tensor = tensor[..., 0]
201
    else:
202
        mode = 'RGB'
203
    array = tensor.byte().numpy()
204
    n = 2 if axis == 1 else 1
205
    images = [ImagePIL.fromarray(rotate(i, n=n)).convert(mode) for i in array]
206
    num_images = len(images)
207
    images = list(reversed(images)) if reverse else images
208
    frame_duration_ms = duration / num_images * 1000
209
    if frame_duration_ms < 10:
210
        fps = round(1000 / frame_duration_ms)
211
        frame_duration_ms = 10
212
        new_duration = frame_duration_ms * num_images / 1000
213
        message = (
214
            'The computed frame rate from the given duration is too high'
215
            f' ({fps} fps). The highest possible frame rate in the GIF'
216
            ' file format specification is 100 fps. The duration has been set'
217
            f' to {new_duration:.1f} seconds, instead of {duration:.1f}'
218
        )
219
        warnings.warn(message)
220
    images[0].save(
221
        output_path,
222
        save_all=True,
223
        append_images=images[1:],
224
        optimize=optimize,
225
        duration=frame_duration_ms,
226
        loop=loop,
227
    )
228