Passed
Pull Request — master (#332)
by Fernando
01:14
created

tests.data.test_subject   A

Complexity

Total Complexity 14

Size/Duplication

Total Lines 59
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 44
dl 0
loc 59
rs 10
c 0
b 0
f 0
wmc 14

7 Methods

Rating   Name   Duplication   Size   Complexity  
A TestSubject.test_history() 0 3 1
A TestSubject.test_input_dict() 0 5 2
A TestSubject.test_no_sample() 0 6 3
A TestSubject.test_positional_args() 0 4 3
A TestSubject.test_inconsistent_spatial_shape() 0 7 2
A TestSubject.test_inconsistent_shape() 0 8 2
A TestSubject.test_plot() 0 7 1
1
#!/usr/bin/env python
2
3
"""Tests for Subject."""
4
5
import tempfile
6
import torch
7
from torchio import Subject, ScalarImage, RandomFlip
8
from ..utils import TorchioTestCase
9
10
11
class TestSubject(TorchioTestCase):
12
    """Tests for `Subject`."""
13
    def test_positional_args(self):
14
        with self.assertRaises(ValueError):
15
            with tempfile.NamedTemporaryFile() as f:
16
                Subject(ScalarImage(f.name))
17
18
    def test_input_dict(self):
19
        with tempfile.NamedTemporaryFile() as f:
20
            input_dict = {'image': ScalarImage(f.name)}
21
            Subject(input_dict)
22
            Subject(**input_dict)
23
24
    def test_no_sample(self):
25
        with tempfile.NamedTemporaryFile() as f:
26
            input_dict = {'image': ScalarImage(f.name)}
27
            subject = Subject(input_dict)
28
            with self.assertRaises(RuntimeError):
29
                RandomFlip()(subject)
30
31
    def test_history(self):
32
        transformed = RandomFlip()(self.sample)
33
        self.assertIs(len(transformed.history), 1)
34
35
    def test_inconsistent_shape(self):
36
        subject = Subject(
37
            a=ScalarImage(tensor=torch.rand(1, 2, 3, 4)),
38
            b=ScalarImage(tensor=torch.rand(2, 2, 3, 4)),
39
        )
40
        subject.spatial_shape
41
        with self.assertRaises(RuntimeError):
42
            subject.shape
43
44
    def test_inconsistent_spatial_shape(self):
45
        subject = Subject(
46
            a=ScalarImage(tensor=torch.rand(1, 3, 3, 4)),
47
            b=ScalarImage(tensor=torch.rand(2, 2, 3, 4)),
48
        )
49
        with self.assertRaises(RuntimeError):
50
            subject.spatial_shape
51
52
    def test_plot(self):
53
        self.sample.plot(
54
            show=False,
55
            output_path=self.dir / 'figure.png',
56
            cmap_dict=dict(
57
                t2='viridis',
58
                label={0: 'yellow', 1: 'blue'},
59
            ),
60
        )
61