Passed
Pull Request — master (#307)
by Fernando
01:17
created

tests.data.inference.test_aggregator   A

Complexity

Total Complexity 5

Size/Duplication

Total Lines 50
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 44
dl 0
loc 50
rs 10
c 0
b 0
f 0
wmc 5

3 Methods

Rating   Name   Duplication   Size   Complexity  
A TestAggregator.test_overlap_average() 0 8 1
A TestAggregator.aggregate() 0 22 3
A TestAggregator.test_overlap_crop() 0 8 1
1
import torch
2
import torchio as tio
3
from torchio import LOCATION, DATA, Subject, ScalarImage
4
from ...utils import TorchioTestCase
5
6
7
class TestAggregator(TorchioTestCase):
8
    """Tests for `aggregator` module."""
9
10
    def aggregate(self, mode, fixture):
11
        tensor = torch.ones(1, 1, 4, 4)
12
        IMG = 'img'
13
        subject = tio.Subject({IMG: tio.ScalarImage(tensor=tensor)})
14
        patch_size = 1, 3, 3
15
        patch_overlap = 0, 2, 2
16
        sampler = tio.data.GridSampler(subject, patch_size, patch_overlap)
17
        aggregator = tio.data.GridAggregator(sampler, overlap_mode=mode)
18
        loader = torch.utils.data.DataLoader(sampler, batch_size=3)
19
        values_dict = {
20
            (0, 0): 0,
21
            (0, 1): 2,
22
            (1, 0): 4,
23
            (1, 1): 6,
24
        }
25
        for batch in loader:
26
            for location, data in zip(batch[LOCATION], batch[IMG][DATA]):
27
                coords_2d = tuple(location[1:3].tolist())
28
                data *= values_dict[coords_2d]
29
            aggregator.add_batch(batch[IMG][DATA], batch[LOCATION])
30
        output = aggregator.get_output_tensor()
31
        self.assertTensorEqual(output, fixture)
32
33
    def test_overlap_crop(self):
34
        fixture = torch.Tensor((
35
            (0, 0, 2, 2),
36
            (0, 0, 2, 2),
37
            (4, 4, 6, 6),
38
            (4, 4, 6, 6),
39
        )).reshape(1, 1, 4, 4)
40
        self.aggregate('crop', fixture)
41
42
    def test_overlap_average(self):
43
        fixture = torch.Tensor((
44
            (0, 1, 1, 2),
45
            (2, 3, 3, 4),
46
            (2, 3, 3, 4),
47
            (4, 5, 5, 6),
48
        )).reshape(1, 1, 4, 4)
49
        self.aggregate('average', fixture)
50