Passed
Pull Request — master (#334)
by Fernando
01:13
created

  A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
#!/usr/bin/env python
2
3
"""Tests for Image."""
4
5
import copy
6
import torch
7
import pytest
8
import numpy as np
9
from torchio import ScalarImage, LabelMap, Subject, INTENSITY, LABEL, STEM
10
from ..utils import TorchioTestCase
11
from torchio import RandomFlip, RandomAffine
12
13
14
class TestImage(TorchioTestCase):
15
    """Tests for `Image`."""
16
17
    def test_image_not_found(self):
18
        with self.assertRaises(FileNotFoundError):
19
            ScalarImage('nopath')
20
21
    def test_wrong_path_value(self):
22
        with self.assertRaises(RuntimeError):
23
            ScalarImage('~&./@#"!?X7=+')
24
25
    def test_wrong_path_type(self):
26
        with self.assertRaises(TypeError):
27
            ScalarImage(5)
28
29
    def test_wrong_affine(self):
30
        with self.assertRaises(TypeError):
31
            ScalarImage(5, affine=1)
32
33
    def test_tensor_flip(self):
34
        sample_input = torch.ones((4, 30, 30, 30))
35
        RandomFlip()(sample_input)
36
37
    def test_tensor_affine(self):
38
        sample_input = torch.ones((4, 10, 10, 10))
39
        RandomAffine()(sample_input)
40
41
    def test_scalar_image_type(self):
42
        data = torch.ones((1, 10, 10, 10))
43
        image = ScalarImage(tensor=data)
44
        self.assertIs(image.type, INTENSITY)
45
46
    def test_label_map_type(self):
47
        data = torch.ones((1, 10, 10, 10))
48
        label = LabelMap(tensor=data)
49
        self.assertIs(label.type, LABEL)
50
51
    def test_wrong_scalar_image_type(self):
52
        data = torch.ones((1, 10, 10, 10))
53
        with self.assertRaises(ValueError):
54
            ScalarImage(tensor=data, type=LABEL)
55
56
    def test_wrong_label_map_type(self):
57
        data = torch.ones((1, 10, 10, 10))
58
        with self.assertRaises(ValueError):
59
            LabelMap(tensor=data, type=INTENSITY)
60
61
    def test_no_input(self):
62
        with self.assertRaises(ValueError):
63
            ScalarImage()
64
65
    def test_bad_key(self):
66
        with self.assertRaises(ValueError):
67
            ScalarImage(path='', data=5)
68
69
    def test_repr(self):
70
        sample = Subject(t1=ScalarImage(self.get_image_path('repr_test')))
71
        assert 'shape' not in repr(sample['t1'])
72
        sample.load()
73
        assert 'shape' in repr(sample['t1'])
74
75
    def test_data_tensor(self):
76
        sample = copy.deepcopy(self.sample)
77
        sample.load()
78
        self.assertIs(sample.t1.data, sample.t1.tensor)
79
80
    def test_bad_affine(self):
81
        with self.assertRaises(ValueError):
82
            ScalarImage(tensor=torch.rand(1, 2, 3, 4), affine=np.eye(3))
83
84
    def test_nans_tensor(self):
85
        tensor = np.random.rand(1, 2, 3, 4)
86
        tensor[0, 0, 0, 0] = np.nan
87
        with self.assertWarns(UserWarning):
88
            image = ScalarImage(tensor=tensor, check_nans=True)
89
        image.set_check_nans(False)
90
91
    def test_get_center(self):
92
        tensor = torch.rand(1, 3, 3, 3)
93
        image = ScalarImage(tensor=tensor)
94
        ras = image.get_center()
95
        lps = image.get_center(lps=True)
96
        self.assertEqual(ras, (1, 1, 1))
97
        self.assertEqual(lps, (-1, -1, 1))
98
99
    def test_with_list_of_missing_files(self):
100
        with self.assertRaises(FileNotFoundError):
101
            ScalarImage(path=['nopath', 'error'])
102
103
    def test_with_a_list_of_paths(self):
104
        shape = (5, 5, 5)
105
        path1 = self.get_image_path('path1', shape=shape)
106
        path2 = self.get_image_path('path2', shape=shape)
107
        image = ScalarImage(path=[path1, path2])
108
        self.assertEqual(image.shape, (2, 5, 5, 5))
109
        self.assertEqual(image[STEM], ['path1', 'path2'])
110
111
    def test_with_a_list_of_images_with_different_shapes(self):
112
        path1 = self.get_image_path('path1', shape=(5, 5, 5))
113
        path2 = self.get_image_path('path2', shape=(7, 5, 5))
114
        image = ScalarImage(path=[path1, path2])
115
        with self.assertRaises(RuntimeError):
116
            image.load()
117
118
    def test_with_a_list_of_images_with_different_affines(self):
119
        path1 = self.get_image_path('path1', spacing=(1, 1, 1))
120
        path2 = self.get_image_path('path2', spacing=(1, 2, 1))
121
        image = ScalarImage(path=[path1, path2])
122
        with self.assertWarns(RuntimeWarning):
123
            image.load()
124
125
    def test_with_a_list_of_2d_paths(self):
126
        shape = (5, 6)
127
        path1 = self.get_image_path('path1', shape=shape, suffix='.nii')
128
        path2 = self.get_image_path('path2', shape=shape, suffix='.img')
129
        path3 = self.get_image_path('path3', shape=shape, suffix='.hdr')
130
        image = ScalarImage(path=[path1, path2, path3])
131
        self.assertEqual(image.shape, (3, 5, 6, 1))
132
        self.assertEqual(image[STEM], ['path1', 'path2', 'path3'])
133
134
    def test_axis_name_2d(self):
135
        path = self.get_image_path('im2d', shape=(5, 6))
136
        image = ScalarImage(path)
137
        height_idx = image.axis_name_to_index('t')
138
        width_idx = image.axis_name_to_index('l')
139
        self.assertEqual(image.height, image.shape[height_idx])
140
        self.assertEqual(image.width, image.shape[width_idx])
141
142
    def test_plot(self):
143
        image = self.sample.t1
144
        image.plot(show=False, output_path=self.dir / 'image.png')
145