Completed
Push — main ( 21e2a1...183d7f )
by Yunguan
22s queued 13s
created

RootMeanSquaredDifference.call()   A

Complexity

Conditions 1

Size

Total Lines 13
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 6
dl 0
loc 13
rs 10
c 0
b 0
f 0
cc 1
nop 3
1
"""This script provides an example of using custom backbone for training."""
2
import tensorflow as tf
3
4
from deepreg.registry import REGISTRY
5
from deepreg.train import train
6
7
8
@REGISTRY.register_loss(name="root_mean_square")
9
class RootMeanSquaredDifference(tf.keras.losses.Loss):
10
    """
11
    Square root of the mean of squared distance between y_true and y_pred.
12
13
    y_true and y_pred have to be at least 1d tensor, including batch axis.
14
    """
15
16
    def __init__(
17
        self,
18
        reduction: str = tf.keras.losses.Reduction.SUM,
19
        name: str = "RootMeanSquaredDifference",
20
    ):
21
        """
22
        Init.
23
24
        :param reduction: using SUM reduction over batch axis,
25
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
26
        :param name: name of the loss
27
        """
28
        super().__init__(reduction=reduction, name=name)
29
30
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
31
        """
32
        Return loss for a batch.
33
34
        :param y_true: shape = (batch, ...)
35
        :param y_pred: shape = (batch, ...)
36
        :return: shape = (batch,)
37
        """
38
        loss = tf.math.squared_difference(y_true, y_pred)
39
        loss = tf.keras.layers.Flatten()(loss)
40
        loss = tf.reduce_mean(loss, axis=1)
41
        loss = tf.math.sqrt(loss)
42
        return loss
43
44
45
config_path = "examples/config_custom_image_label_loss.yaml"
46
train(
47
    gpu="",
48
    config_path=config_path,
49
    gpu_allow_growth=True,
50
    ckpt_path="",
51
)
52