Passed
Pull Request — main (#740)
by Yunguan
01:28
created

deepreg.loss.util   A

Complexity

Total Complexity 10

Size/Duplication

Total Lines 144
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 10
eloc 68
dl 0
loc 144
rs 10
c 0
b 0
f 0

5 Methods

Rating   Name   Duplication   Size   Complexity  
A NegativeLossMixin.__init__() 0 8 1
A MultiScaleMixin.get_config() 0 6 1
A NegativeLossMixin.call() 0 9 1
A MultiScaleMixin.call() 0 35 4
A MultiScaleMixin.__init__() 0 21 2

1 Function

Rating   Name   Duplication   Size   Complexity  
A separable_filter() 0 34 1
1
"""Provide helper functions or classes for defining loss or metrics."""
2
3
from typing import List, Optional
4
5
import tensorflow as tf
6
7
from deepreg.loss.kernel import cauchy_kernel1d
8
from deepreg.loss.kernel import gaussian_kernel1d_sigma as gaussian_kernel1d
9
10
11
class MultiScaleMixin(tf.keras.losses.Loss):
12
    """
13
    Mixin class for multi-scale loss.
14
15
    It applies the loss at different scales (gaussian or cauchy smoothing).
16
    It is assumed that loss values are between 0 and 1.
17
    """
18
19
    kernel_fn_dict = dict(gaussian=gaussian_kernel1d, cauchy=cauchy_kernel1d)
20
21
    def __init__(
22
        self,
23
        scales: Optional[List] = None,
24
        kernel: str = "gaussian",
25
        **kwargs,
26
    ):
27
        """
28
        Init.
29
30
        :param scales: list of scalars or None, if None, do not apply any scaling.
31
        :param kernel: gaussian or cauchy.
32
        :param kwargs: additional arguments.
33
        """
34
        super().__init__(**kwargs)
35
        if kernel not in self.kernel_fn_dict:
36
            raise ValueError(
37
                f"Kernel {kernel} is not supported."
38
                f"Supported kernels are {list(self.kernel_fn_dict.keys())}"
39
            )
40
        self.scales = scales
41
        self.kernel = kernel
42
43
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
44
        """
45
        Use super().call to calculate loss at different scales.
46
47
        :param y_true: ground-truth tensor, shape = (batch, dim1, dim2, dim3).
48
        :param y_pred: predicted tensor, shape = (batch, dim1, dim2, dim3).
49
        :return: multi-scale loss, shape = (batch, ).
50
        """
51
        if self.scales is None:
52
            return super().call(y_true=y_true, y_pred=y_pred)
53
        kernel_fn = self.kernel_fn_dict[self.kernel]
54
        losses = []
55
        for s in self.scales:
56
            if s == 0:
57
                # no smoothing
58
                losses.append(
59
                    super().call(
60
                        y_true=y_true,
61
                        y_pred=y_pred,
62
                    )
63
                )
64
            else:
65
                losses.append(
66
                    super().call(
67
                        y_true=separable_filter(
68
                            tf.expand_dims(y_true, axis=4), kernel_fn(s)
69
                        )[..., 0],
70
                        y_pred=separable_filter(
71
                            tf.expand_dims(y_pred, axis=4), kernel_fn(s)
72
                        )[..., 0],
73
                    )
74
                )
75
        loss = tf.add_n(losses)
76
        loss = loss / len(self.scales)
77
        return loss
78
79
    def get_config(self) -> dict:
80
        """Return the config dictionary for recreating this class."""
81
        config = super().get_config()
82
        config["scales"] = self.scales
83
        config["kernel"] = self.kernel
84
        return config
85
86
87
class NegativeLossMixin(tf.keras.losses.Loss):
88
    """Mixin class to revert the sign of the loss value."""
89
90
    def __init__(self, **kwargs):
91
        """
92
        Init without required arguments.
93
94
        :param kwargs: additional arguments.
95
        """
96
        super().__init__(**kwargs)
97
        self.name = self.name + "Loss"
98
99
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
100
        """
101
        Revert the sign of loss.
102
103
        :param y_true: ground-truth tensor.
104
        :param y_pred: predicted tensor.
105
        :return: negated loss.
106
        """
107
        return -super().call(y_true=y_true, y_pred=y_pred)
108
109
110
def separable_filter(tensor: tf.Tensor, kernel: tf.Tensor) -> tf.Tensor:
111
    """
112
    Create a 3d separable filter.
113
114
    Here `tf.nn.conv3d` accepts the `filters` argument of shape
115
    (filter_depth, filter_height, filter_width, in_channels, out_channels),
116
    where the first axis of `filters` is the depth not batch,
117
    and the input to `tf.nn.conv3d` is of shape
118
    (batch, in_depth, in_height, in_width, in_channels).
119
120
    :param tensor: shape = (batch, dim1, dim2, dim3, 1)
121
    :param kernel: shape = (dim4,)
122
    :return: shape = (batch, dim1, dim2, dim3, 1)
123
    """
124
    strides = [1, 1, 1, 1, 1]
125
    kernel = tf.cast(kernel, dtype=tensor.dtype)
126
127
    tensor = tf.nn.conv3d(
128
        tf.nn.conv3d(
129
            tf.nn.conv3d(
130
                tensor,
131
                filters=tf.reshape(kernel, [-1, 1, 1, 1, 1]),
132
                strides=strides,
133
                padding="SAME",
134
            ),
135
            filters=tf.reshape(kernel, [1, -1, 1, 1, 1]),
136
            strides=strides,
137
            padding="SAME",
138
        ),
139
        filters=tf.reshape(kernel, [1, 1, -1, 1, 1]),
140
        strides=strides,
141
        padding="SAME",
142
    )
143
    return tensor
144