Passed
Push — master ( 7bf0dc...387cc1 )
by Fernando
01:06
created

TestSubjectsDataset.iterate_dataset()   A

Complexity

Conditions 2

Size

Total Lines 5
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 5
nop 1
dl 0
loc 5
rs 10
c 0
b 0
f 0
1
#!/usr/bin/env python
2
3
"""Tests for SubjectsDataset."""
4
5
import nibabel as nib
6
import torchio
7
from torchio import DATA, SubjectsDataset, ImagesDataset
8
from ..utils import TorchioTestCase
9
10
11
class TestSubjectsDataset(TorchioTestCase):
12
    """Tests for `SubjectsDataset`."""
13
14
    def test_images(self):
15
        self.iterate_dataset(self.subjects_list)
16
17
    def test_empty_subjects_list(self):
18
        with self.assertRaises(ValueError):
19
            self.iterate_dataset([])
20
21
    def test_empty_subjects_tuple(self):
22
        with self.assertRaises(ValueError):
23
            self.iterate_dataset(())
24
25
    def test_wrong_subjects_type(self):
26
        with self.assertRaises(TypeError):
27
            self.iterate_dataset(0)
28
29
    def test_wrong_subject_type_int(self):
30
        with self.assertRaises(TypeError):
31
            self.iterate_dataset([0])
32
33
    def test_wrong_subject_type_dict(self):
34
        with self.assertRaises(TypeError):
35
            self.iterate_dataset([{}])
36
37
    def test_wrong_index(self):
38
        with self.assertRaises(ValueError):
39
            self.dataset[:3]
40
41 View Code Duplication
    def test_save_sample(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
42
        dataset = SubjectsDataset(
43
            self.subjects_list, transform=lambda x: x)
44
        _ = len(dataset)  # for coverage
45
        sample = dataset[0]
46
        output_path = self.dir / 'test.nii.gz'
47
        paths_dict = {'t1': output_path}
48
        with self.assertWarns(DeprecationWarning):
49
            dataset.save_sample(sample, paths_dict)
50
        nii = nib.load(str(output_path))
51
        ndims_output = len(nii.shape)
52
        ndims_sample = len(sample['t1'].shape)
53
        assert ndims_sample == ndims_output + 1
54
55
    def test_wrong_transform_init(self):
56
        with self.assertRaises(ValueError):
57
            SubjectsDataset(
58
                self.subjects_list,
59
                transform=dict(),
60
            )
61
62
    def test_wrong_transform_arg(self):
63
        with self.assertRaises(ValueError):
64
            self.dataset.set_transform(1)
65
66
    @staticmethod
67
    def iterate_dataset(subjects_list):
68
        dataset = SubjectsDataset(subjects_list)
69
        for _ in dataset:
70
            pass
71
72
    def test_data_loader(self):
73
        from torch.utils.data import DataLoader
74
        subj_list = [torchio.datasets.Colin27()]
75
        dataset = SubjectsDataset(subj_list)
76
        loader = DataLoader(dataset, batch_size=1, shuffle=True)
77
        for batch in loader:
78
            batch['t1'][DATA]
79
            batch['brain'][DATA]
80
81
    def test_save_deprecated(self):
82
        with self.assertWarns(DeprecationWarning):
83
            self.dataset.save_sample(self.sample, {})
84
85
    def test_images_dataset_deprecated(self):
86
        with self.assertWarns(DeprecationWarning):
87
            ImagesDataset(self.subjects_list)
88