| Total Complexity | 2 |
| Total Lines | 51 |
| Duplicated Lines | 0 % |
| Changes | 0 | ||
| 1 | """This script provides an example of using custom backbone for training.""" |
||
| 2 | import tensorflow as tf |
||
| 3 | |||
| 4 | from deepreg.registry import REGISTRY |
||
| 5 | from deepreg.train import train |
||
| 6 | |||
| 7 | |||
| 8 | @REGISTRY.register_loss(name="root_mean_square") |
||
| 9 | class RootMeanSquaredDifference(tf.keras.losses.Loss): |
||
| 10 | """ |
||
| 11 | Square root of the mean of squared distance between y_true and y_pred. |
||
| 12 | |||
| 13 | y_true and y_pred have to be at least 1d tensor, including batch axis. |
||
| 14 | """ |
||
| 15 | |||
| 16 | def __init__( |
||
| 17 | self, |
||
| 18 | name: str = "RootMeanSquaredDifference", |
||
| 19 | **kwargs, |
||
| 20 | ): |
||
| 21 | """ |
||
| 22 | Init. |
||
| 23 | |||
| 24 | :param name: name of the loss |
||
| 25 | :param kwargs: additional arguments. |
||
| 26 | """ |
||
| 27 | super().__init__(name=name, **kwargs) |
||
| 28 | self.flatten = tf.keras.layers.Flatten() |
||
| 29 | |||
| 30 | def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: |
||
| 31 | """ |
||
| 32 | Return loss for a batch. |
||
| 33 | |||
| 34 | :param y_true: shape = (batch, ...) |
||
| 35 | :param y_pred: shape = (batch, ...) |
||
| 36 | :return: shape = (batch,) |
||
| 37 | """ |
||
| 38 | loss = tf.math.squared_difference(y_true, y_pred) |
||
| 39 | loss = self.flatten(loss) |
||
| 40 | loss = tf.reduce_mean(loss, axis=1) |
||
| 41 | loss = tf.math.sqrt(loss) |
||
| 42 | return loss |
||
| 43 | |||
| 44 | |||
| 45 | config_path = "examples/config_custom_image_label_loss.yaml" |
||
| 46 | train( |
||
| 47 | gpu="", |
||
| 48 | config_path=config_path, |
||
| 49 | gpu_allow_growth=True, |
||
| 50 | ckpt_path="", |
||
| 51 | ) |
||
| 52 |