Passed
Pull Request — master (#182)
by Fernando
58s
created

tests.data.inference.test_inference   A

Complexity

Total Complexity 3

Size/Duplication

Total Lines 33
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 27
dl 0
loc 33
rs 10
c 0
b 0
f 0
wmc 3

1 Method

Rating   Name   Duplication   Size   Complexity  
A TestInference.test_inference() 0 22 3
1
import torch
2
from torch.utils.data import DataLoader
3
from tqdm import tqdm
4
from torchio import LOCATION, DATA
5
from torchio.data.inference import GridSampler, GridAggregator
6
from ...utils import TorchioTestCase
7
8
9
class TestInference(TorchioTestCase):
10
    """Tests for `inference` module."""
11
    def test_inference(self):
12
        def model(tensor):
13
            tensor[:] = -5
14
            return tensor
15
16
        patch_size = 10, 15, 27
17
        patch_overlap = 4, 6, 8
18
        batch_size = 6
19
20
        grid_sampler = GridSampler(self.sample, patch_size, patch_overlap)
21
        aggregator = GridAggregator(grid_sampler)
22
        patch_loader = DataLoader(grid_sampler, batch_size=batch_size)
23
        with torch.no_grad():
24
            for patches_batch in tqdm(patch_loader):
25
                input_tensor = patches_batch['t1'][DATA]
26
                locations = patches_batch[LOCATION]
27
                logits = model(input_tensor)  # some model
28
                outputs = logits
29
                aggregator.add_batch(outputs, locations)
30
31
        output = aggregator.get_output_tensor()
32
        assert (output == -5).all()
33