Passed
Pull Request — master (#332)
by Fernando
01:14
created

tests.data.test_image   A

Complexity

Total Complexity 41

Size/Duplication

Total Lines 167
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 132
dl 0
loc 167
rs 9.1199
c 0
b 0
f 0
wmc 41

28 Methods

Rating   Name   Duplication   Size   Complexity  
A TestImage.test_get_center() 0 7 1
A TestImage.test_label_map_type() 0 4 1
A TestImage.test_repr() 0 5 1
A TestImage.test_bad_key() 0 3 2
A TestImage.test_crop_label_map_type() 0 5 1
A TestImage.test_image_not_found() 0 3 2
A TestImage.test_wrong_label_map_type() 0 4 2
A TestImage.test_wrong_path_type() 0 3 2
A TestImage.test_wrong_affine() 0 3 2
A TestImage.test_crop_scalar_image_type() 0 5 1
A TestImage.test_with_a_list_of_paths() 0 7 1
A TestImage.test_with_list_of_missing_files() 0 3 2
A TestImage.test_tensor_flip() 0 3 1
A TestImage.test_crop_does_not_create_wrong_path() 0 5 1
A TestImage.test_wrong_scalar_image_type() 0 4 2
A TestImage.test_axis_name_2d() 0 7 1
A TestImage.test_data_tensor() 0 4 1
A TestImage.test_bad_affine() 0 3 2
A TestImage.test_crop_attributes() 0 3 1
A TestImage.test_with_a_list_of_2d_paths() 0 8 1
A TestImage.test_no_input() 0 3 2
A TestImage.test_scalar_image_type() 0 4 1
A TestImage.test_with_a_list_of_images_with_different_affines() 0 6 2
A TestImage.test_nans_tensor() 0 6 2
A TestImage.test_tensor_affine() 0 3 1
A TestImage.test_wrong_path_value() 0 3 2
A TestImage.test_with_a_list_of_images_with_different_shapes() 0 6 2
A TestImage.test_plot() 0 3 1

How to fix   Complexity   

Complexity

Complex classes like tests.data.test_image often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
3
"""Tests for Image."""
4
5
import copy
6
import torch
7
import pytest
8
import numpy as np
9
from torchio import ScalarImage, LabelMap, Subject, INTENSITY, LABEL, STEM
10
from ..utils import TorchioTestCase
11
from torchio import RandomFlip, RandomAffine
12
13
14
class TestImage(TorchioTestCase):
15
    """Tests for `Image`."""
16
17
    def test_image_not_found(self):
18
        with self.assertRaises(FileNotFoundError):
19
            ScalarImage('nopath')
20
21
    def test_wrong_path_value(self):
22
        with self.assertRaises(RuntimeError):
23
            ScalarImage('~&./@#"!?X7=+')
24
25
    def test_wrong_path_type(self):
26
        with self.assertRaises(TypeError):
27
            ScalarImage(5)
28
29
    def test_wrong_affine(self):
30
        with self.assertRaises(TypeError):
31
            ScalarImage(5, affine=1)
32
33
    def test_tensor_flip(self):
34
        sample_input = torch.ones((4, 30, 30, 30))
35
        RandomFlip()(sample_input)
36
37
    def test_tensor_affine(self):
38
        sample_input = torch.ones((4, 10, 10, 10))
39
        RandomAffine()(sample_input)
40
41
    def test_crop_attributes(self):
42
        cropped = self.sample.crop((1, 1, 1), (5, 5, 5))
43
        self.assertIs(self.sample.t1['pre_affine'], cropped.t1['pre_affine'])
44
45
    def test_crop_does_not_create_wrong_path(self):
46
        data = torch.ones((1, 10, 10, 10))
47
        image = ScalarImage(tensor=data)
48
        cropped = image.crop((1, 1, 1), (5, 5, 5))
49
        self.assertIs(cropped.path, None)
50
51
    def test_scalar_image_type(self):
52
        data = torch.ones((1, 10, 10, 10))
53
        image = ScalarImage(tensor=data)
54
        self.assertIs(image.type, INTENSITY)
55
56
    def test_label_map_type(self):
57
        data = torch.ones((1, 10, 10, 10))
58
        label = LabelMap(tensor=data)
59
        self.assertIs(label.type, LABEL)
60
61
    def test_wrong_scalar_image_type(self):
62
        data = torch.ones((1, 10, 10, 10))
63
        with self.assertRaises(ValueError):
64
            ScalarImage(tensor=data, type=LABEL)
65
66
    def test_wrong_label_map_type(self):
67
        data = torch.ones((1, 10, 10, 10))
68
        with self.assertRaises(ValueError):
69
            LabelMap(tensor=data, type=INTENSITY)
70
71
    def test_crop_scalar_image_type(self):
72
        data = torch.ones((1, 10, 10, 10))
73
        image = ScalarImage(tensor=data)
74
        cropped = image.crop((1, 1, 1), (5, 5, 5))
75
        self.assertIs(cropped.type, INTENSITY)
76
77
    def test_crop_label_map_type(self):
78
        data = torch.ones((1, 10, 10, 10))
79
        label = LabelMap(tensor=data)
80
        cropped = label.crop((1, 1, 1), (5, 5, 5))
81
        self.assertIs(cropped.type, LABEL)
82
83
    def test_no_input(self):
84
        with self.assertRaises(ValueError):
85
            ScalarImage()
86
87
    def test_bad_key(self):
88
        with self.assertRaises(ValueError):
89
            ScalarImage(path='', data=5)
90
91
    def test_repr(self):
92
        sample = Subject(t1=ScalarImage(self.get_image_path('repr_test')))
93
        assert 'shape' not in repr(sample['t1'])
94
        sample.load()
95
        assert 'shape' in repr(sample['t1'])
96
97
    def test_data_tensor(self):
98
        sample = copy.deepcopy(self.sample)
99
        sample.load()
100
        self.assertIs(sample.t1.data, sample.t1.tensor)
101
102
    def test_bad_affine(self):
103
        with self.assertRaises(ValueError):
104
            ScalarImage(tensor=torch.rand(1, 2, 3, 4), affine=np.eye(3))
105
106
    def test_nans_tensor(self):
107
        tensor = np.random.rand(1, 2, 3, 4)
108
        tensor[0, 0, 0, 0] = np.nan
109
        with self.assertWarns(UserWarning):
110
            image = ScalarImage(tensor=tensor, check_nans=True)
111
        image.set_check_nans(False)
112
113
    def test_get_center(self):
114
        tensor = torch.rand(1, 3, 3, 3)
115
        image = ScalarImage(tensor=tensor)
116
        ras = image.get_center()
117
        lps = image.get_center(lps=True)
118
        self.assertEqual(ras, (1, 1, 1))
119
        self.assertEqual(lps, (-1, -1, 1))
120
121
    def test_with_list_of_missing_files(self):
122
        with self.assertRaises(FileNotFoundError):
123
            ScalarImage(path=['nopath', 'error'])
124
125
    def test_with_a_list_of_paths(self):
126
        shape = (5, 5, 5)
127
        path1 = self.get_image_path('path1', shape=shape)
128
        path2 = self.get_image_path('path2', shape=shape)
129
        image = ScalarImage(path=[path1, path2])
130
        self.assertEqual(image.shape, (2, 5, 5, 5))
131
        self.assertEqual(image[STEM], ['path1', 'path2'])
132
133
    def test_with_a_list_of_images_with_different_shapes(self):
134
        path1 = self.get_image_path('path1', shape=(5, 5, 5))
135
        path2 = self.get_image_path('path2', shape=(7, 5, 5))
136
        image = ScalarImage(path=[path1, path2])
137
        with self.assertRaises(RuntimeError):
138
            image.load()
139
140
    def test_with_a_list_of_images_with_different_affines(self):
141
        path1 = self.get_image_path('path1', spacing=(1, 1, 1))
142
        path2 = self.get_image_path('path2', spacing=(1, 2, 1))
143
        image = ScalarImage(path=[path1, path2])
144
        with self.assertWarns(RuntimeWarning):
145
            image.load()
146
147
    def test_with_a_list_of_2d_paths(self):
148
        shape = (5, 6)
149
        path1 = self.get_image_path('path1', shape=shape, suffix='.nii')
150
        path2 = self.get_image_path('path2', shape=shape, suffix='.img')
151
        path3 = self.get_image_path('path3', shape=shape, suffix='.hdr')
152
        image = ScalarImage(path=[path1, path2, path3])
153
        self.assertEqual(image.shape, (3, 5, 6, 1))
154
        self.assertEqual(image[STEM], ['path1', 'path2', 'path3'])
155
156
    def test_axis_name_2d(self):
157
        path = self.get_image_path('im2d', shape=(5, 6))
158
        image = ScalarImage(path)
159
        height_idx = image.axis_name_to_index('t')
160
        width_idx = image.axis_name_to_index('l')
161
        self.assertEqual(image.height, image.shape[height_idx])
162
        self.assertEqual(image.width, image.shape[width_idx])
163
164
    def test_plot(self):
165
        image = self.sample.t1
166
        image.plot(show=False, output_path=self.dir / 'image.png')
167