Passed
Pull Request — master (#682)
by
unknown
01:17
created

torchio.visualization.plot_volume()   F

Complexity

Conditions 14

Size

Total Lines 76
Code Lines 67

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 67
dl 0
loc 76
rs 3.6
c 0
b 0
f 0
cc 14
nop 11

How to fix   Long Method    Complexity    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like torchio.visualization.plot_volume() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import warnings
2
3
import torch
4
import numpy as np
5
6
from .typing import TypePath
7
from .data.subject import Subject
8
from .data.image import Image, LabelMap
9
from .transforms.preprocessing.spatial.to_canonical import ToCanonical
10
from .transforms.preprocessing.intensity.rescale import RescaleIntensity
11
12
13
def import_mpl_plt():
14
    try:
15
        import matplotlib as mpl
16
        import matplotlib.pyplot as plt
17
    except ImportError as e:
18
        raise ImportError('Install matplotlib for plotting support') from e
19
    return mpl, plt
20
21
22
def rotate(image, radiological=True, n=-1):
23
    # Rotate for visualization purposes
24
    image = np.rot90(image, n)
25
    if radiological:
26
        image = np.fliplr(image)
27
    return image
28
29
30
def plot_volume(
31
        image: Image,
32
        radiological=True,
33
        channel=-1,  # default to foreground for binary maps
34
        axes=None,
35
        cmap=None,
36
        output_path=None,
37
        show=True,
38
        xlabels=True,
39
        percentiles=(0.5, 99.5),
40
        figsize=None,
41
        reorient=True,
42
        indices=None,
43
        ):
44
    _, plt = import_mpl_plt()
45
    fig = None
46
    if axes is None:
47
        fig, axes = plt.subplots(1, 3, figsize=figsize)
48
    sag_axis, cor_axis, axi_axis = axes
49
50
    if reorient:
51
        image = ToCanonical()(image)
52
    data = image.data[channel]
53
    if indices is None:
54
        indices = np.array(data.shape) // 2
55
    i, j, k = indices
56
    slice_x = rotate(data[i, :, :], radiological=radiological)
57
    slice_y = rotate(data[:, j, :], radiological=radiological)
58
    slice_z = rotate(data[:, :, k], radiological=radiological)
59
    kwargs = {}
60
    is_label = isinstance(image, LabelMap)
61
    if isinstance(cmap, dict):
62
        slices = slice_x, slice_y, slice_z
63
        slice_x, slice_y, slice_z = color_labels(slices, cmap)
64
    else:
65
        if cmap is None:
66
            cmap = 'cubehelix' if is_label else 'gray'
67
        kwargs['cmap'] = cmap
68
    if is_label:
69
        kwargs['interpolation'] = 'none'
70
71
    sr, sa, ss = image.spacing
72
    kwargs['origin'] = 'lower'
73
74
    if percentiles is not None:
75
        p1, p2 = np.percentile(data, percentiles)
76
        kwargs['vmin'] = p1
77
        kwargs['vmax'] = p2
78
79
    sag_aspect = ss / sa
80
    sag_axis.imshow(slice_x, aspect=sag_aspect, **kwargs)
81
    if xlabels:
82
        sag_axis.set_xlabel('A')
83
    sag_axis.set_ylabel('S')
84
    sag_axis.invert_xaxis()
85
    #sag_axis.set_title('Sagittal')
86
87
    cor_aspect = ss / sr
88
    cor_axis.imshow(slice_y, aspect=cor_aspect, **kwargs)
89
    if xlabels:
90
        cor_axis.set_xlabel('R')
91
    cor_axis.set_ylabel('S')
92
    cor_axis.invert_xaxis()
93
    #cor_axis.set_title('Coronal')
94
95
    axi_aspect = sa / sr
96
    axi_axis.imshow(slice_z, aspect=axi_aspect, **kwargs)
97
    if xlabels:
98
        axi_axis.set_xlabel('R')
99
    axi_axis.set_ylabel('A')
100
    axi_axis.invert_xaxis()
101
    #axi_axis.set_title('Axial')
102
    
103
    sag_axis.axis('off')
104
    sag_axis.subplots_adjust(wspace=0, hspace=0)
105
    cor_axis.axis('off')
106
    cor_axis.subplots_adjust(wspace=0, hspace=0)
107
    axi_axis.axis('off')
108
    axi_axis.subplots_adjust(wspace=0, hspace=0)
109
110
    #plt.tight_layout()
111
    plt.subplots_adjust(wspace=0, hspace=0)
112
    if output_path is not None and fig is not None:
113
        fig.savefig(output_path, transparent=True)
114
    if show:
115
        plt.show()
116
117
118
def plot_subject(
119
        subject: Subject,
120
        cmap_dict=None,
121
        show=True,
122
        output_path=None,
123
        figsize=None,
124
        clear_axes=True,
125
        **kwargs,
126
        ):
127
    _, plt = import_mpl_plt()
128
    num_images = len(subject)
129
    many_images = num_images > 2
130
    subplots_kwargs = {'figsize': figsize}
131
    try:
132
        if clear_axes:
133
            subject.check_consistent_spatial_shape()
134
            subplots_kwargs['sharex'] = 'row' if many_images else 'col'
135
            subplots_kwargs['sharey'] = 'row' if many_images else 'col'
136
    except RuntimeError:  # different shapes in subject
137
        pass
138
    args = (3, num_images) if many_images else (num_images, 3)
139
    fig, axes = plt.subplots(*args, **subplots_kwargs)
140
    # The array of axes must be 2D so that it can be indexed correctly within
141
    # the plot_volume() function
142
    axes = axes.T if many_images else axes.reshape(-1, 3)
143
    iterable = enumerate(subject.get_images_dict(intensity_only=False).items())
144
    axes_names = 'sagittal', 'coronal', 'axial'
145
    for image_index, (name, image) in iterable:
146
        image_axes = axes[image_index]
147
        cmap = None
148
        if cmap_dict is not None and name in cmap_dict:
149
            cmap = cmap_dict[name]
150
        last_row = image_index == len(axes) - 1
151
        plot_volume(
152
            image,
153
            axes=image_axes,
154
            show=False,
155
            cmap=cmap,
156
            xlabels=last_row,
157
            **kwargs,
158
        )
159
        #for axis, axis_name in zip(image_axes, axes_names):
160
        #    axis.set_title(f'{name} ({axis_name})')
161
162
    plt.tight_layout()
163
    plt.subplots_adjust(wspace=0, hspace=0)
164
    if output_path is not None:
165
        fig.savefig(output_path, transparent=True))
166
    if show:
167
        plt.show()
168
    plt.close(fig) 
169
170
171
def color_labels(arrays, cmap_dict):
172
    results = []
173
    for array in arrays:
174
        si, sj = array.shape
175
        rgb = np.zeros((si, sj, 3), dtype=np.uint8)
176
        for label, color in cmap_dict.items():
177
            if isinstance(color, str):
178
                mpl, _ = import_mpl_plt()
179
                color = mpl.colors.to_rgb(color)
180
                color = [255 * n for n in color]
181
            rgb[array == label] = color
182
        results.append(rgb)
183
    return results
184
185
186
def make_gif(
187
        tensor: torch.Tensor,
188
        axis: int,
189
        duration: float,  # of full gif
190
        output_path: TypePath,
191
        loop: int = 0,
192
        optimize: bool = True,
193
        rescale: bool = True,
194
        reverse: bool = False,
195
        ) -> None:
196
    try:
197
        from PIL import Image as ImagePIL
198
    except ModuleNotFoundError as e:
199
        message = (
200
            'Please install Pillow to use Image.to_gif():'
201
            ' pip install Pillow'
202
        )
203
        raise RuntimeError(message) from e
204
    tensor = RescaleIntensity((0, 255))(tensor) if rescale else tensor
205
    single_channel = len(tensor) == 1
206
207
    # Move channels dimension to the end and bring selected axis to 0
208
    axes = np.roll(range(1, 4), -axis)
209
    tensor = tensor.permute(*axes, 0)
210
211
    if single_channel:
212
        mode = 'P'
213
        tensor = tensor[..., 0]
214
    else:
215
        mode = 'RGB'
216
    array = tensor.byte().numpy()
217
    n = 2 if axis == 1 else 1
218
    images = [ImagePIL.fromarray(rotate(i, n=n)).convert(mode) for i in array]
219
    num_images = len(images)
220
    images = list(reversed(images)) if reverse else images
221
    frame_duration_ms = duration / num_images * 1000
222
    if frame_duration_ms < 10:
223
        fps = round(1000 / frame_duration_ms)
224
        frame_duration_ms = 10
225
        new_duration = frame_duration_ms * num_images / 1000
226
        message = (
227
            'The computed frame rate from the given duration is too high'
228
            f' ({fps} fps). The highest possible frame rate in the GIF'
229
            ' file format specification is 100 fps. The duration has been set'
230
            f' to {new_duration:.1f} seconds, instead of {duration:.1f}'
231
        )
232
        warnings.warn(message)
233
    images[0].save(
234
        output_path,
235
        save_all=True,
236
        append_images=images[1:],
237
        optimize=optimize,
238
        duration=frame_duration_ms,
239
        loop=loop,
240
    )
241