Completed
Push — main ( 21e2a1...183d7f )
by Yunguan
22s queued 13s
created

custom_parameterized_image_label_loss   A

Complexity

Total Complexity 3

Size/Duplication

Total Lines 55
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 3
eloc 26
dl 0
loc 55
rs 10
c 0
b 0
f 0

2 Methods

Rating   Name   Duplication   Size   Complexity  
A LPNorm.call() 0 12 1
A LPNorm.__init__() 0 18 2
1
"""This script provides an example of using custom backbone for training."""
2
import tensorflow as tf
3
4
from deepreg.registry import REGISTRY
5
from deepreg.train import train
6
7
8
@REGISTRY.register_loss(name="lp_norm")
9
class LPNorm(tf.keras.losses.Loss):
10
    """
11
    L^p norm between y_true and y_pred, p = 1 or 2.
12
13
    y_true and y_pred have to be at least 1d tensor, including batch axis.
14
    """
15
16
    def __init__(
17
        self,
18
        p: int,
19
        reduction: str = tf.keras.losses.Reduction.SUM,
20
        name: str = "LPNorm",
21
    ):
22
        """
23
        Init.
24
25
        :param p: order of the norm, 1 or 2.
26
        :param reduction: using SUM reduction over batch axis,
27
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
28
        :param name: name of the loss.
29
        """
30
        super().__init__(reduction=reduction, name=name)
31
        if p not in [1, 2]:
32
            raise ValueError(f"For LPNorm, p must be 0 or 1, got {p}.")
33
        self.p = p
34
35
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
36
        """
37
        Return loss for a batch.
38
39
        :param y_true: shape = (batch, ...)
40
        :param y_pred: shape = (batch, ...)
41
        :return: shape = (batch,)
42
        """
43
        diff = y_true - y_pred
44
        diff = tf.keras.layers.Flatten()(diff)
45
        loss = tf.norm(diff, axis=-1, ord=self.p)
46
        return loss
47
48
49
config_path = "examples/config_custom_parameterized_image_label_loss.yaml"
50
train(
51
    gpu="",
52
    config_path=config_path,
53
    gpu_allow_growth=True,
54
    ckpt_path="",
55
)
56