custom_backbone   A
last analyzed

Complexity

Total Complexity 2

Size/Duplication

Total Lines 78
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 2
eloc 42
dl 0
loc 78
rs 10
c 0
b 0
f 0

2 Methods

Rating   Name   Duplication   Size   Complexity  
A CustomBackbone.__init__() 0 41 1
A CustomBackbone.call() 0 12 1
1
"""This script provides an example of using custom backbone for training."""
2
import tensorflow as tf
3
4
from deepreg.model.backbone import Backbone
5
from deepreg.registry import REGISTRY
6
from deepreg.train import train
7
8
9
@REGISTRY.register_backbone(name="custom_backbone")
10
class CustomBackbone(Backbone):
11
    """
12
    A dummy custom model for demonstration purpose only
13
    """
14
15
    def __init__(
16
        self,
17
        image_size: tuple,
18
        out_channels: int,
19
        num_channel_initial: int,
20
        out_kernel_initializer: str,
21
        out_activation: str,
22
        name: str = "CustomBackbone",
23
        **kwargs,
24
    ):
25
        """
26
        Init.
27
28
        :param image_size: (dim1, dim2, dim3), dims of input image.
29
        :param out_channels: number of channels for the output
30
        :param num_channel_initial: number of initial channels
31
        :param depth: input is at level 0, bottom is at level depth
32
        :param out_kernel_initializer: kernel initializer for the last layer
33
        :param out_activation: activation at the last layer
34
        :param name: name of the backbone
35
        :param kwargs: additional arguments.
36
        """
37
        super().__init__(
38
            image_size=image_size,
39
            out_channels=out_channels,
40
            num_channel_initial=num_channel_initial,
41
            out_kernel_initializer=out_kernel_initializer,
42
            out_activation=out_activation,
43
            name=name,
44
            **kwargs,
45
        )
46
47
        self.conv1 = tf.keras.layers.Conv3D(
48
            filters=num_channel_initial, kernel_size=3, padding="same"
49
        )
50
        self.conv2 = tf.keras.layers.Conv3D(
51
            filters=out_channels,
52
            kernel_size=1,
53
            kernel_initializer=out_kernel_initializer,
54
            activation=out_activation,
55
            padding="same",
56
        )
57
58
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
59
        """
60
        Builds graph based on built layers.
61
62
        :param inputs: shape = (batch, f_dim1, f_dim2, f_dim3, in_channels)
63
        :param training:
64
        :param mask:
65
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
66
        """
67
        out = self.conv1(inputs)
68
        out = self.conv2(out)
69
        return out
70
71
72
config_path = "examples/config_custom_backbone.yaml"
73
train(
74
    gpu="",
75
    config_path=config_path,
76
    gpu_allow_growth=True,
77
    ckpt_path="",
78
)
79