Passed
Push — master ( c291a8...879ee9 )
by Fernando
59s
created

tests.data.sampler.test_uniform_sampler   A

Complexity

Total Complexity 2

Size/Duplication

Total Lines 26
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 20
dl 0
loc 26
rs 10
c 0
b 0
f 0
wmc 2

2 Methods

Rating   Name   Duplication   Size   Complexity  
A TestUniformSampler.test_processed_uniform_probabilities() 0 9 1
A TestUniformSampler.test_uniform_probabilities() 0 5 1
1
import torch
2
import torchio
3
import numpy as np
4
from torchio.data import UniformSampler
5
from ...utils import TorchioTestCase
6
7
8
class TestUniformSampler(TorchioTestCase):
9
    """Tests for `UniformSampler` class."""
10
11
    def test_uniform_probabilities(self):
12
        sampler = UniformSampler(5)
13
        probabilities = sampler.get_probability_map(self.sample)
14
        fixtures = torch.ones_like(probabilities)
15
        assert torch.all(probabilities.eq(fixtures))
16
17
    def test_processed_uniform_probabilities(self):
18
        sampler = UniformSampler(5)
19
        probabilities = sampler.get_probability_map(self.sample)
20
        probabilities = sampler.process_probability_map(probabilities)
21
        fixtures = np.zeros_like(probabilities)
22
        # Other positions cannot be patch centers
23
        fixtures[2:-2, 2:-2, 2:-2] = probabilities[2, 2, 2]
24
        self.assertAlmostEqual(probabilities.sum(), 1)
25
        assert np.equal(probabilities, fixtures).all()
26