1
|
|
|
"""This script provides an example of using custom backbone for training.""" |
2
|
|
|
import tensorflow as tf |
3
|
|
|
|
4
|
|
|
from deepreg.registry import REGISTRY |
5
|
|
|
from deepreg.train import train |
6
|
|
|
|
7
|
|
|
|
8
|
|
|
@REGISTRY.register_loss(name="lp_norm") |
9
|
|
|
class LPNorm(tf.keras.losses.Loss): |
10
|
|
|
""" |
11
|
|
|
L^p norm between y_true and y_pred, p = 1 or 2. |
12
|
|
|
|
13
|
|
|
y_true and y_pred have to be at least 1d tensor, including batch axis. |
14
|
|
|
""" |
15
|
|
|
|
16
|
|
|
def __init__( |
17
|
|
|
self, |
18
|
|
|
p: int, |
19
|
|
|
reduction: str = tf.keras.losses.Reduction.SUM, |
20
|
|
|
name: str = "LPNorm", |
21
|
|
|
): |
22
|
|
|
""" |
23
|
|
|
Init. |
24
|
|
|
|
25
|
|
|
:param p: order of the norm, 1 or 2. |
26
|
|
|
:param reduction: using SUM reduction over batch axis, |
27
|
|
|
calling the loss like `loss(y_true, y_pred)` will return a scalar tensor. |
28
|
|
|
:param name: name of the loss. |
29
|
|
|
""" |
30
|
|
|
super().__init__(reduction=reduction, name=name) |
31
|
|
|
if p not in [1, 2]: |
32
|
|
|
raise ValueError(f"For LPNorm, p must be 0 or 1, got {p}.") |
33
|
|
|
self.p = p |
34
|
|
|
|
35
|
|
|
def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: |
36
|
|
|
""" |
37
|
|
|
Return loss for a batch. |
38
|
|
|
|
39
|
|
|
:param y_true: shape = (batch, ...) |
40
|
|
|
:param y_pred: shape = (batch, ...) |
41
|
|
|
:return: shape = (batch,) |
42
|
|
|
""" |
43
|
|
|
diff = y_true - y_pred |
44
|
|
|
diff = tf.keras.layers.Flatten()(diff) |
45
|
|
|
loss = tf.norm(diff, axis=-1, ord=self.p) |
46
|
|
|
return loss |
47
|
|
|
|
48
|
|
|
|
49
|
|
|
config_path = "examples/config_custom_parameterized_image_label_loss.yaml" |
50
|
|
|
train( |
51
|
|
|
gpu="", |
52
|
|
|
config_path=config_path, |
53
|
|
|
gpu_allow_growth=True, |
54
|
|
|
ckpt_path="", |
55
|
|
|
) |
56
|
|
|
|