Passed
Push — master ( fca4b7...25d6cf )
by Fernando
01:33
created

tests.data.sampler.test_label_sampler   A

Complexity

Total Complexity 3

Size/Duplication

Total Lines 27
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 21
dl 0
loc 27
rs 10
c 0
b 0
f 0
wmc 3

2 Methods

Rating   Name   Duplication   Size   Complexity  
A TestLabelSampler.test_label_sampler() 0 5 2
A TestLabelSampler.test_label_probabilities() 0 11 1
1
import torch
2
import torchio
3
from torchio.data import LabelSampler
4
from ...utils import TorchioTestCase
5
6
7
class TestLabelSampler(TorchioTestCase):
8
    """Tests for `LabelSampler` class."""
9
10
    def test_label_sampler(self):
11
        sampler = LabelSampler(5)
12
        for patch in sampler(self.sample, num_patches=10):
13
            patch_center = patch['label'][torchio.DATA][0, 2, 2, 2]
14
            self.assertEqual(patch_center, 1)
15
16
    def test_label_probabilities(self):
17
        labels = torch.Tensor((0, 0, 1, 1, 2, 1, 0)).reshape(1, 1, -1)
18
        subject = torchio.Subject(
19
            label=torchio.Image(tensor=labels, type=torchio.LABEL),
20
        )
21
        sample = torchio.ImagesDataset([subject])[0]
22
        probs_dict = {0: 0, 1: 50, 2: 25, 3: 25}
23
        sampler = LabelSampler(5, 'label', label_probabilities=probs_dict)
24
        probabilities = sampler.get_probability_map(sample)
25
        fixture = torch.Tensor((0, 0, 2/12, 2/12, 3/12, 2/12, 0))
26
        assert torch.all(probabilities.squeeze().eq(fixture))
27