Passed
Push — master ( 7b848f...cac223 )
by Fernando
02:39
created

tests.data.test_inference   A

Complexity

Total Complexity 3

Size/Duplication

Total Lines 33
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 27
dl 0
loc 33
rs 10
c 0
b 0
f 0
wmc 3

1 Method

Rating   Name   Duplication   Size   Complexity  
A TestInference.test_inference() 0 21 3
1
import torch
2
import torch.nn as nn
3
from torch.utils.data import DataLoader
4
from tqdm import tqdm
5
from torchio import LOCATION, DATA
6
from torchio.data.inference import GridSampler, GridAggregator
7
from ..utils import TorchioTestCase
8
9
10
class TestInference(TorchioTestCase):
11
    """Tests for `inference` module."""
12
    def test_inference(self):
13
        model = nn.Conv3d(1, 1, 3)
14
        patch_size = 10, 15, 27
15
        patch_overlap = 4, 5, 8
16
        batch_size = 6
17
        CHANNELS_DIMENSION = 1
18
19
        grid_sampler = GridSampler(self.sample, patch_size, patch_overlap)
20
        patch_loader = DataLoader(grid_sampler, batch_size=batch_size)
21
        aggregator = GridAggregator(self.sample, patch_overlap)
22
23
        with torch.no_grad():
24
            for patches_batch in tqdm(patch_loader):
25
                input_tensor = patches_batch['t1'][DATA]
26
                locations = patches_batch[LOCATION]
27
                logits = model(input_tensor)  # some model
28
                labels = logits.argmax(dim=CHANNELS_DIMENSION, keepdim=True)
29
                outputs = labels
30
                aggregator.add_batch(outputs, locations)
31
32
        aggregator.get_output_tensor()
33