Passed
Pull Request — master (#182)
by Fernando
58s
created

GridAggregator.crop_batch()   A

Complexity

Conditions 2

Size

Total Lines 30
Code Lines 26

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 26
nop 4
dl 0
loc 30
rs 9.256
c 0
b 0
f 0
1
import warnings
2
from typing import Tuple
3
import torch
4
import numpy as np
5
from ...torchio import TypeData, CHANNELS_DIMENSION
6
from .grid_sampler import GridSampler
7
8
9
class GridAggregator:
10
    r"""Aggregate patches for dense inference.
11
12
    This class is typically used to build a volume made of batches after
13
    inference of patches extracted by a :py:class:`~torchio.data.GridSampler`.
14
15
    Args:
16
        sampler: Instance of :py:class:`~torchio.data.GridSampler` used to
17
            extract the patches.
18
19
    .. note:: Adapted from NiftyNet. See `this NiftyNet tutorial
20
        <https://niftynet.readthedocs.io/en/dev/window_sizes.html>`_ for more
21
        information.
22
    """
23
    def __init__(self, sampler: GridSampler):
24
        sample = sampler.sample
25
        self.spatial_shape = sample.spatial_shape
26
        self._output_tensor = None
27
        self.patch_overlap = sampler.patch_overlap
28
29
    def crop_batch(
30
            self,
31
            batch: torch.Tensor,
32
            location: np.ndarray,
33
            overlap: np.ndarray,
34
            ) -> Tuple[TypeData, np.ndarray]:
35
        border = np.array(overlap) // 2  # overlap is even in grid sampler
36
        crop_locations = location.astype(int).copy()
37
        spatial_shape = batch.shape[2:]  # ignore batch and channels dim
38
        indices_ini, indices_fin = crop_locations[:, :3], crop_locations[:, 3:]
39
        num_locations = len(crop_locations)
40
        border_ini = np.tile(border, (num_locations, 1))
41
        border_fin = border_ini.copy()
42
43
        # Do not crop patches at the border of the volume
44
        mask_border_ini = indices_ini == 0
45
        border_ini[mask_border_ini] = 0
46
        for axis, size in enumerate(self.spatial_shape):
47
            mask_border_ini = indices_fin[:, axis] == size
48
            border_fin[mask_border_ini, axis] = 0
49
50
        indices_ini += border_ini
51
        indices_fin -= border_fin
52
        cropped_shape = np.max(indices_fin - indices_ini, axis=0)
53
        diff = spatial_shape - cropped_shape
54
        left = np.floor(diff / 2).astype(np.int)
55
        i_ini, j_ini, k_ini = left
56
        i_fin, j_fin, k_fin = left + cropped_shape
57
        cropped_batch = batch[..., i_ini:i_fin, j_ini:j_fin, k_ini:k_fin]
58
        return cropped_batch, crop_locations
59
60
    def initialize_output_tensor(self, batch: torch.Tensor) -> None:
61
        if self._output_tensor is not None:
62
            return
63
        num_channels = batch.shape[CHANNELS_DIMENSION]
64
        self._output_tensor = torch.zeros(
65
            num_channels,
66
            *self.spatial_shape,
67
            dtype=batch.dtype,
68
        )
69
70
    def add_batch(self, batch: torch.Tensor, locations: TypeData) -> None:
71
        batch = batch.cpu()
72
        locations = locations.cpu().numpy()
73
        self.initialize_output_tensor(batch)
74
        cropped_batch, crop_locations = self.crop_batch(
75
            batch,
76
            locations,
77
            self.patch_overlap,
78
        )
79
        for patch, location in zip(cropped_batch, crop_locations):
80
            i_ini, j_ini, k_ini, i_fin, j_fin, k_fin = location
81
            for channel_idx, tensor in enumerate(patch):
82
                self._output_tensor[
83
                    channel_idx,
84
                    i_ini:i_fin,
85
                    j_ini:j_fin,
86
                    k_ini:k_fin] = tensor
87
88
    def get_output_tensor(self) -> torch.Tensor:
89
        if self._output_tensor.dtype == torch.int64:
90
            message = (
91
                'Medical image frameworks such as ITK do not support int64.'
92
                ' Casting to int32...'
93
            )
94
            warnings.warn(message)
95
            self._output_tensor = self._output_tensor.type(torch.int32)
96
        return self._output_tensor
97