Passed
Push — master ( 53ab14...c2608f )
by Fernando
01:07
created

TestImage.test_wrong_affine()   A

Complexity

Conditions 2

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 3
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
#!/usr/bin/env python
2
3
"""Tests for Image."""
4
5
import copy
6
import torch
7
import numpy as np
8
from torchio import ScalarImage, LabelMap, Subject, INTENSITY, LABEL, STEM
9
from ..utils import TorchioTestCase
10
from torchio import RandomFlip, RandomAffine
11
12
13
class TestImage(TorchioTestCase):
14
    """Tests for `Image`."""
15
16
    def test_image_not_found(self):
17
        with self.assertRaises(FileNotFoundError):
18
            ScalarImage('nopath')
19
20
    def test_wrong_path_value(self):
21
        with self.assertRaises(RuntimeError):
22
            ScalarImage('~&./@#"!?X7=+')
23
24
    def test_wrong_path_type(self):
25
        with self.assertRaises(TypeError):
26
            ScalarImage(5)
27
28
    def test_wrong_affine(self):
29
        with self.assertRaises(TypeError):
30
            ScalarImage(5, affine=1)
31
32
    def test_tensor_flip(self):
33
        sample_input = torch.ones((4, 30, 30, 30))
34
        RandomFlip()(sample_input)
35
36
    def test_tensor_affine(self):
37
        sample_input = torch.ones((4, 10, 10, 10))
38
        RandomAffine()(sample_input)
39
40
    def test_crop_attributes(self):
41
        cropped = self.sample.crop((1, 1, 1), (5, 5, 5))
42
        self.assertIs(self.sample.t1['pre_affine'], cropped.t1['pre_affine'])
43
44
    def test_crop_does_not_create_wrong_path(self):
45
        data = torch.ones((10, 10, 10))
46
        image = ScalarImage(tensor=data)
47
        cropped = image.crop((1, 1, 1), (5, 5, 5))
48
        self.assertIs(cropped.path, None)
49
50
    def test_scalar_image_type(self):
51
        data = torch.ones((10, 10, 10))
52
        image = ScalarImage(tensor=data)
53
        self.assertIs(image.type, INTENSITY)
54
55
    def test_label_map_type(self):
56
        data = torch.ones((10, 10, 10))
57
        label = LabelMap(tensor=data)
58
        self.assertIs(label.type, LABEL)
59
60
    def test_wrong_scalar_image_type(self):
61
        data = torch.ones((10, 10, 10))
62
        with self.assertRaises(ValueError):
63
            ScalarImage(tensor=data, type=LABEL)
64
65
    def test_wrong_label_map_type(self):
66
        data = torch.ones((10, 10, 10))
67
        with self.assertRaises(ValueError):
68
            LabelMap(tensor=data, type=INTENSITY)
69
70
    def test_crop_scalar_image_type(self):
71
        data = torch.ones((10, 10, 10))
72
        image = ScalarImage(tensor=data)
73
        cropped = image.crop((1, 1, 1), (5, 5, 5))
74
        self.assertIs(cropped.type, INTENSITY)
75
76
    def test_crop_label_map_type(self):
77
        data = torch.ones((10, 10, 10))
78
        label = LabelMap(tensor=data)
79
        cropped = label.crop((1, 1, 1), (5, 5, 5))
80
        self.assertIs(cropped.type, LABEL)
81
82
    def test_no_input(self):
83
        with self.assertRaises(ValueError):
84
            image = ScalarImage()
85
86
    def test_bad_key(self):
87
        with self.assertRaises(ValueError):
88
            image = ScalarImage(path='', data=5)
89
90
    def test_repr(self):
91
        sample = Subject(t1=ScalarImage(self.get_image_path('repr_test')))
92
        assert 'shape' not in repr(sample['t1'])
93
        sample.load()
94
        assert 'shape' in repr(sample['t1'])
95
96
    def test_data_tensor(self):
97
        sample = copy.deepcopy(self.sample)
98
        sample.load()
99
        self.assertIs(sample.t1.data, sample.t1.tensor)
100
101
    def test_bad_affine(self):
102
        with self.assertRaises(ValueError):
103
            ScalarImage(tensor=torch.rand(1, 2, 3, 4), affine=np.eye(3))
104
105
    def test_nans_tensor(self):
106
        tensor = np.random.rand(1, 2, 3, 4)
107
        tensor[0, 0, 0, 0] = np.nan
108
        with self.assertWarns(UserWarning):
109
            image = ScalarImage(tensor=tensor)
110
        image.set_check_nans(False)
111
112
    def test_nans_file(self):
113
        image = ScalarImage(self.get_image_path('repr_test', add_nans=True))
114
        with self.assertWarns(UserWarning):
115
            image.load()
116
117
    def test_get_center(self):
118
        tensor = torch.rand(1, 3, 3, 3)
119
        image = ScalarImage(tensor=tensor)
120
        ras = image.get_center()
121
        lps = image.get_center(lps=True)
122
        self.assertEqual(ras, (1, 1, 1))
123
        self.assertEqual(lps, (-1, -1, 1))
124
125
    def test_with_list_of_missing_files(self):
126
        with self.assertRaises(FileNotFoundError):
127
            ScalarImage(path=['nopath', 'error'])
128
129
    def test_with_a_list_of_paths(self):
130
        shape = (5, 5, 5)
131
        path1 = self.get_image_path('path1', shape=shape)
132
        path2 = self.get_image_path('path2', shape=shape)
133
        image = ScalarImage(path=[path1, path2])
134
        self.assertEqual(image.shape, (2, 5, 5, 5))
135
        self.assertEqual(image[STEM], ['path1', 'path2'])
136
137
    def test_with_a_list_of_images_with_different_shapes(self):
138
        path1 = self.get_image_path('path1', shape=(5, 5, 5))
139
        path2 = self.get_image_path('path2', shape=(7, 5, 5))
140
        image = ScalarImage(path=[path1, path2])
141
        with self.assertRaises(RuntimeError):
142
            image.load()
143
144
    def test_with_a_list_of_images_with_different_affines(self):
145
        path1 = self.get_image_path('path1', spacing=(1, 1, 1))
146
        path2 = self.get_image_path('path2', spacing=(1, 2, 1))
147
        image = ScalarImage(path=[path1, path2])
148
        with self.assertWarns(RuntimeWarning):
149
            image.load()
150
151
    def test_with_a_list_of_2d_paths(self):
152
        shape = (5, 6)
153
        path1 = self.get_image_path('path1', shape=shape, suffix='.nii')
154
        path2 = self.get_image_path('path2', shape=shape, suffix='.img')
155
        path3 = self.get_image_path('path3', shape=shape, suffix='.hdr')
156
        image = ScalarImage(path=[path1, path2, path3])
157
        self.assertEqual(image.shape, (3, 5, 6, 1))
158
        self.assertEqual(image[STEM], ['path1', 'path2', 'path3'])
159
160
    def test_axis_name_2d(self):
161
        path = self.get_image_path('im2d', shape=(5, 6))
162
        image = ScalarImage(path)
163
        height_idx = image.axis_name_to_index('h')
164
        width_idx = image.axis_name_to_index('w')
165
        self.assertEqual(image.height, image.shape[height_idx])
166
        self.assertEqual(image.width, image.shape[width_idx])
167