AffineHead.__init__()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 18
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 10
dl 0
loc 18
rs 9.9
c 0
b 0
f 0
cc 1
nop 3
1
# coding=utf-8
2
3
from typing import List, Optional, Tuple, Union
4
5
import numpy as np
6
import tensorflow as tf
7
import tensorflow.keras.layers as tfkl
8
9
from deepreg.model import layer_util
10
from deepreg.model.backbone.u_net import UNet
11
from deepreg.registry import REGISTRY
12
13
14
class AffineHead(tfkl.Layer):
15
    def __init__(
16
        self,
17
        image_size: tuple,
18
        name: str = "AffineHead",
19
    ):
20
        """
21
        Init.
22
23
        :param image_size: such as (dim1, dim2, dim3)
24
        :param name: name of the layer
25
        """
26
        super().__init__(name=name)
27
        self.reference_grid = layer_util.get_reference_grid(image_size)
28
        self.transform_initial = tf.constant_initializer(
29
            value=list(np.eye(4, 3).reshape((-1)))
30
        )
31
        self._flatten = tfkl.Flatten()
32
        self._dense = tfkl.Dense(units=12, bias_initializer=self.transform_initial)
33
34
    def call(
35
        self, inputs: Union[tf.Tensor, List], **kwargs
36
    ) -> Tuple[tf.Tensor, tf.Tensor]:
37
        """
38
39
        :param inputs: a tensor or a list of tensor with length 1
40
        :param kwargs: additional args
41
        :return: ddf and theta
42
43
            - ddf has shape (batch, dim1, dim2, dim3, 3)
44
            - theta has shape (batch, 4, 3)
45
        """
46
        if isinstance(inputs, list):
47
            inputs = inputs[0]
48
        theta = self._dense(self._flatten(inputs))
49
        theta = tf.reshape(theta, shape=(-1, 4, 3))
50
        # warp the reference grid with affine parameters to output a ddf
51
        grid_warped = layer_util.warp_grid(self.reference_grid, theta)
52
        ddf = grid_warped - self.reference_grid
53
        return ddf, theta
54
55
    def get_config(self):
56
        """Return the config dictionary for recreating this class."""
57
        config = super().get_config()
58
        config.update(image_size=self.reference_grid.shape[:3])
59
        return config
60
61
62
@REGISTRY.register_backbone(name="global")
63
class GlobalNet(UNet):
64
    """
65
    Build GlobalNet for image registration.
66
67
    GlobalNet is a special UNet where the decoder for up-sampling is skipped.
68
    The network's outputs come from the bottom layer from the encoder directly.
69
70
    Reference:
71
72
    - Hu, Yipeng, et al.
73
      "Label-driven weakly-supervised learning
74
      for multimodal deformable image registration,"
75
      https://arxiv.org/abs/1711.01666
76
    """
77
78
    def __init__(
79
        self,
80
        image_size: tuple,
81
        num_channel_initial: int,
82
        extract_levels: Optional[Tuple[int, ...]] = None,
83
        depth: Optional[int] = None,
84
        name: str = "GlobalNet",
85
        **kwargs,
86
    ):
87
        """
88
        Image is encoded gradually, i from level 0 to E.
89
        Then, a densely-connected layer outputs an affine
90
        transformation.
91
92
        :param image_size: tuple, such as (dim1, dim2, dim3)
93
        :param num_channel_initial: int, number of initial channels
94
        :param extract_levels: list, which levels from net to extract, deprecated.
95
            If depth is not given, depth = max(extract_levels) will be used.
96
        :param depth: depth of the encoder. If given, extract_levels is not used.
97
        :param name: name of the backbone.
98
        :param kwargs: additional arguments.
99
        """
100
        if depth is None:
101
            if extract_levels is None:
102
                raise ValueError(
103
                    "GlobalNet requires `depth` or `extract_levels` "
104
                    "to define the depth of encoder. "
105
                    "If `depth` is not given, "
106
                    "the maximum value of `extract_levels` will be used."
107
                    "However the argument `extract_levels` is deprecated "
108
                    "and will be removed in future release."
109
                )
110
            depth = max(extract_levels)
111
        super().__init__(
112
            image_size=image_size,
113
            num_channel_initial=num_channel_initial,
114
            depth=depth,
115
            extract_levels=(depth,),
116
            name=name,
117
            **kwargs,
118
        )
119
120
    def build_output_block(
121
        self,
122
        image_size: Tuple[int, ...],
123
        extract_levels: Tuple[int, ...],
124
        out_channels: int,
125
        out_kernel_initializer: str,
126
        out_activation: str,
127
    ) -> Union[tf.keras.Model, tfkl.Layer]:
128
        """
129
        Build a block for output.
130
131
        The input to this block is a list of length 1.
132
        The output has two tensors.
133
134
        :param image_size: such as (dim1, dim2, dim3)
135
        :param extract_levels: not used
136
        :param out_channels: not used
137
        :param out_kernel_initializer: not used
138
        :param out_activation: not used
139
        :return: a block consists of one or multiple layers
140
        """
141
        return AffineHead(image_size=image_size)
142