Passed
Push — master ( fca4b7...25d6cf )
by Fernando
01:33
created

tests.data.sampler.test_weighted_sampler   A

Complexity

Total Complexity 2

Size/Duplication

Total Lines 26
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 20
dl 0
loc 26
rs 10
c 0
b 0
f 0
wmc 2

2 Methods

Rating   Name   Duplication   Size   Complexity  
A TestWeightedSampler.get_sample() 0 10 1
A TestWeightedSampler.test_weighted_sampler() 0 5 1
1
import torch
2
import torchio
3
from torchio.data import WeightedSampler
4
from ...utils import TorchioTestCase
5
6
7
class TestWeightedSampler(TorchioTestCase):
8
    """Tests for `WeightedSampler` class."""
9
10
    def test_weighted_sampler(self):
11
        sample = self.get_sample((7, 7, 7))
12
        sampler = WeightedSampler(5, 'prob')
13
        patch = next(iter(sampler(sample)))
14
        self.assertEqual(tuple(patch['index_ini']), (1, 1, 1))
15
16
    def get_sample(self, image_shape):
17
        t1 = torch.rand(*image_shape)
18
        prob = torch.zeros_like(t1)
19
        prob[3, 3, 3] = 1
20
        subject = torchio.Subject(
21
            t1=torchio.Image(tensor=t1),
22
            prob=torchio.Image(tensor=prob),
23
        )
24
        sample = torchio.ImagesDataset([subject])[0]
25
        return sample
26