Passed
Pull Request — master (#182)
by Fernando
01:01
created

TestInference.test_inference()   A

Complexity

Conditions 3

Size

Total Lines 22
Code Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 19
nop 1
dl 0
loc 22
rs 9.45
c 0
b 0
f 0
1
import torch
2
from torch.utils.data import DataLoader
3
from tqdm import tqdm
4
from torchio import LOCATION, DATA
5
from torchio.data.inference import GridSampler, GridAggregator
6
from ...utils import TorchioTestCase
7
8
9
class TestInference(TorchioTestCase):
10
    """Tests for `inference` module."""
11
    def test_inference(self):
12
        def model(tensor):
13
            tensor[:] = -5
14
            return tensor
15
16
        patch_size = 10, 15, 27
17
        patch_overlap = 4, 6, 8
18
        batch_size = 6
19
20
        grid_sampler = GridSampler(self.sample, patch_size, patch_overlap)
21
        aggregator = GridAggregator(grid_sampler)
22
        patch_loader = DataLoader(grid_sampler, batch_size=batch_size)
23
        with torch.no_grad():
24
            for patches_batch in tqdm(patch_loader):
25
                input_tensor = patches_batch['t1'][DATA]
26
                locations = patches_batch[LOCATION]
27
                logits = model(input_tensor)  # some model
28
                outputs = logits
29
                aggregator.add_batch(outputs, locations)
30
31
        output = aggregator.get_output_tensor()
32
        assert (output == -5).all()
33