|
1
|
|
|
"""This script provides an example of using custom backbone for training.""" |
|
2
|
|
|
import tensorflow as tf |
|
3
|
|
|
|
|
4
|
|
|
from deepreg.model.backbone import Backbone |
|
5
|
|
|
from deepreg.registry import REGISTRY |
|
6
|
|
|
from deepreg.train import train |
|
7
|
|
|
|
|
8
|
|
|
|
|
9
|
|
|
@REGISTRY.register_backbone(name="custom_backbone") |
|
10
|
|
|
class CustomBackbone(Backbone): |
|
11
|
|
|
""" |
|
12
|
|
|
A dummy custom model for demonstration purpose only |
|
13
|
|
|
""" |
|
14
|
|
|
|
|
15
|
|
|
def __init__( |
|
16
|
|
|
self, |
|
17
|
|
|
image_size: tuple, |
|
18
|
|
|
out_channels: int, |
|
19
|
|
|
num_channel_initial: int, |
|
20
|
|
|
out_kernel_initializer: str, |
|
21
|
|
|
out_activation: str, |
|
22
|
|
|
name: str = "CustomBackbone", |
|
23
|
|
|
**kwargs, |
|
24
|
|
|
): |
|
25
|
|
|
""" |
|
26
|
|
|
Init. |
|
27
|
|
|
|
|
28
|
|
|
:param image_size: (dim1, dim2, dim3), dims of input image. |
|
29
|
|
|
:param out_channels: number of channels for the output |
|
30
|
|
|
:param num_channel_initial: number of initial channels |
|
31
|
|
|
:param depth: input is at level 0, bottom is at level depth |
|
32
|
|
|
:param out_kernel_initializer: kernel initializer for the last layer |
|
33
|
|
|
:param out_activation: activation at the last layer |
|
34
|
|
|
:param name: name of the backbone |
|
35
|
|
|
:param kwargs: additional arguments. |
|
36
|
|
|
""" |
|
37
|
|
|
super().__init__( |
|
38
|
|
|
image_size=image_size, |
|
39
|
|
|
out_channels=out_channels, |
|
40
|
|
|
num_channel_initial=num_channel_initial, |
|
41
|
|
|
out_kernel_initializer=out_kernel_initializer, |
|
42
|
|
|
out_activation=out_activation, |
|
43
|
|
|
name=name, |
|
44
|
|
|
**kwargs, |
|
45
|
|
|
) |
|
46
|
|
|
|
|
47
|
|
|
self.conv1 = tf.keras.layers.Conv3D( |
|
48
|
|
|
filters=num_channel_initial, kernel_size=3, padding="same" |
|
49
|
|
|
) |
|
50
|
|
|
self.conv2 = tf.keras.layers.Conv3D( |
|
51
|
|
|
filters=out_channels, |
|
52
|
|
|
kernel_size=1, |
|
53
|
|
|
kernel_initializer=out_kernel_initializer, |
|
54
|
|
|
activation=out_activation, |
|
55
|
|
|
padding="same", |
|
56
|
|
|
) |
|
57
|
|
|
|
|
58
|
|
|
def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor: |
|
59
|
|
|
""" |
|
60
|
|
|
Builds graph based on built layers. |
|
61
|
|
|
|
|
62
|
|
|
:param inputs: shape = (batch, f_dim1, f_dim2, f_dim3, in_channels) |
|
63
|
|
|
:param training: |
|
64
|
|
|
:param mask: |
|
65
|
|
|
:return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels) |
|
66
|
|
|
""" |
|
67
|
|
|
out = self.conv1(inputs) |
|
68
|
|
|
out = self.conv2(out) |
|
69
|
|
|
return out |
|
70
|
|
|
|
|
71
|
|
|
|
|
72
|
|
|
config_path = "examples/config_custom_backbone.yaml" |
|
73
|
|
|
train( |
|
74
|
|
|
gpu="", |
|
75
|
|
|
config_path=config_path, |
|
76
|
|
|
gpu_allow_growth=True, |
|
77
|
|
|
ckpt_path="", |
|
78
|
|
|
) |
|
79
|
|
|
|