Passed
Pull Request — master (#270)
by Fernando
01:10
created

tests.data.test_image   A

Complexity

Total Complexity 42

Size/Duplication

Total Lines 167
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 132
dl 0
loc 167
rs 9.0399
c 0
b 0
f 0
wmc 42

28 Methods

Rating   Name   Duplication   Size   Complexity  
A TestImage.test_image_not_found() 0 3 2
A TestImage.test_label_map_type() 0 4 1
A TestImage.test_repr() 0 5 1
A TestImage.test_bad_key() 0 3 2
A TestImage.test_crop_label_map_type() 0 5 1
A TestImage.test_wrong_label_map_type() 0 4 2
A TestImage.test_wrong_path_type() 0 3 2
A TestImage.test_wrong_affine() 0 3 2
A TestImage.test_crop_scalar_image_type() 0 5 1
A TestImage.test_with_a_list_of_paths() 0 7 1
A TestImage.test_get_center() 0 7 1
A TestImage.test_with_list_of_missing_files() 0 3 2
A TestImage.test_tensor_flip() 0 3 1
A TestImage.test_crop_does_not_create_wrong_path() 0 5 1
A TestImage.test_wrong_scalar_image_type() 0 4 2
A TestImage.test_nans_file() 0 4 2
A TestImage.test_data_tensor() 0 4 1
A TestImage.test_bad_affine() 0 3 2
A TestImage.test_crop_attributes() 0 3 1
A TestImage.test_no_input() 0 3 2
A TestImage.test_scalar_image_type() 0 4 1
A TestImage.test_with_a_list_of_images_with_different_affines() 0 6 2
A TestImage.test_nans_tensor() 0 6 2
A TestImage.test_tensor_affine() 0 3 1
A TestImage.test_wrong_path_value() 0 3 2
A TestImage.test_with_a_list_of_images_with_different_shapes() 0 6 2
A TestImage.test_axis_name_2d() 0 7 1
A TestImage.test_with_a_list_of_2d_paths() 0 8 1

How to fix   Complexity   

Complexity

Complex classes like tests.data.test_image often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
3
"""Tests for Image."""
4
5
import copy
6
import torch
7
import numpy as np
8
from torchio import ScalarImage, LabelMap, Subject, INTENSITY, LABEL, STEM
9
from ..utils import TorchioTestCase
10
from torchio import RandomFlip, RandomAffine
11
12
13
class TestImage(TorchioTestCase):
14
    """Tests for `Image`."""
15
16
    def test_image_not_found(self):
17
        with self.assertRaises(FileNotFoundError):
18
            ScalarImage('nopath')
19
20
    def test_wrong_path_value(self):
21
        with self.assertRaises(RuntimeError):
22
            ScalarImage('~&./@#"!?X7=+')
23
24
    def test_wrong_path_type(self):
25
        with self.assertRaises(TypeError):
26
            ScalarImage(5)
27
28
    def test_wrong_affine(self):
29
        with self.assertRaises(TypeError):
30
            ScalarImage(5, affine=1)
31
32
    def test_tensor_flip(self):
33
        sample_input = torch.ones((4, 30, 30, 30))
34
        RandomFlip()(sample_input)
35
36
    def test_tensor_affine(self):
37
        sample_input = torch.ones((4, 10, 10, 10))
38
        RandomAffine()(sample_input)
39
40
    def test_crop_attributes(self):
41
        cropped = self.sample.crop((1, 1, 1), (5, 5, 5))
42
        self.assertIs(self.sample.t1['pre_affine'], cropped.t1['pre_affine'])
43
44
    def test_crop_does_not_create_wrong_path(self):
45
        data = torch.ones((10, 10, 10))
46
        image = ScalarImage(tensor=data)
47
        cropped = image.crop((1, 1, 1), (5, 5, 5))
48
        self.assertIs(cropped.path, None)
49
50
    def test_scalar_image_type(self):
51
        data = torch.ones((10, 10, 10))
52
        image = ScalarImage(tensor=data)
53
        self.assertIs(image.type, INTENSITY)
54
55
    def test_label_map_type(self):
56
        data = torch.ones((10, 10, 10))
57
        label = LabelMap(tensor=data)
58
        self.assertIs(label.type, LABEL)
59
60
    def test_wrong_scalar_image_type(self):
61
        data = torch.ones((10, 10, 10))
62
        with self.assertRaises(ValueError):
63
            ScalarImage(tensor=data, type=LABEL)
64
65
    def test_wrong_label_map_type(self):
66
        data = torch.ones((10, 10, 10))
67
        with self.assertRaises(ValueError):
68
            LabelMap(tensor=data, type=INTENSITY)
69
70
    def test_crop_scalar_image_type(self):
71
        data = torch.ones((10, 10, 10))
72
        image = ScalarImage(tensor=data)
73
        cropped = image.crop((1, 1, 1), (5, 5, 5))
74
        self.assertIs(cropped.type, INTENSITY)
75
76
    def test_crop_label_map_type(self):
77
        data = torch.ones((10, 10, 10))
78
        label = LabelMap(tensor=data)
79
        cropped = label.crop((1, 1, 1), (5, 5, 5))
80
        self.assertIs(cropped.type, LABEL)
81
82
    def test_no_input(self):
83
        with self.assertRaises(ValueError):
84
            image = ScalarImage()
85
86
    def test_bad_key(self):
87
        with self.assertRaises(ValueError):
88
            image = ScalarImage(path='', data=5)
89
90
    def test_repr(self):
91
        sample = Subject(t1=ScalarImage(self.get_image_path('repr_test')))
92
        assert 'shape' not in repr(sample['t1'])
93
        sample.load()
94
        assert 'shape' in repr(sample['t1'])
95
96
    def test_data_tensor(self):
97
        sample = copy.deepcopy(self.sample)
98
        sample.load()
99
        self.assertIs(sample.t1.data, sample.t1.tensor)
100
101
    def test_bad_affine(self):
102
        with self.assertRaises(ValueError):
103
            ScalarImage(tensor=torch.rand(1, 2, 3, 4), affine=np.eye(3))
104
105
    def test_nans_tensor(self):
106
        tensor = np.random.rand(1, 2, 3, 4)
107
        tensor[0, 0, 0, 0] = np.nan
108
        with self.assertWarns(UserWarning):
109
            image = ScalarImage(tensor=tensor)
110
        image.set_check_nans(False)
111
112
    def test_nans_file(self):
113
        image = ScalarImage(self.get_image_path('repr_test', add_nans=True))
114
        with self.assertWarns(UserWarning):
115
            image.load()
116
117
    def test_get_center(self):
118
        tensor = torch.rand(1, 3, 3, 3)
119
        image = ScalarImage(tensor=tensor)
120
        ras = image.get_center()
121
        lps = image.get_center(lps=True)
122
        self.assertEqual(ras, (1, 1, 1))
123
        self.assertEqual(lps, (-1, -1, 1))
124
125
    def test_with_list_of_missing_files(self):
126
        with self.assertRaises(FileNotFoundError):
127
            ScalarImage(path=['nopath', 'error'])
128
129
    def test_with_a_list_of_paths(self):
130
        shape = (5, 5, 5)
131
        path1 = self.get_image_path('path1', shape=shape)
132
        path2 = self.get_image_path('path2', shape=shape)
133
        image = ScalarImage(path=[path1, path2])
134
        self.assertEqual(image.shape, (2, 5, 5, 5))
135
        self.assertEqual(image[STEM], ['path1', 'path2'])
136
137
    def test_with_a_list_of_images_with_different_shapes(self):
138
        path1 = self.get_image_path('path1', shape=(5, 5, 5))
139
        path2 = self.get_image_path('path2', shape=(7, 5, 5))
140
        image = ScalarImage(path=[path1, path2])
141
        with self.assertRaises(RuntimeError):
142
            image.load()
143
144
    def test_with_a_list_of_images_with_different_affines(self):
145
        path1 = self.get_image_path('path1', spacing=(1, 1, 1))
146
        path2 = self.get_image_path('path2', spacing=(1, 2, 1))
147
        image = ScalarImage(path=[path1, path2])
148
        with self.assertWarns(RuntimeWarning):
149
            image.load()
150
151
    def test_with_a_list_of_2d_paths(self):
152
        shape = (5, 6)
153
        path1 = self.get_image_path('path1', shape=shape, suffix='.nii')
154
        path2 = self.get_image_path('path2', shape=shape, suffix='.img')
155
        path3 = self.get_image_path('path3', shape=shape, suffix='.hdr')
156
        image = ScalarImage(path=[path1, path2, path3])
157
        self.assertEqual(image.shape, (3, 5, 6, 1))
158
        self.assertEqual(image[STEM], ['path1', 'path2', 'path3'])
159
160
    def test_axis_name_2d(self):
161
        path = self.get_image_path('im2d', shape=(5, 6))
162
        image = ScalarImage(path)
163
        height_idx = image.axis_name_to_index('h')
164
        width_idx = image.axis_name_to_index('w')
165
        self.assertEqual(image.height, image.shape[height_idx])
166
        self.assertEqual(image.width, image.shape[width_idx])
167