Passed
Pull Request — main (#656)
by Yunguan
02:48
created

deepreg.model.backbone.global_net   A

Complexity

Total Complexity 5

Size/Duplication

Total Lines 133
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 5
eloc 71
dl 0
loc 133
rs 10
c 0
b 0
f 0

2 Methods

Rating   Name   Duplication   Size   Complexity  
B GlobalNet.__init__() 0 75 3
A GlobalNet.call() 0 29 2
1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
from typing import List
4
5
import numpy as np
6
import tensorflow as tf
7
import tensorflow.keras.layers as tfkl
8
9
from deepreg.model import layer, layer_util
10
from deepreg.model.backbone.interface import Backbone
11
from deepreg.registry import REGISTRY
12
13
14
@REGISTRY.register_backbone(name="global")
15
class GlobalNet(Backbone):
16
    """
17
    Build GlobalNet for image registration.
18
19
    Reference:
20
21
    - Hu, Yipeng, et al.
22
      "Label-driven weakly-supervised learning
23
      for multimodal deformable image registration,"
24
      https://arxiv.org/abs/1711.01666
25
    """
26
27
    def __init__(
28
        self,
29
        image_size: tuple,
30
        out_channels: int,
31
        num_channel_initial: int,
32
        extract_levels: List[int],
33
        out_kernel_initializer: str,
34
        out_activation: str,
35
        name: str = "GlobalNet",
36
        **kwargs,
37
    ):
38
        """
39
        Image is encoded gradually, i from level 0 to E.
40
        Then, a densely-connected layer outputs an affine
41
        transformation.
42
43
        :param image_size: tuple, such as (dim1, dim2, dim3)
44
        :param out_channels: int, number of channels for the output
45
        :param num_channel_initial: int, number of initial channels
46
        :param extract_levels: list, which levels from net to extract
47
        :param out_kernel_initializer: not used
48
        :param out_activation: not used
49
        :param name: name of the backbone.
50
        :param kwargs: additional arguments.
51
        """
52
        super().__init__(
53
            image_size=image_size,
54
            out_channels=out_channels,
55
            num_channel_initial=num_channel_initial,
56
            out_kernel_initializer=out_kernel_initializer,
57
            out_activation=out_activation,
58
            name=name,
59
            **kwargs,
60
        )
61
62
        # save parameters
63
        assert out_channels == 3
64
        self._extract_levels = extract_levels
65
        self._extract_max_level = max(self._extract_levels)  # E
66
        self.reference_grid = layer_util.get_reference_grid(image_size)
67
        self.transform_initial = tf.constant_initializer(
68
            value=list(np.eye(4, 3).reshape((-1)))
69
        )
70
        # init layer variables
71
        num_channels = [
72
            num_channel_initial * (2 ** level)
73
            for level in range(self._extract_max_level + 1)
74
        ]  # level 0 to E
75
        self._downsample_convs = [
76
            tf.keras.Sequential(
77
                [
78
                    layer.Conv3dBlock(
79
                        filters=num_channels[i],
80
                        kernel_size=7 if i == 0 else 3,
81
                        padding="same",
82
                    ),
83
                    layer.ResidualConv3dBlock(
84
                        filters=num_channels[i],
85
                        kernel_size=7 if i == 0 else 3,
86
                        padding="same",
87
                    ),
88
                ]
89
            )
90
            for i in range(self._extract_max_level)
91
        ]  # level 0 to E-1
92
        self._downsample_pools = [
93
            tfkl.MaxPool3D(pool_size=2, strides=2, padding="same")
94
            for _ in range(self._extract_max_level)
95
        ]  # level 0 to E-1
96
        self._conv3d_block = layer.Conv3dBlock(
97
            filters=num_channels[-1], kernel_size=3, padding="same"
98
        )  # level E
99
        self._flatten = tfkl.Flatten()
100
        self._dense_layer = tfkl.Dense(
101
            units=12, bias_initializer=self.transform_initial
102
        )
103
104
    def call(
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
105
        self, inputs: tf.Tensor, training=None, mask=None
106
    ) -> (tf.Tensor, tf.Tensor):
107
        """
108
        Build GlobalNet graph based on built layers.
109
110
        :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch)
111
        :param training: None or bool.
112
        :param mask: None or tf.Tensor.
113
        :return:
114
            ddf shape = (batch, dim1, dim2, dim3, 3)
115
            theta shape = (batch, 4, 3)
116
        """
117
        # down sample from level 0 to E
118
        h_in = inputs
119
        for level in range(self._extract_max_level):  # level 0 to E - 1
120
            skip = self._downsample_convs[level](inputs=h_in, training=training)
121
            h_in = self._downsample_pools[level](inputs=skip)
122
        h_out = self._conv3d_block(
123
            inputs=h_in, training=training
124
        )  # level E of encoding
125
126
        # predict affine parameters theta of shape = (batch, 4, 3)
127
        theta = self._dense_layer(self._flatten(h_out))
128
        theta = tf.reshape(theta, shape=(-1, 4, 3))
129
        # warp the reference grid with affine parameters to output a ddf
130
        grid_warped = layer_util.warp_grid(self.reference_grid, theta)
131
        ddf = grid_warped - self.reference_grid
132
        return ddf, theta
133