Passed
Pull Request — master (#246)
by Fernando
01:07
created

tests.data.test_image.TestImage.test_nans_tensor()   A

Complexity

Conditions 2

Size

Total Lines 6
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 6
nop 1
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
#!/usr/bin/env python
2
3
"""Tests for Image."""
4
5
import copy
6
import torch
7
import numpy as np
8
from torchio import INTENSITY, LABEL, Image, ScalarImage, LabelMap, Subject
9
from ..utils import TorchioTestCase
10
from torchio import RandomFlip, RandomAffine
11
12
13
class TestImage(TorchioTestCase):
14
    """Tests for `Image`."""
15
16
    def test_image_not_found(self):
17
        with self.assertRaises(FileNotFoundError):
18
            Image('nopath', type=INTENSITY)
19
20
    def test_wrong_path_type(self):
21
        with self.assertRaises(TypeError):
22
            Image(5, type=INTENSITY)
23
24
    def test_wrong_affine(self):
25
        with self.assertRaises(TypeError):
26
            Image(5, type=INTENSITY, affine=1)
27
28
    def test_tensor_flip(self):
29
        sample_input = torch.ones((4, 30, 30, 30))
30
        RandomFlip()(sample_input)
31
32
    def test_tensor_affine(self):
33
        sample_input = torch.ones((4, 10, 10, 10))
34
        RandomAffine()(sample_input)
35
36
    def test_crop_attributes(self):
37
        cropped = self.sample.crop((1, 1, 1), (5, 5, 5))
38
        self.assertIs(self.sample.t1['pre_affine'], cropped.t1['pre_affine'])
39
40
    def test_crop_does_not_create_wrong_path(self):
41
        data = torch.ones((10, 10, 10))
42
        image = Image(tensor=data)
43
        cropped = image.crop((1, 1, 1), (5, 5, 5))
44
        self.assertIs(cropped.path, None)
45
46
    def test_scalar_image_type(self):
47
        data = torch.ones((10, 10, 10))
48
        image = ScalarImage(tensor=data)
49
        self.assertIs(image.type, INTENSITY)
50
51
    def test_label_map_type(self):
52
        data = torch.ones((10, 10, 10))
53
        label = LabelMap(tensor=data)
54
        self.assertIs(label.type, LABEL)
55
56
    def test_wrong_scalar_image_type(self):
57
        data = torch.ones((10, 10, 10))
58
        with self.assertRaises(ValueError):
59
            ScalarImage(tensor=data, type=LABEL)
60
61
    def test_wrong_label_map_type(self):
62
        data = torch.ones((10, 10, 10))
63
        with self.assertRaises(ValueError):
64
            LabelMap(tensor=data, type=INTENSITY)
65
66
    def test_crop_scalar_image_type(self):
67
        data = torch.ones((10, 10, 10))
68
        image = ScalarImage(tensor=data)
69
        cropped = image.crop((1, 1, 1), (5, 5, 5))
70
        self.assertIs(cropped.type, INTENSITY)
71
72
    def test_crop_label_map_type(self):
73
        data = torch.ones((10, 10, 10))
74
        label = LabelMap(tensor=data)
75
        cropped = label.crop((1, 1, 1), (5, 5, 5))
76
        self.assertIs(cropped.type, LABEL)
77
78
    def test_no_input(self):
79
        with self.assertRaises(ValueError):
80
            image = Image()
81
82
    def test_bad_key(self):
83
        with self.assertRaises(ValueError):
84
            image = Image(path='', data=5)
85
86
    def test_repr(self):
87
        sample = Subject(t1=ScalarImage(self.get_image_path('repr_test')))
88
        assert 'shape' not in repr(sample['t1'])
89
        sample.load()
90
        assert 'shape' in repr(sample['t1'])
91
92
    def test_data_tensor(self):
93
        sample = copy.deepcopy(self.sample)
94
        sample.load()
95
        self.assertIs(sample.t1.data, sample.t1.tensor)
96
97
    def test_bad_affine(self):
98
        with self.assertRaises(ValueError):
99
            Image(tensor=torch.rand(1, 2, 3, 4), affine=np.eye(3))
100
101
    def test_nans_tensor(self):
102
        tensor = np.random.rand(1, 2, 3, 4)
103
        tensor[0, 0, 0, 0] = np.nan
104
        with self.assertWarns(UserWarning):
105
            image = Image(tensor=tensor)
106
        image.set_check_nans(False)
107
108
    def test_nans_file(self):
109
        image = Image(self.get_image_path('repr_test', add_nans=True))
110
        with self.assertWarns(UserWarning):
111
            image._load()
112
113
    def test_get_center(self):
114
        tensor = torch.rand(1, 3, 3, 3)
115
        image = Image(tensor=tensor)
116
        ras = image.get_center()
117
        lps = image.get_center(lps=True)
118
        self.assertEqual(ras, (1, 1, 1))
119
        self.assertEqual(lps, (-1, -1, 1))
120