Completed
Push — main ( de3728...ca54a2 )
by Yunguan
19s queued 13s
created

deepreg.loss.util.MultiScaleMixin.get_config()   A

Complexity

Conditions 1

Size

Total Lines 6
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 5
dl 0
loc 6
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
"""Provide helper functions or classes for defining loss or metrics."""
2
3
from typing import List, Optional, Union
4
5
import tensorflow as tf
6
7
from deepreg.loss.kernel import cauchy_kernel1d
8
from deepreg.loss.kernel import gaussian_kernel1d_sigma as gaussian_kernel1d
9
10
11
class MultiScaleMixin(tf.keras.losses.Loss):
12
    """
13
    Mixin class for multi-scale loss.
14
15
    It applies the loss at different scales (gaussian or cauchy smoothing).
16
    It is assumed that loss values are between 0 and 1.
17
    """
18
19
    kernel_fn_dict = dict(gaussian=gaussian_kernel1d, cauchy=cauchy_kernel1d)
20
21
    def __init__(
22
        self,
23
        scales: Optional[Union[List, float, int]] = None,
24
        kernel: str = "gaussian",
25
        **kwargs,
26
    ):
27
        """
28
        Init.
29
30
        :param scales: list of scalars or None, if None, do not apply any scaling.
31
        :param kernel: gaussian or cauchy.
32
        :param kwargs: additional arguments.
33
        """
34
        super().__init__(**kwargs)
35
        if kernel not in self.kernel_fn_dict:
36
            raise ValueError(
37
                f"Kernel {kernel} is not supported."
38
                f"Supported kernels are {list(self.kernel_fn_dict.keys())}"
39
            )
40
        if scales is not None and not isinstance(scales, list):
41
            scales = [scales]
42
        self.scales = scales
43
        self.kernel = kernel
44
45
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
46
        """
47
        Use super().call to calculate loss at different scales.
48
49
        :param y_true: ground-truth tensor, shape = (batch, dim1, dim2, dim3).
50
        :param y_pred: predicted tensor, shape = (batch, dim1, dim2, dim3).
51
        :return: multi-scale loss, shape = (batch, ).
52
        """
53
        if self.scales is None:
54
            return super().call(y_true=y_true, y_pred=y_pred)
55
        kernel_fn = self.kernel_fn_dict[self.kernel]
56
        losses = []
57
        for s in self.scales:
58
            if s == 0:
59
                # no smoothing
60
                losses.append(
61
                    super().call(
62
                        y_true=y_true,
63
                        y_pred=y_pred,
64
                    )
65
                )
66
            else:
67
                losses.append(
68
                    super().call(
69
                        y_true=separable_filter(
70
                            tf.expand_dims(y_true, axis=4), kernel_fn(s)
71
                        )[..., 0],
72
                        y_pred=separable_filter(
73
                            tf.expand_dims(y_pred, axis=4), kernel_fn(s)
74
                        )[..., 0],
75
                    )
76
                )
77
        loss = tf.add_n(losses)
78
        loss = loss / len(self.scales)
79
        return loss
80
81
    def get_config(self) -> dict:
82
        """Return the config dictionary for recreating this class."""
83
        config = super().get_config()
84
        config["scales"] = self.scales
85
        config["kernel"] = self.kernel
86
        return config
87
88
89
class NegativeLossMixin(tf.keras.losses.Loss):
90
    """Mixin class to revert the sign of the loss value."""
91
92
    def __init__(self, **kwargs):
93
        """
94
        Init without required arguments.
95
96
        :param kwargs: additional arguments.
97
        """
98
        super().__init__(**kwargs)
99
        self.name = self.name + "Loss"
100
101
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
102
        """
103
        Revert the sign of loss.
104
105
        :param y_true: ground-truth tensor.
106
        :param y_pred: predicted tensor.
107
        :return: negated loss.
108
        """
109
        return -super().call(y_true=y_true, y_pred=y_pred)
110
111
112
def separable_filter(tensor: tf.Tensor, kernel: tf.Tensor) -> tf.Tensor:
113
    """
114
    Create a 3d separable filter.
115
116
    Here `tf.nn.conv3d` accepts the `filters` argument of shape
117
    (filter_depth, filter_height, filter_width, in_channels, out_channels),
118
    where the first axis of `filters` is the depth not batch,
119
    and the input to `tf.nn.conv3d` is of shape
120
    (batch, in_depth, in_height, in_width, in_channels).
121
122
    :param tensor: shape = (batch, dim1, dim2, dim3, 1)
123
    :param kernel: shape = (dim4,)
124
    :return: shape = (batch, dim1, dim2, dim3, 1)
125
    """
126
    strides = [1, 1, 1, 1, 1]
127
    kernel = tf.cast(kernel, dtype=tensor.dtype)
128
129
    tensor = tf.nn.conv3d(
130
        tf.nn.conv3d(
131
            tf.nn.conv3d(
132
                tensor,
133
                filters=tf.reshape(kernel, [-1, 1, 1, 1, 1]),
134
                strides=strides,
135
                padding="SAME",
136
            ),
137
            filters=tf.reshape(kernel, [1, -1, 1, 1, 1]),
138
            strides=strides,
139
            padding="SAME",
140
        ),
141
        filters=tf.reshape(kernel, [1, 1, -1, 1, 1]),
142
        strides=strides,
143
        padding="SAME",
144
    )
145
    return tensor
146