Passed
Pull Request — master (#182)
by Fernando
01:16
created

TestInference.test_inference()   A

Complexity

Conditions 3

Size

Total Lines 19
Code Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 17
nop 1
dl 0
loc 19
rs 9.55
c 0
b 0
f 0
1
import torch
2
import torch.nn as nn
3
from torch.utils.data import DataLoader
4
from tqdm import tqdm
5
from torchio import LOCATION, DATA, CHANNELS_DIMENSION
6
from torchio.data.inference import GridSampler, GridAggregator
7
from ...utils import TorchioTestCase
8
9
10
class TestInference(TorchioTestCase):
11
    """Tests for `inference` module."""
12
    def test_inference(self):
13
        model = nn.Conv3d(1, 1, 3)
14
        patch_size = 10, 15, 27
15
        patch_overlap = 4, 6, 8
16
        batch_size = 6
17
18
        grid_sampler = GridSampler(self.sample, patch_size, patch_overlap)
19
        aggregator = GridAggregator(grid_sampler)
20
        patch_loader = DataLoader(grid_sampler, batch_size=batch_size)
21
        with torch.no_grad():
22
            for patches_batch in tqdm(patch_loader):
23
                input_tensor = patches_batch['t1'][DATA]
24
                locations = patches_batch[LOCATION]
25
                logits = model(input_tensor)  # some model
26
                labels = logits.argmax(dim=CHANNELS_DIMENSION, keepdim=True)
27
                outputs = labels
28
                aggregator.add_batch(outputs, locations)
29
30
        aggregator.get_output_tensor()
31