Passed
Pull Request — main (#1400)
by
unknown
01:35
created

tests.transforms.test_custom_image_subclass   A

Complexity

Total Complexity 19

Size/Duplication

Total Lines 225
Duplicated Lines 6.22 %

Importance

Changes 0
Metric Value
eloc 120
dl 14
loc 225
rs 10
c 0
b 0
f 0
wmc 19

16 Methods

Rating   Name   Duplication   Size   Complexity  
A TestCustomImageSubclass.history_image() 0 6 1
A HistoryImage.__init__() 3 3 1
A MetadataImage.__init__() 0 3 1
A TestCustomImageSubclass.test_new_like_method_directly() 0 18 1
A HistoryImage.new_like() 7 7 2
A MetadataImage.new_like() 0 7 2
A TestCustomImageSubclass.test_backward_compatibility_standard_images() 0 15 1
A TestCustomImageSubclass.test_chained_transforms_preserve_attributes() 0 20 1
A TestCustomImageSubclass.test_to_reference_space_with_custom_image() 0 18 1
A TestCustomImageSubclass.test_label_map_subclass() 0 35 2
A TestCustomImageSubclass.metadata_image() 0 7 1
A TestCustomImageSubclass.test_new_like_with_default_affine() 0 14 1
A TestCustomImageSubclass.test_crop_with_history_image() 0 13 1
A TestCustomImageSubclass.history_subject() 0 4 1
A TestCustomImageSubclass.metadata_subject() 0 4 1
A TestCustomImageSubclass.test_crop_with_metadata_image() 0 13 1

How to fix   Duplicated Code   

Duplicated Code

Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.

Common duplication problems, and corresponding solutions are:

1
"""Tests for custom Image subclasses with transforms."""
2
3
import torch
4
import pytest
5
import torchio as tio
6
7
8 View Code Duplication
class HistoryImage(tio.ScalarImage):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
9
    """Test custom Image with required parameter."""
10
11
    def __init__(self, tensor, affine, history, **kwargs):
12
        super().__init__(tensor=tensor, affine=affine, **kwargs)
13
        self.history = history
14
15
    def new_like(self, tensor, affine=None):
16
        return type(self)(
17
            tensor=tensor,
18
            affine=affine if affine is not None else self.affine,
19
            history=self.history,
20
            check_nans=self.check_nans,
21
            reader=self.reader,
22
        )
23
24
25
class MetadataImage(tio.ScalarImage):
26
    """Test custom Image with optional parameter."""
27
28
    def __init__(self, tensor, affine, metadata=None, **kwargs):
29
        super().__init__(tensor=tensor, affine=affine, **kwargs)
30
        self.metadata = metadata or {}
31
32
    def new_like(self, tensor, affine=None):
33
        return type(self)(
34
            tensor=tensor,
35
            affine=affine if affine is not None else self.affine,
36
            metadata=self.metadata,
37
            check_nans=self.check_nans,
38
            reader=self.reader,
39
        )
40
41
42
class TestCustomImageSubclass:
43
    """Test suite for custom Image subclasses with transforms."""
44
45
    @pytest.fixture
46
    def history_image(self):
47
        """Create a HistoryImage for testing."""
48
        tensor = torch.rand(1, 10, 10, 10)
49
        affine = torch.eye(4)
50
        return HistoryImage(tensor=tensor, affine=affine, history=['created'])
51
52
    @pytest.fixture
53
    def metadata_image(self):
54
        """Create a MetadataImage for testing."""
55
        tensor = torch.rand(1, 12, 12, 12)
56
        affine = torch.eye(4)
57
        return MetadataImage(
58
            tensor=tensor, affine=affine, metadata={'id': 123, 'source': 'test'}
59
        )
60
61
    @pytest.fixture
62
    def history_subject(self, history_image):
63
        """Create a Subject with HistoryImage."""
64
        return tio.Subject(image=history_image)
65
66
    @pytest.fixture
67
    def metadata_subject(self, metadata_image):
68
        """Create a Subject with MetadataImage."""
69
        return tio.Subject(image=metadata_image)
70
71
    def test_crop_with_history_image(self, history_subject):
72
        """Test that Crop transform works with custom Image requiring history parameter."""
73
        transform = tio.Crop(cropping=2)
74
        result = transform(history_subject)
75
76
        # Check that the result is still a HistoryImage
77
        assert isinstance(result.image, HistoryImage)
78
79
        # Check that custom attribute is preserved
80
        assert result.image.history == ['created']
81
82
        # Check that cropping worked correctly
83
        assert result.image.shape == (1, 6, 6, 6)
84
85
    def test_crop_with_metadata_image(self, metadata_subject):
86
        """Test that Crop transform works with custom Image with optional parameters."""
87
        transform = tio.Crop(cropping=1)
88
        result = transform(metadata_subject)
89
90
        # Check that the result is still a MetadataImage
91
        assert isinstance(result.image, MetadataImage)
92
93
        # Check that custom attribute is preserved
94
        assert result.image.metadata == {'id': 123, 'source': 'test'}
95
96
        # Check that cropping worked correctly
97
        assert result.image.shape == (1, 10, 10, 10)
98
99
    def test_chained_transforms_preserve_attributes(self, history_subject):
100
        """Test that chained transforms preserve custom attributes."""
101
        # Chain multiple transforms
102
        transform = tio.Compose(
103
            [
104
                tio.Crop(cropping=1),
105
                tio.Crop(cropping=1),
106
            ]
107
        )
108
109
        result = transform(history_subject)
110
111
        # Check that the result is still a HistoryImage after multiple transforms
112
        assert isinstance(result.image, HistoryImage)
113
114
        # Check that custom attribute is preserved through the chain
115
        assert result.image.history == ['created']
116
117
        # Check that both crops were applied
118
        assert result.image.shape == (1, 6, 6, 6)
119
120
    def test_backward_compatibility_standard_images(self):
121
        """Test that standard Images still work with transforms."""
122
        # Create a standard ScalarImage
123
        tensor = torch.rand(1, 10, 10, 10)
124
        affine = torch.eye(4)
125
        image = tio.ScalarImage(tensor=tensor, affine=affine)
126
        subject = tio.Subject(image=image)
127
128
        # Apply transform
129
        transform = tio.Crop(cropping=2)
130
        result = transform(subject)
131
132
        # Check that it still works
133
        assert isinstance(result.image, tio.ScalarImage)
134
        assert result.image.shape == (1, 6, 6, 6)
135
136
    def test_to_reference_space_with_custom_image(self, history_image):
137
        """Test that ToReferenceSpace works with custom images."""
138
        # Create a reference image
139
        reference_tensor = torch.rand(1, 20, 20, 20)
140
        reference_affine = torch.eye(4)
141
        reference = tio.ScalarImage(tensor=reference_tensor, affine=reference_affine)
142
143
        # Create embedding tensor (smaller than reference)
144
        embedding_tensor = torch.rand(1, 10, 10, 10)
145
146
        # Use ToReferenceSpace.from_tensor
147
        result = tio.ToReferenceSpace.from_tensor(embedding_tensor, history_image)
148
149
        # Check that the result preserves the custom class type
150
        assert isinstance(result, HistoryImage)
151
152
        # Check that custom attribute is preserved
153
        assert result.history == ['created']
154
155
    def test_new_like_method_directly(self, history_image):
156
        """Test the new_like method directly."""
157
        new_tensor = torch.rand(1, 5, 5, 5)
158
        new_affine = torch.eye(4) * 2
159
160
        # Create new image using new_like
161
        new_image = history_image.new_like(tensor=new_tensor, affine=new_affine)
162
163
        # Check type preservation
164
        assert isinstance(new_image, HistoryImage)
165
166
        # Check attribute preservation
167
        assert new_image.history == ['created']
168
169
        # Check new data
170
        assert torch.equal(new_image.data, new_tensor)
171
        assert torch.allclose(
172
            torch.tensor(new_image.affine).float(), new_affine.float()
173
        )
174
175
    def test_new_like_with_default_affine(self, metadata_image):
176
        """Test new_like method with default affine (None)."""
177
        new_tensor = torch.rand(1, 8, 8, 8)
178
179
        # Create new image using new_like with default affine
180
        new_image = metadata_image.new_like(tensor=new_tensor)
181
182
        # Check that original affine is used
183
        assert torch.allclose(
184
            torch.tensor(new_image.affine), torch.tensor(metadata_image.affine)
185
        )
186
187
        # Check attribute preservation
188
        assert new_image.metadata == {'id': 123, 'source': 'test'}
189
190
    def test_label_map_subclass(self):
191
        """Test that custom LabelMap subclasses also work."""
192
193
        class CustomLabelMap(tio.LabelMap):
194
            def __init__(self, tensor, affine, labels_info, **kwargs):
195
                super().__init__(tensor=tensor, affine=affine, **kwargs)
196
                self.labels_info = labels_info
197
198
            def new_like(self, tensor, affine=None):
199
                return type(self)(
200
                    tensor=tensor,
201
                    affine=affine if affine is not None else self.affine,
202
                    labels_info=self.labels_info,
203
                    check_nans=self.check_nans,
204
                    reader=self.reader,
205
                )
206
207
        # Create custom label map
208
        tensor = torch.randint(0, 3, (1, 8, 8, 8))
209
        affine = torch.eye(4)
210
        labels_info = {0: 'background', 1: 'tissue1', 2: 'tissue2'}
211
212
        custom_label = CustomLabelMap(
213
            tensor=tensor, affine=affine, labels_info=labels_info
214
        )
215
        subject = tio.Subject(labels=custom_label)
216
217
        # Apply transform
218
        transform = tio.Crop(cropping=1)
219
        result = transform(subject)
220
221
        # Check preservation
222
        assert isinstance(result.labels, CustomLabelMap)
223
        assert result.labels.labels_info == labels_info
224
        assert result.labels.shape == (1, 6, 6, 6)
225