deepreg.vis.gif_warp()   B
last analyzed

Complexity

Conditions 6

Size

Total Lines 60
Code Lines 36

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 36
dl 0
loc 60
rs 8.0826
c 0
b 0
f 0
cc 6
nop 6

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""
2
Module to generate visualisations of data
3
at command line interface.
4
Requires ffmpeg writer to write gif files
5
"""
6
7
import argparse
8
import os
9
from typing import List
10
11
import matplotlib.animation as animation
12
import matplotlib.pyplot as plt
13
import numpy as np
14
import numpy.matlib
15
16
from deepreg import log
17
from deepreg.dataset.loader.nifti_loader import load_nifti_file
18
from deepreg.model.layer import Warping
19
20
logger = log.get(__name__)
21
22
23
def string_to_list(string: str) -> List[str]:
24
    """
25
    Converts a comma separated string to a list of strings
26
    also removes leading or trailing spaces from each element in list.
27
28
    :param string: string which is to be converted to list
29
    :return: list of strings
30
    """
31
    return [elem.strip() for elem in string.split(",")]
32
33
34
def gif_slices(img_paths, save_path="", interval=50):
35
    """
36
    Generates and saves gif of slices of image
37
    supports multiple images to generate multiple gif files.
38
39
    :param img_paths: list or comma separated string of image paths
40
    :param save_path: path to directory where visualisation/s is/are to be saved
41
    :param interval: time in miliseconds between frames of gif
42
    """
43
    if type(img_paths) is str:
44
        img_paths = string_to_list(img_paths)
45
    img = load_nifti_file(img_paths[0])
46
    img_shape = np.shape(img)
47
    for img_path in img_paths:
48
        fig = plt.figure()
49
        ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
50
        ax.set_axis_off()
51
        fig.add_axes(ax)
52
53
        img = load_nifti_file(img_path)
54
55
        frames = []
56
        for index in range(img_shape[-1]):
57
            frame = plt.imshow(img[:, :, index], aspect="auto", animated=True)
58
            # plt.axis('off')
59
            frames.append([frame])
60
61
        anim = animation.ArtistAnimation(fig, frames, interval=interval)
62
63
        path_to_anim_save = os.path.join(
64
            save_path, os.path.split(img_path)[-1].split(".")[0] + ".gif"
65
        )
66
67
        anim.save(path_to_anim_save)
68
        logger.info("Animation saved to: %s.", path_to_anim_save)
69
70
71
def tile_slices(img_paths, save_path="", fname=None, slice_inds=None, col_titles=None):
72
    """
73
    Generate a tiled plot of multiple images for multiple slices in the image.
74
    Rows are different slices, columns are different images.
75
76
    :param img_paths: list or comma separated string of image paths
77
    :param save_path: path to directory where visualisation/s is/are to be saved
78
    :param fname: file name with extension to save visualisation to
79
    :param slice_inds: list of slice indices to plot for each image
80
    :param col_titles: titles for each column, if None then inferred from file names
81
    """
82
    if type(img_paths) is str:
83
        img_paths = string_to_list(img_paths)
84
    img = load_nifti_file(img_paths[0])
85
    img_shape = np.shape(img)
86
87
    if slice_inds is None:
88
        slice_inds = [round(np.random.rand() * (img_shape[-1]) - 1)]
89
90
    if col_titles is None:
91
        col_titles = [
92
            os.path.split(img_path)[-1].split(".")[0] for img_path in img_paths
93
        ]
94
95
    num_inds = len(slice_inds)
96
    num_imgs = len(img_paths)
97
98
    subplot_mat = np.array(np.arange(num_inds * num_imgs) + 1).reshape(
99
        num_inds, num_imgs
100
    )
101
102
    plt.figure(figsize=(num_imgs * 2, num_inds * 2))
103
104
    imgs = [load_nifti_file(p) for p in img_paths]
105
106
    for col_num, img in enumerate(imgs):
107
        for row_num, index in enumerate(slice_inds):
108
            plt.subplot(num_inds, num_imgs, subplot_mat[row_num, col_num])
109
            plt.imshow(img[:, :, index])
110
            plt.axis("off")
111
            if row_num - 0 < 1e-3:
112
                plt.title(col_titles[col_num])
113
114
    if fname is None:
115
        fname = "visualisation.png"
116
    save_fig_to = os.path.join(save_path, fname)
117
    plt.savefig(save_fig_to)
118
    logger.info("Plot saved to: %s", save_fig_to)
119
120
121
def gif_warp(
122
    img_paths, ddf_path, slice_inds=None, num_interval=100, interval=50, save_path=""
123
):
124
    """
125
    Apply ddf to image slice/s to generate gif.
126
127
    :param img_paths: list or comma separated string of image paths
128
    :param ddf_path: path to ddf to use for warping
129
    :param slice_inds: list of slice indices to use for each image
130
    :param num_interval: number of intervals in which to apply ddf
131
    :param interval: time in miliseconds between frames of gif
132
    :param save_path: path to directory where visualisation/s is/are to be saved
133
    """
134
    if type(img_paths) is str:
135
        img_paths = string_to_list(img_paths)
136
137
    image = load_nifti_file(img_paths[0])
138
    img_shape = np.shape(image)
139
140
    if slice_inds is None:
141
        slice_inds = [round(np.random.rand() * (img_shape[-1]) - 1)]
142
143
    for img_path in img_paths:
144
        for slice_ind in slice_inds:
145
146
            fig = plt.figure()
147
            ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
148
            ax.set_axis_off()
149
            fig.add_axes(ax)
150
151
            ddf_scalers = np.linspace(0, 1, num=num_interval)
152
153
            frames = []
154
            for ddf_scaler in ddf_scalers:
155
                image = load_nifti_file(img_path)
156
                ddf = load_nifti_file(ddf_path)
157
                fixed_image_shape = ddf.shape[:3]
158
                image = np.expand_dims(image, axis=0)
159
                ddf = np.expand_dims(ddf, axis=0) * ddf_scaler
160
161
                warped_image = Warping(fixed_image_size=fixed_image_shape)([ddf, image])
162
                warped_image = np.squeeze(warped_image.numpy())
163
164
                frame = plt.imshow(
165
                    warped_image[:, :, slice_ind], aspect="auto", animated=True
166
                )
167
168
                frames.append([frame])
169
170
            anim = animation.ArtistAnimation(fig, frames, interval=interval)
171
            path_to_anim_save = os.path.join(
172
                save_path,
173
                os.path.split(img_path)[-1].split(".")[0]
174
                + "_slice_"
175
                + str(slice_ind)
176
                + ".gif",
177
            )
178
179
            anim.save(path_to_anim_save)
180
            logger.info("Animation saved to: %s", path_to_anim_save)
181
182
183
def gif_tile_slices(img_paths, save_path=None, size=(2, 2), fname=None, interval=50):
184
    """
185
    Creates tiled gif over slices of multiple images.
186
187
    :param img_paths: list or comma separated string of image paths
188
    :param save_path: path to directory where visualisation/s is/are to be saved
189
    :param interval: time in miliseconds between frames of gif
190
    :param size: number of columns and rows of images for the tiled gif
191
        (tuple e.g. (2,2))
192
    :param fname: filename to save visualisation to
193
    """
194
    if type(img_paths) is str:
195
        img_paths = string_to_list(img_paths)
196
197
    num_images = np.prod(size)
198
    if int(len(img_paths)) != int(num_images):
199
        raise ValueError(
200
            "The number of images supplied is "
201
            + str(len(img_paths))
202
            + " whereas the number required is "
203
            + str(np.prod(size))
204
            + " as size specified is "
205
            + str(size)
206
        )
207
208
    img = load_nifti_file(img_paths[0])
209
    img_shape = np.shape(img)
210
211
    imgs = []
212
    for img_path in img_paths:
213
        img = load_nifti_file(img_path)
214
        shape = np.shape(img)
215
        if shape != img_shape:
216
            raise ValueError("all images do not have equal shapes")
217
        imgs.append(img)
218
219
    frames = []
220
221
    fig = plt.figure()
222
    ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
223
    ax.set_axis_off()
224
    fig.add_axes(ax)
225
226
    for index in range(img_shape[-1]):
227
228
        temp_tiles = []
229
        frame = np.matlib.repmat(
230
            np.ones((img_shape[0], img_shape[1])), size[0], size[1]
231
        )
232
233
        for img in imgs:
234
            temp_tile = img[:, :, index]
235
            temp_tiles.append(temp_tile)
236
237
        tile_count = 0
238
        for i in range(size[0]):
239
            for j in range(size[1]):
240
                tile = temp_tiles[tile_count]
241
                tile_count += 1
242
                frame[
243
                    i * img_shape[0] : (i + 1) * img_shape[0],
244
                    j * img_shape[0] : (j + 1) * img_shape[0],
245
                ] = tile
246
247
        frame = plt.imshow(frame, aspect="auto", animated=True)
248
249
        frames.append([frame])
250
251
    if fname is None:
252
        fname = "visualisation.gif"
253
254
    anim = animation.ArtistAnimation(fig, frames, interval=interval)
255
    path_to_anim_save = os.path.join(save_path, fname)
256
257
    anim.save(path_to_anim_save)
258
    logger.info("Animation saved to: %s", path_to_anim_save)
259
260
261
def main(args=None):
262
    """
263
    CLI for deepreg_vis tool.
264
265
    Requires ffmpeg wirter to write gif files.
266
267
    :param args:
268
    """
269
    parser = argparse.ArgumentParser(
270
        description="deepreg_vis", formatter_class=argparse.RawTextHelpFormatter
271
    )
272
273
    parser.add_argument(
274
        "--mode",
275
        "-m",
276
        help="Mode of visualisation \n"
277
        "0 for animtion over image slices, \n"
278
        "1 for warp animation, \n"
279
        "2 for tile plot",
280
        type=int,
281
        required=True,
282
    )
283
    parser.add_argument(
284
        "--image-paths",
285
        "-i",
286
        help="File path for image file "
287
        "(can specify multiple paths using a comma separated string)",
288
        type=str,
289
        required=True,
290
    )
291
    parser.add_argument(
292
        "--save-path",
293
        "-s",
294
        help="Path to directory where resulting visualisation is saved",
295
        default="",
296
    )
297
298
    parser.add_argument(
299
        "--interval",
300
        help="Interval between frames of animation (in miliseconds)\n"
301
        "Applicable only if --mode 0 or --mode 1 or --mode 3",
302
        type=int,
303
        default=50,
304
    )
305
    parser.add_argument(
306
        "--ddf-path",
307
        help="Path to ddf used for warping images\n"
308
        "Applicable only and required if --mode 1",
309
        type=str,
310
        default=None,
311
    )
312
    parser.add_argument(
313
        "--num-interval",
314
        help="Number of intervals to use for warping\n" "Applicable only if --mode 1",
315
        type=int,
316
        default=100,
317
    )
318
    parser.add_argument(
319
        "--slice-inds",
320
        help="Comma separated string of indexes of slices"
321
        " to be used for the visualisation\n"
322
        "Applicable only if --mode 1 or --mode 2",
323
        type=str,
324
        default=None,
325
    )
326
    parser.add_argument(
327
        "--fname",
328
        help="File name (with extension like .png, .jpeg, .gif, ...)"
329
        " to save visualisation to\n"
330
        "Applicable only if --mode 2 or --mode 3",
331
        type=str,
332
        default=None,
333
    )
334
    parser.add_argument(
335
        "--col-titles",
336
        help="Comma separated string of column titles to use "
337
        "(inferred from file names if not provided)\n"
338
        "Applicable only if --mode 2",
339
        default=None,
340
    )
341
    parser.add_argument(
342
        "--size",
343
        help="Comma separated string of number of columns and rows (e.g. '2,2')\n"
344
        "Applicable only if --mode 3",
345
        default="2,2",
346
    )
347
348
    # init arguments
349
    args = parser.parse_args(args)
350
351
    if args.slice_inds is not None:
352
        args.slice_inds = string_to_list(args.slice_inds)
353
        args.slice_inds = [int(elem) for elem in args.slice_inds]
354
355
    if args.mode == 0:
356
        gif_slices(
357
            img_paths=args.image_paths, save_path=args.save_path, interval=args.interval
358
        )
359
    elif args.mode == 1:
360
        if args.ddf_path is None:
361
            raise Exception("--ddf-path is required when using --mode 1")
362
        gif_warp(
363
            img_paths=args.image_paths,
364
            ddf_path=args.ddf_path,
365
            slice_inds=args.slice_inds,
366
            num_interval=args.num_interval,
367
            interval=args.interval,
368
            save_path=args.save_path,
369
        )
370
    elif args.mode == 2:
371
        tile_slices(
372
            img_paths=args.image_paths,
373
            save_path=args.save_path,
374
            fname=args.fname,
375
            slice_inds=args.slice_inds,
376
            col_titles=args.col_titles,
377
        )
378
    elif args.mode == 3:
379
        size = tuple([int(elem) for elem in string_to_list(args.size)])
380
        gif_tile_slices(
381
            img_paths=args.image_paths,
382
            save_path=args.save_path,
383
            fname=args.fname,
384
            interval=args.interval,
385
            size=size,
386
        )
387
388
389
if __name__ == "__main__":
390
    main()  # pragma: no cover
391