Completed
Push — main ( 0c57ec...f6b5bf )
by Yunguan
18s queued 13s
created

test.unit.test_interface.test_file_loader()   C

Complexity

Conditions 11

Size

Total Lines 53
Code Lines 36

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 36
dl 0
loc 53
rs 5.4
c 0
b 0
f 0
cc 11
nop 0

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like test.unit.test_interface.test_file_loader() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# coding=utf-8
2
3
"""
4
Tests for deepreg/dataset/loader/interface.py
5
"""
6
from test.unit.util import is_equal_np
7
8
import numpy as np
9
import pytest
10
11
from deepreg.dataset.loader.interface import (
12
    AbstractPairedDataLoader,
13
    AbstractUnpairedDataLoader,
14
    DataLoader,
15
    FileLoader,
16
    GeneratorDataLoader,
17
)
18
from deepreg.dataset.loader.nifti_loader import NiftiFileLoader
19
from deepreg.dataset.loader.paired_loader import PairedDataLoader
20
from deepreg.dataset.loader.util import normalize_array
21
22
23
class TestDataLoader:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
24
    @pytest.mark.parametrize(
25
        "labeled,num_indices,sample_label,seed",
26
        [
27
            (True, 1, "all", 0),
28
            (False, 1, "all", 0),
29
            (None, 1, "all", 0),
30
            (True, 1, "sample", 0),
31
            (True, 1, "all", 0),
32
            (True, 1, None, 0),
33
            (True, 1, "sample", None),
34
        ],
0 ignored issues
show
introduced by
"labeled, num_indices, sample_label, seed" missing in parameter type documentation
Loading history...
introduced by
Redundant returns documentation
Loading history...
35
    )
36
    def test_init(self, labeled, num_indices, sample_label, seed):
37
        """
38
        Test init function of DataLoader class
39
        :param labeled: bool
40
        :param num_indices: int
41
        :param sample_label: str
42
        :param seed: float/int/None
43
        :return:
44
        """
45
        DataLoader(
46
            labeled=labeled,
47
            num_indices=num_indices,
48
            sample_label=sample_label,
49
            seed=seed,
50
        )
51
52
        data_loader = DataLoader(
53
            labeled=labeled,
54
            num_indices=num_indices,
55
            sample_label=sample_label,
56
            seed=seed,
57
        )
58
59
        with pytest.raises(NotImplementedError):
60
            data_loader.moving_image_shape
0 ignored issues
show
Unused Code introduced by
This statement seems to have no effect and could be removed.

This issue is typically triggered when a function that does not have side-effects is called and the return value is discarded:

class SomeClass:
    def __init__(self):
        self._x = 5

    def squared(self):
        return self._x * self._x

some_class = SomeClass()
some_class.squared()        # Flagged, as the return value is not used
print(some_class.squared()) # Ok
Loading history...
61
        with pytest.raises(NotImplementedError):
62
            data_loader.fixed_image_shape
0 ignored issues
show
Unused Code introduced by
This statement seems to have no effect and could be removed.

This issue is typically triggered when a function that does not have side-effects is called and the return value is discarded:

class SomeClass:
    def __init__(self):
        self._x = 5

    def squared(self):
        return self._x * self._x

some_class = SomeClass()
some_class.squared()        # Flagged, as the return value is not used
print(some_class.squared()) # Ok
Loading history...
63
        with pytest.raises(NotImplementedError):
64
            data_loader.num_samples
0 ignored issues
show
Unused Code introduced by
This statement seems to have no effect and could be removed.

This issue is typically triggered when a function that does not have side-effects is called and the return value is discarded:

class SomeClass:
    def __init__(self):
        self._x = 5

    def squared(self):
        return self._x * self._x

some_class = SomeClass()
some_class.squared()        # Flagged, as the return value is not used
print(some_class.squared()) # Ok
Loading history...
65
        with pytest.raises(NotImplementedError):
66
            data_loader.get_dataset()
67
68
        data_loader.close()
69
70
    @pytest.mark.parametrize(
71
        "labeled,moving_shape,fixed_shape,batch_size,data_augmentation",
72
        [
73
            (True, (9, 9, 9), (9, 9, 9), 1, {}),
74
            (
75
                True,
76
                (9, 9, 9),
77
                (15, 15, 15),
78
                1,
79
                {"data_augmentation": {"name": "affine"}},
80
            ),
81
            (
82
                True,
83
                (9, 9, 9),
84
                (15, 15, 15),
85
                1,
86
                {
87
                    "data_augmentation": [
88
                        {"name": "affine"},
89
                        {
90
                            "name": "ddf",
91
                            "field_strength": 1,
92
                            "low_res_size": (3, 3, 3),
93
                        },
0 ignored issues
show
introduced by
"batch_size, data_augmentation, fixed_shape, labeled, moving_shape" missing in parameter type documentation
Loading history...
introduced by
Redundant returns documentation
Loading history...
94
                    ],
95
                },
96
            ),
97
        ],
98
    )
99
    def test_get_dataset_and_preprocess(
100
        self, labeled, moving_shape, fixed_shape, batch_size, data_augmentation
0 ignored issues
show
Unused Code introduced by
The argument labeled seems to be unused.
Loading history...
101
    ):
102
        """
103
        Test get_transforms() function. For that, an Abstract Data Loader is created
104
        only to set the moving  and fixed shapes that are used in get_transforms().
105
        Here we test that the get_transform() returns a function and the shape of
106
        the output of this function. See test_preprocess.py for more testing regarding
107
        the concrete params.
108
109
        :param labeled: bool
110
        :param moving_shape: tuple
111
        :param fixed_shape: tuple
112
        :param batch_size: int
113
        :param data_augmentation: dict
114
        :return:
115
        """
116
        data_dir_path = [
117
            "data/test/nifti/paired/train",
118
            "data/test/nifti/paired/test",
119
        ]
120
        common_args = dict(
121
            file_loader=NiftiFileLoader, labeled=True, sample_label="all", seed=None
122
        )
123
124
        data_loader = PairedDataLoader(
125
            data_dir_paths=data_dir_path,
126
            fixed_image_shape=fixed_shape,
127
            moving_image_shape=moving_shape,
128
            **common_args,
129
        )
130
131
        dataset = data_loader.get_dataset_and_preprocess(
132
            training=True,
133
            batch_size=batch_size,
134
            repeat=True,
135
            shuffle_buffer_num_batch=1,
136
            **data_augmentation,
137
        )
138
139
        for outputs in dataset.take(1):
140
            assert (
141
                outputs["moving_image"].shape
142
                == (batch_size,) + data_loader.moving_image_shape
143
            )
144
            assert (
145
                outputs["fixed_image"].shape
146
                == (batch_size,) + data_loader.fixed_image_shape
147
            )
148
            assert (
149
                outputs["moving_label"].shape
150
                == (batch_size,) + data_loader.moving_image_shape
151
            )
152
            assert (
153
                outputs["fixed_label"].shape
154
                == (batch_size,) + data_loader.fixed_image_shape
155
            )
156
157
158
def test_abstract_paired_data_loader():
159
    """
160
    Test the functions in AbstractPairedDataLoader
161
    """
162
    moving_image_shape = (8, 8, 4)
163
    fixed_image_shape = (6, 6, 4)
164
165
    # test init invalid shape
166
    with pytest.raises(ValueError) as err_info:
167
        AbstractPairedDataLoader(
168
            moving_image_shape=(2, 2),
169
            fixed_image_shape=(3, 3),
170
            labeled=True,
171
            sample_label="sample",
172
        )
173
    assert "moving_image_shape and fixed_image_shape have length of three" in str(
174
        err_info.value
175
    )
176
177
    # test init valid shapes
178
    data_loader = AbstractPairedDataLoader(
179
        moving_image_shape=moving_image_shape,
180
        fixed_image_shape=fixed_image_shape,
181
        labeled=True,
182
        sample_label="sample",
183
    )
184
185
    # test properties
186
    assert data_loader.num_indices == 2
187
    assert data_loader.moving_image_shape == moving_image_shape
188
    assert data_loader.fixed_image_shape == fixed_image_shape
189
    assert data_loader.num_samples is None
190
191
192
def test_abstract_unpaired_data_loader():
193
    """
194
    Test the functions in AbstractUnpairedDataLoader
195
    """
196
    image_shape = (8, 8, 4)
197
198
    # test init invalid shape
199
    with pytest.raises(ValueError) as err_info:
200
        AbstractUnpairedDataLoader(
201
            image_shape=(2, 2), labeled=True, sample_label="sample"
202
        )
203
    assert "image_shape has to be length of three" in str(err_info.value)
204
205
    # test init valid shapes
206
    data_loader = AbstractUnpairedDataLoader(
207
        image_shape=image_shape, labeled=True, sample_label="sample"
208
    )
209
210
    # test properties
211
    assert data_loader.num_indices == 3
212
    assert data_loader.moving_image_shape == image_shape
213
    assert data_loader.fixed_image_shape == image_shape
214
    assert data_loader.num_samples is None
215
216
217
def test_generator_data_loader(caplog):
0 ignored issues
show
introduced by
"caplog" missing in parameter type documentation
Loading history...
218
    """
219
    Test the functions in GeneratorDataLoader
220
    :param caplog: used to check warning message.
221
    """
222
    generator = GeneratorDataLoader(labeled=True, num_indices=1, sample_label="all")
223
224
    # test properties
225
    assert generator.loader_moving_image is None
226
    assert generator.loader_moving_image is None
227
    assert generator.loader_moving_image is None
228
    assert generator.loader_moving_image is None
229
230
    # not implemented properties / functions
231
    with pytest.raises(NotImplementedError):
232
        generator.sample_index_generator()
233
234
    # implemented functions
235
    # test get_Dataset
236
    dummy_array = np.random.random(size=(100, 100, 100)).astype(np.float32)
237
    # for unlabeled data
238
    # mock generator
239
    sequence = [
240
        dict(
241
            moving_image=dummy_array,
242
            fixed_image=dummy_array,
243
            moving_label=dummy_array,
244
            fixed_label=dummy_array,
245
            indices=[1],
246
        )
247
        for i in range(3)
248
    ]
249
250
    def mock_generator():
251
        for el in sequence:
252
            yield el
253
254
    # inputs, no error means passed
255
    generator.data_generator = mock_generator
256
    dataset = generator.get_dataset()
257
258
    # check dataset output
259
    expected = dict(
260
        moving_image=dummy_array,
261
        fixed_image=dummy_array,
262
        moving_label=dummy_array,
263
        fixed_label=dummy_array,
264
        indices=[1],
265
    )
266
    for got in list(dataset.as_numpy_iterator()):
267
        assert all(is_equal_np(got[key], expected[key]) for key in expected.keys())
0 ignored issues
show
unused-code introduced by
Consider iterating the dictionary directly instead of calling .keys()
Loading history...
268
269
    # for unlabeled data
270
    generator_unlabeled = GeneratorDataLoader(
271
        labeled=False, num_indices=1, sample_label="all"
272
    )
273
274
    sequence = [
275
        dict(moving_image=dummy_array, fixed_image=dummy_array, indices=[1])
276
        for i in range(3)
277
    ]
278
279
    # inputs, no error means passed
280
    generator_unlabeled.data_generator = mock_generator
281
    dataset = generator_unlabeled.get_dataset()
282
283
    # check dataset output
284
    expected = dict(moving_image=dummy_array, fixed_image=dummy_array, indices=[1])
285
    for got in list(dataset.as_numpy_iterator()):
286
        assert all(is_equal_np(got[key], expected[key]) for key in expected.keys())
0 ignored issues
show
unused-code introduced by
Consider iterating the dictionary directly instead of calling .keys()
Loading history...
287
288
    # test data_generator
289
    # create mock data loader and sample index generator
290
    class MockDataLoader:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
291
        def __init__(self, **kwargs):
0 ignored issues
show
introduced by
Useless super delegation in method '__init__'
Loading history...
292
            super().__init__(**kwargs)
293
294
        def get_data(index):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
Coding Style Best Practice introduced by
Methods should have self as first argument.

It is a widespread convention and generally a good practice to name the first argument of methods self.

class SomeClass:
    def some_method(self):
        # ... do something
Loading history...
295
            return dummy_array
296
297
    def mock_sample_index_generator():
298
        return [[[1], [1], [1]]]
299
300
    generator = GeneratorDataLoader(labeled=True, num_indices=1, sample_label="all")
301
    generator.sample_index_generator = mock_sample_index_generator
302
    generator.loader_moving_image = MockDataLoader
303
    generator.loader_fixed_image = MockDataLoader
304
    generator.loader_moving_label = MockDataLoader
305
    generator.loader_fixed_label = MockDataLoader
306
307
    # check data generator output
308
    got = next(generator.data_generator())
309
310
    expected = dict(
311
        moving_image=normalize_array(dummy_array),
312
        fixed_image=normalize_array(dummy_array),
313
        moving_label=dummy_array,
314
        fixed_label=dummy_array,
315
        indices=np.asarray([1] + [0], dtype=np.float32),
316
    )
317
    assert all(is_equal_np(got[key], expected[key]) for key in expected.keys())
0 ignored issues
show
unused-code introduced by
Consider iterating the dictionary directly instead of calling .keys()
Loading history...
318
319
    # test validate_images_and_labels
320
    with pytest.raises(ValueError) as err_info:
321
        generator.validate_images_and_labels(
322
            fixed_image=None,
323
            moving_image=dummy_array,
324
            moving_label=None,
325
            fixed_label=None,
326
            image_indices=[1],
327
        )
328
    assert "moving image and fixed image must not be None" in str(err_info.value)
329
    with pytest.raises(ValueError) as err_info:
330
        generator.validate_images_and_labels(
331
            fixed_image=dummy_array,
332
            moving_image=dummy_array,
333
            moving_label=dummy_array,
334
            fixed_label=None,
335
            image_indices=[1],
336
        )
337
    assert "moving label and fixed label must be both None or non-None" in str(
338
        err_info.value
339
    )
340
    with pytest.raises(ValueError) as err_info:
341
        generator.validate_images_and_labels(
342
            fixed_image=dummy_array,
343
            moving_image=dummy_array + 1.0,
344
            moving_label=None,
345
            fixed_label=None,
346
            image_indices=[1],
347
        )
348
    assert "Sample [1]'s moving_image's values are not between [0, 1]" in str(
349
        err_info.value
350
    )
351
    with pytest.raises(ValueError) as err_info:
352
        generator.validate_images_and_labels(
353
            fixed_image=dummy_array,
354
            moving_image=np.random.random(size=(100, 100)),
355
            moving_label=None,
356
            fixed_label=None,
357
            image_indices=[1],
358
        )
359
    assert "Sample [1]'s moving_image' shape should be 3D. " in str(err_info.value)
360
    with pytest.raises(ValueError) as err_info:
361
        generator.validate_images_and_labels(
362
            fixed_image=dummy_array,
363
            moving_image=dummy_array,
364
            moving_label=np.random.random(size=(100, 100)),
365
            fixed_label=dummy_array,
366
            image_indices=[1],
367
        )
368
    assert "Sample [1]'s moving_label' shape should be 3D or 4D. " in str(
369
        err_info.value
370
    )
371
    with pytest.raises(ValueError) as err_info:
372
        generator.validate_images_and_labels(
373
            fixed_image=dummy_array,
374
            moving_image=dummy_array,
375
            moving_label=np.random.random(size=(100, 100, 100, 3)),
376
            fixed_label=np.random.random(size=(100, 100, 100, 4)),
377
            image_indices=[1],
378
        )
379
    assert (
380
        "Sample [1]'s moving image and fixed image have different numbers of labels."
381
        in str(err_info.value)
382
    )
383
384
    # warning
385
    caplog.clear()  # clear previous log
386
    generator.validate_images_and_labels(
387
        fixed_image=dummy_array,
388
        moving_image=dummy_array,
389
        moving_label=np.random.random(size=(100, 100, 90)),
390
        fixed_label=dummy_array,
391
        image_indices=[1],
392
    )
393
    assert "Sample [1]'s moving image and label have different shapes. " in caplog.text
394
    caplog.clear()  # clear previous log
395
    generator.validate_images_and_labels(
396
        fixed_image=dummy_array,
397
        moving_image=dummy_array,
398
        moving_label=dummy_array,
399
        fixed_label=np.random.random(size=(100, 100, 90)),
400
        image_indices=[1],
401
    )
402
    assert "Sample [1]'s fixed image and label have different shapes. " in caplog.text
403
404
    # test sample_image_label method
405
    # for unlabeled input data
406
    got = next(
407
        generator.sample_image_label(
408
            fixed_image=dummy_array,
409
            moving_image=dummy_array,
410
            moving_label=None,
411
            fixed_label=None,
412
            image_indices=[1],
413
        )
414
    )
415
    expected = dict(
416
        moving_image=dummy_array,
417
        fixed_image=dummy_array,
418
        indices=np.asarray([1] + [-1], dtype=np.float32),
419
    )
420
    assert all(is_equal_np(got[key], expected[key]) for key in expected.keys())
0 ignored issues
show
unused-code introduced by
Consider iterating the dictionary directly instead of calling .keys()
Loading history...
421
422
    # for data with one label
423
    got = next(
424
        generator.sample_image_label(
425
            fixed_image=dummy_array,
426
            moving_image=dummy_array,
427
            moving_label=dummy_array,
428
            fixed_label=dummy_array,
429
            image_indices=[1],
430
        )
431
    )
432
    expected = dict(
433
        moving_image=dummy_array,
434
        fixed_image=dummy_array,
435
        moving_label=dummy_array,
436
        fixed_label=dummy_array,
437
        indices=np.asarray([1] + [0], dtype=np.float32),
438
    )
439
    assert all(is_equal_np(got[key], expected[key]) for key in expected.keys())
0 ignored issues
show
unused-code introduced by
Consider iterating the dictionary directly instead of calling .keys()
Loading history...
440
441
    # for data with multiple labels
442
    dummy_labels = np.random.random(size=(100, 100, 100, 3))
443
    got = generator.sample_image_label(
444
        fixed_image=dummy_array,
445
        moving_image=dummy_array,
446
        moving_label=dummy_labels,
447
        fixed_label=dummy_labels,
448
        image_indices=[1],
449
    )
450
    for label_index in range(dummy_labels.shape[3]):
451
        got_iter = next(got)
452
        expected = dict(
453
            moving_image=dummy_array,
454
            fixed_image=dummy_array,
455
            moving_label=dummy_labels[..., label_index],
456
            fixed_label=dummy_labels[..., label_index],
457
            indices=np.asarray([1] + [label_index], dtype=np.float32),
458
        )
459
        assert all(is_equal_np(got_iter[key], expected[key]) for key in expected.keys())
0 ignored issues
show
unused-code introduced by
Consider iterating the dictionary directly instead of calling .keys()
Loading history...
460
461
462
def test_file_loader():
463
    """
464
    Test the functions in FileLoader
465
    """
466
    # init, no error means passed
467
    loader_grouped = FileLoader(
468
        dir_paths=["/path/grouped_loader/"], name="grouped_loader", grouped=True
469
    )
470
    loader_ungrouped = FileLoader(
471
        dir_paths=["/path/ungrouped_loader/"], name="ungrouped_loader", grouped=False
472
    )
473
474
    # init fails with repeated paths
475
    with pytest.raises(ValueError) as err_info:
476
        FileLoader(
477
            dir_paths=["/path/ungrouped_loader/", "/path/ungrouped_loader/"],
478
            name="ungrouped_loader",
479
            grouped=False,
480
        )
481
    assert "dir_paths have repeated elements" in str(err_info.value)
482
483
    # not implemented properties / functions
484
    with pytest.raises(NotImplementedError):
485
        loader_grouped.set_data_structure()
486
    with pytest.raises(NotImplementedError):
487
        loader_grouped.set_group_structure()
488
    with pytest.raises(NotImplementedError):
489
        loader_grouped.get_data(1)
490
    with pytest.raises(NotImplementedError):
491
        loader_grouped.get_data_ids()
492
    with pytest.raises(NotImplementedError):
493
        loader_grouped.get_num_images()
494
    with pytest.raises(NotImplementedError):
495
        loader_grouped.close()
496
497
    # test grouped file loader functions
498
    assert loader_grouped.group_struct is None
499
500
    # create mock group structure with nested list
501
    loader_grouped.group_struct = [[1, 2], [3, 4], [5, 6]]
502
    assert loader_grouped.get_num_groups() == 3
503
    assert loader_grouped.get_num_images_per_group() == [2, 2, 2]
504
    with pytest.raises(ValueError) as err_info:
505
        loader_grouped.group_struct = [[], [3, 4], [5, 6]]
506
        loader_grouped.get_num_images_per_group()
507
    assert "Groups of ID [0, 2, 2] are empty." in str(err_info.value)
508
509
    # test ungrouped file loader
510
    assert loader_ungrouped.group_struct is None
511
    with pytest.raises(AssertionError):
512
        loader_ungrouped.get_num_groups()
513
    with pytest.raises(AssertionError):
514
        loader_ungrouped.get_num_images_per_group()
515