1
|
|
|
"""This script provides an example of using custom backbone for training.""" |
2
|
|
|
import tensorflow as tf |
3
|
|
|
|
4
|
|
|
from deepreg.registry import REGISTRY |
5
|
|
|
from deepreg.train import train |
6
|
|
|
|
7
|
|
|
|
8
|
|
|
@REGISTRY.register_loss(name="root_mean_square") |
9
|
|
|
class RootMeanSquaredDifference(tf.keras.losses.Loss): |
10
|
|
|
""" |
11
|
|
|
Square root of the mean of squared distance between y_true and y_pred. |
12
|
|
|
|
13
|
|
|
y_true and y_pred have to be at least 1d tensor, including batch axis. |
14
|
|
|
""" |
15
|
|
|
|
16
|
|
|
def __init__( |
17
|
|
|
self, |
18
|
|
|
reduction: str = tf.keras.losses.Reduction.SUM, |
19
|
|
|
name: str = "RootMeanSquaredDifference", |
20
|
|
|
): |
21
|
|
|
""" |
22
|
|
|
Init. |
23
|
|
|
|
24
|
|
|
:param reduction: using SUM reduction over batch axis, |
25
|
|
|
calling the loss like `loss(y_true, y_pred)` will return a scalar tensor. |
26
|
|
|
:param name: name of the loss |
27
|
|
|
""" |
28
|
|
|
super().__init__(reduction=reduction, name=name) |
29
|
|
|
|
30
|
|
|
def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: |
31
|
|
|
""" |
32
|
|
|
Return loss for a batch. |
33
|
|
|
|
34
|
|
|
:param y_true: shape = (batch, ...) |
35
|
|
|
:param y_pred: shape = (batch, ...) |
36
|
|
|
:return: shape = (batch,) |
37
|
|
|
""" |
38
|
|
|
loss = tf.math.squared_difference(y_true, y_pred) |
39
|
|
|
loss = tf.keras.layers.Flatten()(loss) |
40
|
|
|
loss = tf.reduce_mean(loss, axis=1) |
41
|
|
|
loss = tf.math.sqrt(loss) |
42
|
|
|
return loss |
43
|
|
|
|
44
|
|
|
|
45
|
|
|
config_path = "examples/config_custom_image_label_loss.yaml" |
46
|
|
|
train( |
47
|
|
|
gpu="", |
48
|
|
|
config_path=config_path, |
49
|
|
|
gpu_allow_growth=True, |
50
|
|
|
ckpt_path="", |
51
|
|
|
) |
52
|
|
|
|