Passed
Pull Request — master (#670)
by
unknown
03:32
created

tests.data.test_subjects_dataset   A

Complexity

Total Complexity 22

Size/Duplication

Total Lines 65
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 22
eloc 48
dl 0
loc 65
rs 10
c 0
b 0
f 0

12 Methods

Rating   Name   Duplication   Size   Complexity  
A TestSubjectsDataset.test_wrong_subjects_type() 0 3 2
A TestSubjectsDataset.test_empty_subjects_list() 0 3 2
A TestSubjectsDataset.test_wrong_index() 0 3 2
A TestSubjectsDataset.test_images() 0 2 1
A TestSubjectsDataset.test_wrong_subject_type_int() 0 3 2
A TestSubjectsDataset.test_empty_subjects_tuple() 0 3 2
A TestSubjectsDataset.test_wrong_subject_type_dict() 0 3 2
A TestSubjectsDataset.test_data_loader() 0 8 2
A TestSubjectsDataset.test_wrong_transform_arg() 0 3 2
A TestSubjectsDataset.test_indexing_nonint() 0 3 1
A TestSubjectsDataset.test_wrong_transform_init() 0 5 2
A TestSubjectsDataset.iterate_dataset() 0 5 2
1
#!/usr/bin/env python
2
3
from torchio import DATA, SubjectsDataset
4
from ..utils import TorchioTestCase
5
6
7
class TestSubjectsDataset(TorchioTestCase):
8
    
9
    def test_indexing_nonint(self):
10
        dset = SubjectsDataset(self.subjects_list)
11
        dset[torch.tensor(0)]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable torch does not seem to be defined.
Loading history...
12
13
    def test_images(self):
14
        self.iterate_dataset(self.subjects_list)
15
16
    def test_empty_subjects_list(self):
17
        with self.assertRaises(ValueError):
18
            self.iterate_dataset([])
19
20
    def test_empty_subjects_tuple(self):
21
        with self.assertRaises(ValueError):
22
            self.iterate_dataset(())
23
24
    def test_wrong_subjects_type(self):
25
        with self.assertRaises(TypeError):
26
            self.iterate_dataset(0)
27
28
    def test_wrong_subject_type_int(self):
29
        with self.assertRaises(TypeError):
30
            self.iterate_dataset([0])
31
32
    def test_wrong_subject_type_dict(self):
33
        with self.assertRaises(TypeError):
34
            self.iterate_dataset([{}])
35
36
    def test_wrong_index(self):
37
        with self.assertRaises(ValueError):
38
            self.dataset[:3]
39
40
    def test_wrong_transform_init(self):
41
        with self.assertRaises(ValueError):
42
            SubjectsDataset(
43
                self.subjects_list,
44
                transform={},
45
            )
46
47
    def test_wrong_transform_arg(self):
48
        with self.assertRaises(ValueError):
49
            self.dataset.set_transform(1)
50
51
    @staticmethod
52
    def iterate_dataset(subjects_list):
53
        dataset = SubjectsDataset(subjects_list)
54
        for _ in dataset:
55
            pass
56
57
    def test_data_loader(self):
58
        from torch.utils.data import DataLoader
59
        subj_list = [self.sample_subject]
60
        dataset = SubjectsDataset(subj_list)
61
        loader = DataLoader(dataset, batch_size=1, shuffle=True)
62
        for batch in loader:
63
            batch['t1'][DATA]
64
            batch['label'][DATA]
65