Completed
Pull Request — master (#246)
by Fernando
92:34 queued 91:37
created

tests.utils.TorchioTestCase.assertTensorNotEqual()   A

Complexity

Conditions 2

Size

Total Lines 4
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 4
nop 3
dl 0
loc 4
rs 10
c 0
b 0
f 0
1
import copy
2
import shutil
3
import random
4
import tempfile
5
import unittest
6
from pathlib import Path
7
8
import torch
9
import numpy as np
10
import nibabel as nib
11
from torchio.datasets import IXITiny
12
from torchio import INTENSITY, LABEL, DATA, Image, ImagesDataset, Subject
13
14
15
class TorchioTestCase(unittest.TestCase):
16
17
    def setUp(self):
18
        """Set up test fixtures, if any."""
19
        self.dir = Path(tempfile.gettempdir()) / '.torchio_tests'
20
        self.dir.mkdir(exist_ok=True)
21
        random.seed(42)
22
        np.random.seed(42)
23
24
        registration_matrix = np.array([
25
            [1, 0, 0, 10],
26
            [0, 1, 0, 0],
27
            [0, 0, 1.2, 0],
28
            [0, 0, 0, 1]
29
        ])
30
31
        subject_a = Subject(
32
            t1=Image(self.get_image_path('t1_a'), INTENSITY),
33
        )
34
        subject_b = Subject(
35
            t1=Image(self.get_image_path('t1_b'), INTENSITY),
36
            label=Image(self.get_image_path('label_b', binary=True), LABEL),
37
        )
38
        subject_c = Subject(
39
            label=Image(self.get_image_path('label_c', binary=True), LABEL),
40
        )
41
        subject_d = Subject(
42
            t1=Image(
43
                self.get_image_path('t1_d'),
44
                INTENSITY,
45
                pre_affine=registration_matrix,
46
            ),
47
            t2=Image(self.get_image_path('t2_d'), INTENSITY),
48
            label=Image(self.get_image_path('label_d', binary=True), LABEL),
49
        )
50
        self.subjects_list = [
51
            subject_a,
52
            subject_b,
53
            subject_c,
54
            subject_d,
55
        ]
56
        self.dataset = ImagesDataset(self.subjects_list)
57
        self.sample = self.dataset[-1]
58
59
    def make_2d(self, sample):
60
        sample = copy.deepcopy(sample)
61
        for image in sample.get_images(intensity_only=False):
62
            image[DATA] = image[DATA][:, 0:1, ...]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
63
        return sample
64
65
    def make_4d(self, sample):
66
        sample = copy.deepcopy(sample)
67
        for image in sample.get_images(intensity_only=False):
68
            image[DATA] = torch.cat(4 * (image[DATA],))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
69
        return sample
70
71
    def get_inconsistent_sample(self):
72
        """Return a sample containing images of different shape."""
73
        subject = Subject(
74
            t1=Image(self.get_image_path('t1_inc'), INTENSITY),
75
            t2=Image(
76
                self.get_image_path('t2_inc', shape=(10, 20, 31)), INTENSITY),
77
            label=Image(
78
                self.get_image_path(
79
                    'label_inc',
80
                    shape=(8, 17, 25),
81
                    binary=True,
82
                ),
83
                LABEL,
84
            ),
85
            label2=Image(
86
                self.get_image_path(
87
                    'label2_inc',
88
                    shape=(18, 17, 25),
89
                    binary=True,
90
                ),
91
                LABEL,
92
            ),
93
        )
94
        subjects_list = [subject]
95
        dataset = ImagesDataset(subjects_list)
96
        return dataset[0]
97
98
    def get_reference_image_and_path(self):
99
        """Return a reference image and its path"""
100
        path = self.get_image_path('ref', shape=(10, 20, 31), spacing=(1, 1, 2))
101
        image = Image(path, INTENSITY)
102
        return image, path
103
104
    def tearDown(self):
105
        """Tear down test fixtures, if any."""
106
        print('Deleting', self.dir)
107
        shutil.rmtree(self.dir)
108
109
    def get_ixi_tiny(self):
110
        root_dir = Path(tempfile.gettempdir()) / 'torchio' / 'ixi_tiny'
111
        return IXITiny(root_dir, download=True)
112
113
    def get_image_path(
114
            self,
115
            stem,
116
            binary=False,
117
            shape=(10, 20, 30),
118
            spacing=(1, 1, 1),
119
            ):
120
        data = np.random.rand(*shape)
121
        if binary:
122
            data = (data > 0.5).astype(np.uint8)
123
        affine = np.diag((*spacing, 1))
124
        suffix = random.choice(('.nii.gz', '.nii', '.nrrd', '.img'))
125
        path = self.dir / f'{stem}{suffix}'
126
        if np.random.rand() > 0.5:
127
            path = str(path)
128
        image = Image(tensor=data, affine=affine)
129
        image.save(path)
130
        return path
131
132
    def assertTensorNotEqual(self, a, b):
133
        if a.shape != b.shape:
134
            return
135
        assert not torch.all(torch.eq(a, b))
136