Passed
Pull Request — master (#182)
by Fernando
01:18
created

TestInference.test_inference()   A

Complexity

Conditions 4

Size

Total Lines 23
Code Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 20
nop 1
dl 0
loc 23
rs 9.4
c 0
b 0
f 0
1
import torch
2
from torch.utils.data import DataLoader
3
from tqdm import tqdm
4
from torchio import LOCATION, DATA
5
from torchio.data.inference import GridSampler, GridAggregator
6
from ...utils import TorchioTestCase
7
8
9
class TestInference(TorchioTestCase):
10
    """Tests for `inference` module."""
11
    def test_inference(self):
12
        def model(tensor):
13
            tensor[:] = -5
14
            return tensor
15
16
        for n in 17, 27:
17
            patch_size = 10, 15, n
18
            patch_overlap = 4, 6, 8
19
            batch_size = 6
20
21
            grid_sampler = GridSampler(self.sample, patch_size, patch_overlap)
22
            aggregator = GridAggregator(grid_sampler)
23
            patch_loader = DataLoader(grid_sampler, batch_size=batch_size)
24
            with torch.no_grad():
25
                for patches_batch in tqdm(patch_loader):
26
                    input_tensor = patches_batch['t1'][DATA]
27
                    locations = patches_batch[LOCATION]
28
                    logits = model(input_tensor)  # some model
29
                    outputs = logits
30
                    aggregator.add_batch(outputs, locations)
31
32
            output = aggregator.get_output_tensor()
33
            assert (output == -5).all()
34