Passed
Pull Request — master (#182)
by Fernando
57s
created

tests.data.inference.test_inference   A

Complexity

Total Complexity 7

Size/Duplication

Total Lines 46
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 38
dl 0
loc 46
rs 10
c 0
b 0
f 0
wmc 7

3 Methods

Rating   Name   Duplication   Size   Complexity  
A TestInference.test_inference_no_padding() 0 2 1
A TestInference.test_inference_padding() 0 2 1
A TestInference.try_inference() 0 25 4

1 Function

Rating   Name   Duplication   Size   Complexity  
A model() 0 3 1
1
import torch
2
from torch.utils.data import DataLoader
3
from tqdm import tqdm
4
from torchio import LOCATION, DATA
5
from torchio.data.inference import GridSampler, GridAggregator
6
from ...utils import TorchioTestCase
7
8
9
class TestInference(TorchioTestCase):
10
    """Tests for `inference` module."""
11
    def test_inference_no_padding(self):
12
        self.try_inference(None)
13
14
    def test_inference_padding(self):
15
        self.try_inference(0)
16
17
    def try_inference(self, padding_mode):
18
        for n in 17, 27:
19
            patch_size = 10, 15, n
20
            patch_overlap = 4, 6, 8
21
            batch_size = 6
22
23
            grid_sampler = GridSampler(
24
                self.sample,
25
                patch_size,
26
                patch_overlap,
27
                padding_mode=padding_mode,
28
            )
29
            aggregator = GridAggregator(grid_sampler)
30
            patch_loader = DataLoader(grid_sampler, batch_size=batch_size)
31
            with torch.no_grad():
32
                for patches_batch in tqdm(patch_loader):
33
                    input_tensor = patches_batch['t1'][DATA]
34
                    locations = patches_batch[LOCATION]
35
                    logits = model(input_tensor)  # some model
36
                    outputs = logits
37
                    aggregator.add_batch(outputs, locations)
38
39
            output = aggregator.get_output_tensor()
40
            assert (output == -5).all()
41
            assert output.shape == self.sample.t1.shape
42
43
def model(tensor):
44
    tensor[:] = -5
45
    return tensor
46