Passed
Pull Request — master (#394)
by Fernando
01:18
created

tests.data.sampler.test_label_sampler   A

Complexity

Total Complexity 9

Size/Duplication

Total Lines 78
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 57
dl 0
loc 78
rs 10
c 0
b 0
f 0
wmc 9

6 Methods

Rating   Name   Duplication   Size   Complexity  
A TestLabelSampler.test_inconsistent_shape() 0 9 1
A TestLabelSampler.test_multichannel_label_sampler() 0 23 1
A TestLabelSampler.test_label_sampler() 0 5 2
A TestLabelSampler.test_label_probabilities() 0 11 1
A TestLabelSampler.test_no_labelmap() 0 6 2
A TestLabelSampler.test_empty_map() 0 10 2
1
import torch
2
import torchio as tio
3
from ...utils import TorchioTestCase
4
5
6
class TestLabelSampler(TorchioTestCase):
7
    """Tests for `LabelSampler` class."""
8
9
    def test_label_sampler(self):
10
        sampler = tio.LabelSampler(5)
11
        for patch in sampler(self.sample_subject, num_patches=10):
12
            patch_center = patch['label'][tio.DATA][0, 2, 2, 2]
13
            self.assertEqual(patch_center, 1)
14
15
    def test_label_probabilities(self):
16
        labels = torch.Tensor((0, 0, 1, 1, 2, 1, 0)).reshape(1, 1, 1, -1)
17
        subject = tio.Subject(
18
            label=tio.Image(tensor=labels, type=tio.LABEL),
19
        )
20
        subject = tio.SubjectsDataset([subject])[0]
21
        probs_dict = {0: 0, 1: 50, 2: 25, 3: 25}
22
        sampler = tio.LabelSampler(5, 'label', label_probabilities=probs_dict)
23
        probabilities = sampler.get_probability_map(subject)
24
        fixture = torch.Tensor((0, 0, 2 / 12, 2 / 12, 3 / 12, 2 / 12, 0))
25
        assert torch.all(probabilities.squeeze().eq(fixture))
26
27
    def test_inconsistent_shape(self):
28
        # https://github.com/fepegar/torchio/issues/234#issuecomment-675029767
29
        subject = tio.Subject(
30
            im1=tio.ScalarImage(tensor=torch.rand(2, 4, 5, 6)),
31
            im2=tio.LabelMap(tensor=torch.rand(1, 4, 5, 6)),
32
        )
33
        patch_size = 2
34
        sampler = tio.LabelSampler(patch_size, 'im2')
35
        next(sampler(subject))
36
37
    def test_multichannel_label_sampler(self):
38
        subject = tio.Subject(
39
            label=tio.LabelMap(
40
                tensor=torch.tensor(
41
                    [
42
                        [[[1, 1]]],
43
                        [[[0, 1]]]
44
                    ]
45
                )
46
            )
47
        )
48
        patch_size = 1
49
        sampler = tio.LabelSampler(
50
            patch_size,
51
            'label',
52
            label_probabilities={0: 1, 1: 1}
53
        )
54
        # There are 2 voxels in the image, channels have same probabilities,
55
        # 1st voxel has probability 0.5 * 0.5 + 0 * 0.5 of being chosen while
56
        # 2nd voxel has probability 0.5 * 0.5 + 1 * 0.5 of being chosen.
57
        probabilities = sampler.get_probability_map(subject)
58
        fixture = torch.Tensor((1 / 4, 3 / 4))
59
        assert torch.all(probabilities.squeeze().eq(fixture))
60
61
    def test_no_labelmap(self):
62
        im = tio.ScalarImage(tensor=torch.rand(1, 1, 1, 1))
63
        subject = tio.Subject(image=im, no_label=im)
64
        sampler = tio.LabelSampler(1)
65
        with self.assertRaises(RuntimeError):
66
            next(sampler(subject))
67
68
    def test_empty_map(self):
69
        # https://github.com/fepegar/torchio/issues/392
70
        im = tio.ScalarImage(tensor=torch.rand(1, 6, 6, 6))
71
        label = torch.zeros(1, 6, 6, 6)
72
        label[..., 0] = 1  # voxels far from center
73
        label_im = tio.LabelMap(tensor=label)
74
        subject = tio.Subject(image=im, label=label_im)
75
        sampler = tio.LabelSampler(4)
76
        with self.assertRaises(RuntimeError):
77
            next(sampler(subject))
78