Completed
Push — master ( 040be1...23b5e5 )
by Fernando
02:09
created

TestLabelSampler.test_inconsistent_shape()   A

Complexity

Conditions 1

Size

Total Lines 9
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 7
nop 1
dl 0
loc 9
rs 10
c 0
b 0
f 0
1
import torch
2
import torchio
3
from torchio.data import LabelSampler
4
from ...utils import TorchioTestCase
5
6
7
class TestLabelSampler(TorchioTestCase):
8
    """Tests for `LabelSampler` class."""
9
10
    def test_label_sampler(self):
11
        sampler = LabelSampler(5)
12
        for patch in sampler(self.sample, num_patches=10):
13
            patch_center = patch['label'][torchio.DATA][0, 2, 2, 2]
14
            self.assertEqual(patch_center, 1)
15
16
    def test_label_probabilities(self):
17
        labels = torch.Tensor((0, 0, 1, 1, 2, 1, 0)).reshape(1, 1, 1, -1)
18
        subject = torchio.Subject(
19
            label=torchio.Image(tensor=labels, type=torchio.LABEL),
20
        )
21
        sample = torchio.SubjectsDataset([subject])[0]
22
        probs_dict = {0: 0, 1: 50, 2: 25, 3: 25}
23
        sampler = LabelSampler(5, 'label', label_probabilities=probs_dict)
24
        probabilities = sampler.get_probability_map(sample)
25
        fixture = torch.Tensor((0, 0, 2 / 12, 2 / 12, 3 / 12, 2 / 12, 0))
26
        assert torch.all(probabilities.squeeze().eq(fixture))
27
28
    def test_inconsistent_shape(self):
29
        # https://github.com/fepegar/torchio/issues/234#issuecomment-675029767
30
        sample = torchio.Subject(
31
            im1=torchio.ScalarImage(tensor=torch.rand(2, 4, 5, 6)),
32
            im2=torchio.LabelMap(tensor=torch.rand(1, 4, 5, 6)),
33
        )
34
        patch_size = 2
35
        sampler = LabelSampler(patch_size, 'im2')
36
        next(sampler(sample))
37
38
    def test_multichannel_label_sampler(self):
39
        sample = torchio.Subject(
40
            label=torchio.LabelMap(
41
                tensor=torch.tensor(
42
                    [
43
                        [[[1, 1]]],
44
                        [[[0, 1]]]
45
                    ]
46
                )
47
            )
48
        )
49
        patch_size = 1
50
        sampler = LabelSampler(
51
            patch_size,
52
            'label',
53
            label_probabilities={0: 1, 1: 1}
54
        )
55
        # There are 2 voxels in the image, channels have same probabilities,
56
        # 1st voxel has probability 0.5 * 0.5 + 0 * 0.5 of being chosen while
57
        # 2nd voxel has probability 0.5 * 0.5 + 1 * 0.5 of being chosen.
58
        probabilities = sampler.get_probability_map(sample)
59
        fixture = torch.Tensor((1 / 4, 3 / 4))
60
        assert torch.all(probabilities.squeeze().eq(fixture))
61