Passed
Pull Request — master (#307)
by Fernando
01:14
created

GridAggregator.add_batch()   B

Complexity

Conditions 5

Size

Total Lines 44
Code Lines 33

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 33
nop 3
dl 0
loc 44
rs 8.6213
c 0
b 0
f 0
1
import warnings
2
from typing import Tuple
3
import torch
4
import numpy as np
5
from ...torchio import TypeData, CHANNELS_DIMENSION
6
from .grid_sampler import GridSampler
7
8
9
class GridAggregator:
10
    r"""Aggregate patches for dense inference.
11
12
    This class is typically used to build a volume made of patches after
13
    inference of batches extracted by a :py:class:`~torchio.data.GridSampler`.
14
15
    Args:
16
        sampler: Instance of :py:class:`~torchio.data.GridSampler` used to
17
            extract the patches.
18
19
    .. note:: Adapted from NiftyNet. See `this NiftyNet tutorial
20
        <https://niftynet.readthedocs.io/en/dev/window_sizes.html>`_ for more
21
        information about patch based sampling.
22
    """
23
    def __init__(self, sampler: GridSampler, overlap_mode = 'crop'):
24
        sample = sampler.sample
25
        self.volume_padded = sampler.padding_mode is not None
26
        self.spatial_shape = sample.spatial_shape
27
        self._output_tensor = None
28
        self.patch_overlap = sampler.patch_overlap
29
        self.overlap_mode = overlap_mode
30
        self._avgmask_tensor = None
31
32
    def crop_batch(
33
            self,
34
            batch: torch.Tensor,
35
            locations: np.ndarray,
36
            overlap: np.ndarray,
37
            ) -> Tuple[TypeData, np.ndarray]:
38
        border = np.array(overlap) // 2  # overlap is even in grid sampler
39
        crop_locations = locations.astype(int).copy()
40
        indices_ini, indices_fin = crop_locations[:, :3], crop_locations[:, 3:]
41
        num_locations = len(crop_locations)
42
43
        border_ini = np.tile(border, (num_locations, 1))
44
        border_fin = border_ini.copy()
45
        # Do not crop patches at the border of the volume
46
        # Unless we're padding the volume in the grid sampler. In that case,
47
        # it doesn't matter if we don't crop patches at the border, because the
48
        # output volume will be cropped
49
        if not self.volume_padded:
50
            mask_border_ini = indices_ini == 0
51
            border_ini[mask_border_ini] = 0
52
            for axis, size in enumerate(self.spatial_shape):
53
                mask_border_fin = indices_fin[:, axis] == size
54
                border_fin[mask_border_fin, axis] = 0
55
56
        indices_ini += border_ini
57
        indices_fin -= border_fin
58
59
        crop_shapes = indices_fin - indices_ini
60
        patch_shape = batch.shape[2:]  # ignore batch and channels dim
61
        cropped_patches = []
62
        for patch, crop_shape in zip(batch, crop_shapes):
63
            diff = patch_shape - crop_shape
64
            left = (diff / 2).astype(int)
65
            i_ini, j_ini, k_ini = left
66
            i_fin, j_fin, k_fin = left + crop_shape
67
            cropped_patch = patch[:, i_ini:i_fin, j_ini:j_fin, k_ini:k_fin]
68
            cropped_patches.append(cropped_patch)
69
        return cropped_patches, crop_locations
70
71
    def initialize_output_tensor(self, batch: torch.Tensor) -> None:
72
        if self._output_tensor is not None:
73
            return
74
        num_channels = batch.shape[CHANNELS_DIMENSION]
75
        self._output_tensor = torch.zeros(
76
            num_channels,
77
            *self.spatial_shape,
78
            dtype=batch.dtype,
79
        )
80
81
    def initialize_avgmask_tensor(self, batch: torch.Tensor) -> None:
82
        if self._avgmask_tensor is not None:
83
            return
84
        num_channels = batch.shape[CHANNELS_DIMENSION]
85
        self._avgmask_tensor = torch.zeros(
86
            num_channels,
87
            *self.spatial_shape,
88
            dtype=batch.dtype,
89
        )
90
91
    def add_batch(
92
            self,
93
            batch_tensor: torch.Tensor,
94
            locations: torch.Tensor,
95
            ) -> None:
96
        """Add batch processed by a CNN to the output prediction volume.
97
98
        Args:
99
            batch_tensor: 5D tensor, typically the output of a convolutional
100
                neural network, e.g. ``batch['image'][torchio.DATA]``.
101
            locations: 2D tensor with shape :math:`(B, 6)` representing the
102
                patch indices in the original image. They are typically
103
                extracted using ``batch[torchio.LOCATION]``.
104
        """
105
        batch = batch_tensor.cpu()
106
        locations = locations.cpu().numpy()
107
        self.initialize_output_tensor(batch)
108
        if self.overlap_mode == 'crop':
109
            cropped_patches, crop_locations = self.crop_batch(
110
                batch,
111
                locations,
112
                self.patch_overlap,
113
            )
114
            for patch, crop_location in zip(cropped_patches, crop_locations):
115
                i_ini, j_ini, k_ini, i_fin, j_fin, k_fin = crop_location
116
                self._output_tensor[
117
                    :,
118
                    i_ini:i_fin,
119
                    j_ini:j_fin,
120
                    k_ini:k_fin] = patch
121
        elif self.overlap_mode == 'average':
122
            self.initialize_avgmask_tensor(batch)
123
            for patch, location in zip(batch, locations):
124
                i_ini, j_ini, k_ini, i_fin, j_fin, k_fin = location
125
                self._output_tensor[
126
                    :,
127
                    i_ini:i_fin,
128
                    j_ini:j_fin,
129
                    k_ini:k_fin] += patch
130
                self._avgmask_tensor[
131
                    :,
132
                    i_ini:i_fin,
133
                    j_ini:j_fin,
134
                    k_ini:k_fin] += 1
135
136
    def get_output_tensor(self) -> torch.Tensor:
137
        """Get the aggregated volume after dense inference."""
138
        if self._output_tensor.dtype == torch.int64:
139
            message = (
140
                'Medical image frameworks such as ITK do not support int64.'
141
                ' Casting to int32...'
142
            )
143
            warnings.warn(message)
144
            self._output_tensor = self._output_tensor.type(torch.int32)
145
        if self.overlap_mode == 'average':
146
            output = self._output_tensor / self._avgmask_tensor
147
        else:
148
            output = self._output_tensor
149
        if self.volume_padded:
150
            from ...transforms import Crop
151
            border = self.patch_overlap // 2
152
            cropping = border.repeat(2)
153
            crop = Crop(cropping)
154
            return crop(output)
155
        else:
156
            return output
157