Passed
Pull Request — master (#307)
by Fernando
01:17
created

GridAggregator.parse_overlap_mode()   A

Complexity

Conditions 2

Size

Total Lines 8
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 6
nop 1
dl 0
loc 8
rs 10
c 0
b 0
f 0
1
import warnings
2
from typing import Tuple
3
import torch
4
import numpy as np
5
from ...torchio import TypeData, CHANNELS_DIMENSION
6
from .grid_sampler import GridSampler
7
8
9
class GridAggregator:
10
    r"""Aggregate patches for dense inference.
11
12
    This class is typically used to build a volume made of patches after
13
    inference of batches extracted by a :py:class:`~torchio.data.GridSampler`.
14
15
    Args:
16
        sampler: Instance of :py:class:`~torchio.data.GridSampler` used to
17
            extract the patches.
18
        overlap_mode: If ``'crop'``, the overlapping predictions will be
19
            cropped. If ``'average'``, the predictions in the overlapping areas
20
            will be averaged with equal weights. See the
21
            `grid aggregator tests`_ for a raw visualization of both modes.
22
23
    .. _grid aggregator tests: https://github.com/fepegar/torchio/blob/master/tests/data/inference/test_aggregator.py
24
25
    .. note:: Adapted from NiftyNet. See `this NiftyNet tutorial
26
        <https://niftynet.readthedocs.io/en/dev/window_sizes.html>`_ for more
27
        information about patch-based sampling.
28
    """
29
    def __init__(self, sampler: GridSampler, overlap_mode: str = 'crop'):
30
        sample = sampler.sample
31
        self.volume_padded = sampler.padding_mode is not None
32
        self.spatial_shape = sample.spatial_shape
33
        self._output_tensor = None
34
        self.patch_overlap = sampler.patch_overlap
35
        self.parse_overlap_mode(overlap_mode)
36
        self.overlap_mode = overlap_mode
37
        self._avgmask_tensor = None
38
39
    @staticmethod
40
    def parse_overlap_mode(overlap_mode):
41
        if overlap_mode not in ('crop', 'average'):
42
            message = (
43
                f'Overlap mode must be "crop" or "average" but "{overlap_mode}"'
44
                ' was passed'
45
            )
46
            raise ValueError(message)
47
48
    def crop_batch(
49
            self,
50
            batch: torch.Tensor,
51
            locations: np.ndarray,
52
            overlap: np.ndarray,
53
            ) -> Tuple[TypeData, np.ndarray]:
54
        border = np.array(overlap) // 2  # overlap is even in grid sampler
55
        crop_locations = locations.astype(int).copy()
56
        indices_ini, indices_fin = crop_locations[:, :3], crop_locations[:, 3:]
57
        num_locations = len(crop_locations)
58
59
        border_ini = np.tile(border, (num_locations, 1))
60
        border_fin = border_ini.copy()
61
        # Do not crop patches at the border of the volume
62
        # Unless we're padding the volume in the grid sampler. In that case,
63
        # it doesn't matter if we don't crop patches at the border, because the
64
        # output volume will be cropped
65
        if not self.volume_padded:
66
            mask_border_ini = indices_ini == 0
67
            border_ini[mask_border_ini] = 0
68
            for axis, size in enumerate(self.spatial_shape):
69
                mask_border_fin = indices_fin[:, axis] == size
70
                border_fin[mask_border_fin, axis] = 0
71
72
        indices_ini += border_ini
73
        indices_fin -= border_fin
74
75
        crop_shapes = indices_fin - indices_ini
76
        patch_shape = batch.shape[2:]  # ignore batch and channels dim
77
        cropped_patches = []
78
        for patch, crop_shape in zip(batch, crop_shapes):
79
            diff = patch_shape - crop_shape
80
            left = (diff / 2).astype(int)
81
            i_ini, j_ini, k_ini = left
82
            i_fin, j_fin, k_fin = left + crop_shape
83
            cropped_patch = patch[:, i_ini:i_fin, j_ini:j_fin, k_ini:k_fin]
84
            cropped_patches.append(cropped_patch)
85
        return cropped_patches, crop_locations
86
87
    def initialize_output_tensor(self, batch: torch.Tensor) -> None:
88
        if self._output_tensor is not None:
89
            return
90
        num_channels = batch.shape[CHANNELS_DIMENSION]
91
        self._output_tensor = torch.zeros(
92
            num_channels,
93
            *self.spatial_shape,
94
            dtype=batch.dtype,
95
        )
96
97
    def initialize_avgmask_tensor(self, batch: torch.Tensor) -> None:
98
        if self._avgmask_tensor is not None:
99
            return
100
        num_channels = batch.shape[CHANNELS_DIMENSION]
101
        self._avgmask_tensor = torch.zeros(
102
            num_channels,
103
            *self.spatial_shape,
104
            dtype=batch.dtype,
105
        )
106
107
    def add_batch(
108
            self,
109
            batch_tensor: torch.Tensor,
110
            locations: torch.Tensor,
111
            ) -> None:
112
        """Add batch processed by a CNN to the output prediction volume.
113
114
        Args:
115
            batch_tensor: 5D tensor, typically the output of a convolutional
116
                neural network, e.g. ``batch['image'][torchio.DATA]``.
117
            locations: 2D tensor with shape :math:`(B, 6)` representing the
118
                patch indices in the original image. They are typically
119
                extracted using ``batch[torchio.LOCATION]``.
120
        """
121
        batch = batch_tensor.cpu()
122
        locations = locations.cpu().numpy()
123
        self.initialize_output_tensor(batch)
124
        if self.overlap_mode == 'crop':
125
            cropped_patches, crop_locations = self.crop_batch(
126
                batch,
127
                locations,
128
                self.patch_overlap,
129
            )
130
            for patch, crop_location in zip(cropped_patches, crop_locations):
131
                i_ini, j_ini, k_ini, i_fin, j_fin, k_fin = crop_location
132
                self._output_tensor[
133
                    :,
134
                    i_ini:i_fin,
135
                    j_ini:j_fin,
136
                    k_ini:k_fin] = patch
137
        elif self.overlap_mode == 'average':
138
            self.initialize_avgmask_tensor(batch)
139
            for patch, location in zip(batch, locations):
140
                i_ini, j_ini, k_ini, i_fin, j_fin, k_fin = location
141
                self._output_tensor[
142
                    :,
143
                    i_ini:i_fin,
144
                    j_ini:j_fin,
145
                    k_ini:k_fin] += patch
146
                self._avgmask_tensor[
147
                    :,
148
                    i_ini:i_fin,
149
                    j_ini:j_fin,
150
                    k_ini:k_fin] += 1
151
152
    def get_output_tensor(self) -> torch.Tensor:
153
        """Get the aggregated volume after dense inference."""
154
        if self._output_tensor.dtype == torch.int64:
155
            message = (
156
                'Medical image frameworks such as ITK do not support int64.'
157
                ' Casting to int32...'
158
            )
159
            warnings.warn(message)
160
            self._output_tensor = self._output_tensor.type(torch.int32)
161
        if self.overlap_mode == 'average':
162
            output = self._output_tensor / self._avgmask_tensor
163
        else:
164
            output = self._output_tensor
165
        if self.volume_padded:
166
            from ...transforms import Crop
167
            border = self.patch_overlap // 2
168
            cropping = border.repeat(2)
169
            crop = Crop(cropping)
170
            return crop(output)
171
        else:
172
            return output
173