Passed
Pull Request — main (#656)
by Yunguan
02:48
created

GlobalNet.__init__()   B

Complexity

Conditions 3

Size

Total Lines 75
Code Lines 47

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 47
dl 0
loc 75
rs 8.7345
c 0
b 0
f 0
cc 3
nop 9

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
from typing import List
4
5
import numpy as np
6
import tensorflow as tf
7
import tensorflow.keras.layers as tfkl
8
9
from deepreg.model import layer, layer_util
10
from deepreg.model.backbone.interface import Backbone
11
from deepreg.registry import REGISTRY
12
13
14
@REGISTRY.register_backbone(name="global")
15
class GlobalNet(Backbone):
16
    """
17
    Build GlobalNet for image registration.
18
19
    Reference:
20
21
    - Hu, Yipeng, et al.
22
      "Label-driven weakly-supervised learning
23
      for multimodal deformable image registration,"
24
      https://arxiv.org/abs/1711.01666
25
    """
26
27
    def __init__(
28
        self,
29
        image_size: tuple,
30
        out_channels: int,
31
        num_channel_initial: int,
32
        extract_levels: List[int],
33
        out_kernel_initializer: str,
34
        out_activation: str,
35
        name: str = "GlobalNet",
36
        **kwargs,
37
    ):
38
        """
39
        Image is encoded gradually, i from level 0 to E.
40
        Then, a densely-connected layer outputs an affine
41
        transformation.
42
43
        :param image_size: tuple, such as (dim1, dim2, dim3)
44
        :param out_channels: int, number of channels for the output
45
        :param num_channel_initial: int, number of initial channels
46
        :param extract_levels: list, which levels from net to extract
47
        :param out_kernel_initializer: not used
48
        :param out_activation: not used
49
        :param name: name of the backbone.
50
        :param kwargs: additional arguments.
51
        """
52
        super().__init__(
53
            image_size=image_size,
54
            out_channels=out_channels,
55
            num_channel_initial=num_channel_initial,
56
            out_kernel_initializer=out_kernel_initializer,
57
            out_activation=out_activation,
58
            name=name,
59
            **kwargs,
60
        )
61
62
        # save parameters
63
        assert out_channels == 3
64
        self._extract_levels = extract_levels
65
        self._extract_max_level = max(self._extract_levels)  # E
66
        self.reference_grid = layer_util.get_reference_grid(image_size)
67
        self.transform_initial = tf.constant_initializer(
68
            value=list(np.eye(4, 3).reshape((-1)))
69
        )
70
        # init layer variables
71
        num_channels = [
72
            num_channel_initial * (2 ** level)
73
            for level in range(self._extract_max_level + 1)
74
        ]  # level 0 to E
75
        self._downsample_convs = [
76
            tf.keras.Sequential(
77
                [
78
                    layer.Conv3dBlock(
79
                        filters=num_channels[i],
80
                        kernel_size=7 if i == 0 else 3,
81
                        padding="same",
82
                    ),
83
                    layer.ResidualConv3dBlock(
84
                        filters=num_channels[i],
85
                        kernel_size=7 if i == 0 else 3,
86
                        padding="same",
87
                    ),
88
                ]
89
            )
90
            for i in range(self._extract_max_level)
91
        ]  # level 0 to E-1
92
        self._downsample_pools = [
93
            tfkl.MaxPool3D(pool_size=2, strides=2, padding="same")
94
            for _ in range(self._extract_max_level)
95
        ]  # level 0 to E-1
96
        self._conv3d_block = layer.Conv3dBlock(
97
            filters=num_channels[-1], kernel_size=3, padding="same"
98
        )  # level E
99
        self._flatten = tfkl.Flatten()
100
        self._dense_layer = tfkl.Dense(
101
            units=12, bias_initializer=self.transform_initial
102
        )
103
104
    def call(
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
105
        self, inputs: tf.Tensor, training=None, mask=None
106
    ) -> (tf.Tensor, tf.Tensor):
107
        """
108
        Build GlobalNet graph based on built layers.
109
110
        :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch)
111
        :param training: None or bool.
112
        :param mask: None or tf.Tensor.
113
        :return:
114
            ddf shape = (batch, dim1, dim2, dim3, 3)
115
            theta shape = (batch, 4, 3)
116
        """
117
        # down sample from level 0 to E
118
        h_in = inputs
119
        for level in range(self._extract_max_level):  # level 0 to E - 1
120
            skip = self._downsample_convs[level](inputs=h_in, training=training)
121
            h_in = self._downsample_pools[level](inputs=skip)
122
        h_out = self._conv3d_block(
123
            inputs=h_in, training=training
124
        )  # level E of encoding
125
126
        # predict affine parameters theta of shape = (batch, 4, 3)
127
        theta = self._dense_layer(self._flatten(h_out))
128
        theta = tf.reshape(theta, shape=(-1, 4, 3))
129
        # warp the reference grid with affine parameters to output a ddf
130
        grid_warped = layer_util.warp_grid(self.reference_grid, theta)
131
        ddf = grid_warped - self.reference_grid
132
        return ddf, theta
133