Passed
Pull Request — main (#740)
by Yunguan
01:19
created

deepreg.loss.util.MultiScaleMixin.call()   A

Complexity

Conditions 4

Size

Total Lines 35
Code Lines 22

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 22
dl 0
loc 35
rs 9.352
c 0
b 0
f 0
cc 4
nop 3
1
"""Provide helper functions or classes for defining loss or metrics."""
2
3
from typing import List, Optional
4
5
import tensorflow as tf
6
7
from deepreg.loss.kernel import cauchy_kernel1d
8
from deepreg.loss.kernel import gaussian_kernel1d_sigma as gaussian_kernel1d
9
10
11
class MultiScaleMixin(tf.keras.losses.Loss):
12
    """
13
    Mixin class for multi-scale loss.
14
15
    It applies the loss at different scales (gaussian or cauchy smoothing).
16
    It is assumed that loss values are between 0 and 1.
17
    """
18
19
    kernel_fn_dict = dict(gaussian=gaussian_kernel1d, cauchy=cauchy_kernel1d)
20
21
    def __init__(
22
        self,
23
        scales: Optional[List] = None,
24
        kernel: str = "gaussian",
25
        name: str = "MultiScaleMixin",
26
        **kwargs,
27
    ):
28
        """
29
        Init.
30
31
        :param scales: list of scalars or None, if None, do not apply any scaling.
32
        :param kernel: gaussian or cauchy.
33
        :param kwargs: additional arguments.
34
        """
35
        super().__init__(name=name, **kwargs)
36
        if kernel not in self.kernel_fn_dict:
37
            raise ValueError(
38
                f"Kernel {kernel} is not supported."
39
                f"Supported kernels are {list(self.kernel_fn_dict.keys())}"
40
            )
41
        self.scales = scales
42
        self.kernel = kernel
43
44
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
45
        """
46
        Use super().call to calculate loss at different scales.
47
48
        :param y_true: ground-truth tensor, shape = (batch, dim1, dim2, dim3).
49
        :param y_pred: predicted tensor, shape = (batch, dim1, dim2, dim3).
50
        :return: multi-scale loss, shape = (batch, ).
51
        """
52
        if self.scales is None:
53
            return super().call(y_true=y_true, y_pred=y_pred)
54
        kernel_fn = self.kernel_fn_dict[self.kernel]
55
        losses = []
56
        for s in self.scales:
57
            if s == 0:
58
                # no smoothing
59
                losses.append(
60
                    super().call(
61
                        y_true=y_true,
62
                        y_pred=y_pred,
63
                    )
64
                )
65
            else:
66
                losses.append(
67
                    super().call(
68
                        y_true=separable_filter(
69
                            tf.expand_dims(y_true, axis=4), kernel_fn(s)
70
                        )[..., 0],
71
                        y_pred=separable_filter(
72
                            tf.expand_dims(y_pred, axis=4), kernel_fn(s)
73
                        )[..., 0],
74
                    )
75
                )
76
        loss = tf.add_n(losses)
77
        loss = loss / len(self.scales)
78
        return loss
79
80
    def get_config(self) -> dict:
81
        """Return the config dictionary for recreating this class."""
82
        config = super().get_config()
83
        config["scales"] = self.scales
84
        config["kernel"] = self.kernel
85
        return config
86
87
88
class NegativeLossMixin(tf.keras.losses.Loss):
89
    """Mixin class to revert the sign of the loss value."""
90
91
    def __init__(self, **kwargs):
92
        """
93
        Init without required arguments.
94
95
        :param kwargs: additional arguments.
96
        """
97
        super().__init__(**kwargs)
98
        self.name = self.name + "Loss"
99
100
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
101
        """
102
        Revert the sign of loss.
103
104
        :param y_true: ground-truth tensor.
105
        :param y_pred: predicted tensor.
106
        :return: negated loss.
107
        """
108
        return -super().call(y_true=y_true, y_pred=y_pred)
109
110
111
def separable_filter(tensor: tf.Tensor, kernel: tf.Tensor) -> tf.Tensor:
112
    """
113
    Create a 3d separable filter.
114
115
    Here `tf.nn.conv3d` accepts the `filters` argument of shape
116
    (filter_depth, filter_height, filter_width, in_channels, out_channels),
117
    where the first axis of `filters` is the depth not batch,
118
    and the input to `tf.nn.conv3d` is of shape
119
    (batch, in_depth, in_height, in_width, in_channels).
120
121
    :param tensor: shape = (batch, dim1, dim2, dim3, 1)
122
    :param kernel: shape = (dim4,)
123
    :return: shape = (batch, dim1, dim2, dim3, 1)
124
    """
125
    strides = [1, 1, 1, 1, 1]
126
    kernel = tf.cast(kernel, dtype=tensor.dtype)
127
128
    tensor = tf.nn.conv3d(
129
        tf.nn.conv3d(
130
            tf.nn.conv3d(
131
                tensor,
132
                filters=tf.reshape(kernel, [-1, 1, 1, 1, 1]),
133
                strides=strides,
134
                padding="SAME",
135
            ),
136
            filters=tf.reshape(kernel, [1, -1, 1, 1, 1]),
137
            strides=strides,
138
            padding="SAME",
139
        ),
140
        filters=tf.reshape(kernel, [1, 1, -1, 1, 1]),
141
        strides=strides,
142
        padding="SAME",
143
    )
144
    return tensor
145