Passed
Pull Request — master (#394)
by Fernando
01:18
created

tests.data.sampler.test_weighted_sampler   A

Complexity

Total Complexity 3

Size/Duplication

Total Lines 36
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 27
dl 0
loc 36
rs 10
c 0
b 0
f 0
wmc 3

3 Methods

Rating   Name   Duplication   Size   Complexity  
A TestWeightedSampler.get_sample() 0 10 1
A TestWeightedSampler.test_weighted_sampler() 0 5 1
A TestWeightedSampler.test_inconsistent_shape() 0 9 1
1
import torch
2
import torchio
3
from torchio.data import WeightedSampler
4
from ...utils import TorchioTestCase
5
6
7
class TestWeightedSampler(TorchioTestCase):
8
    """Tests for `WeightedSampler` class."""
9
10
    def test_weighted_sampler(self):
11
        subject = self.get_sample((1, 7, 7, 7))
12
        sampler = WeightedSampler(5, 'prob')
13
        patch = next(iter(sampler(subject)))
14
        self.assertEqual(tuple(patch['index_ini']), (1, 1, 1))
15
16
    def get_sample(self, image_shape):
17
        t1 = torch.rand(*image_shape)
18
        prob = torch.zeros_like(t1)
19
        prob[0, 3, 3, 3] = 1
20
        subject = torchio.Subject(
21
            t1=torchio.ScalarImage(tensor=t1),
22
            prob=torchio.ScalarImage(tensor=prob),
23
        )
24
        subject = torchio.SubjectsDataset([subject])[0]
25
        return subject
26
27
    def test_inconsistent_shape(self):
28
        # https://github.com/fepegar/torchio/issues/234#issuecomment-675029767
29
        subject = torchio.Subject(
30
            im1=torchio.ScalarImage(tensor=torch.rand(1, 4, 5, 6)),
31
            im2=torchio.ScalarImage(tensor=torch.rand(2, 4, 5, 6)),
32
        )
33
        patch_size = 2
34
        sampler = torchio.data.WeightedSampler(patch_size, 'im1')
35
        next(sampler(subject))
36