Passed
Pull Request — main (#605)
by
unknown
03:40
created

deepreg.loss.util.NegativeLossMixin.call()   A

Complexity

Conditions 1

Size

Total Lines 9
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 9
rs 10
c 0
b 0
f 0
cc 1
nop 3
1
"""Provide helper functions or classes for defining loss or metrics."""
2
import tensorflow as tf
3
4
5
class NegativeLossMixin(tf.keras.losses.Loss):
6
    """Mixin class to revert the sign of the loss value."""
7
8
    def __init__(self, **kwargs):
9
        """
10
        Init without required arguments.
11
12
        :param kwargs: additional arguments.
13
        """
14
        super().__init__(**kwargs)
15
        self.name = self.name + "Loss"
16
17
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
18
        """
19
        Revert the sign of loss.
20
21
        :param y_true: ground-truth tensor.
22
        :param y_pred: predicted tensor.
23
        :return: negated loss.
24
        """
25
        return -super().call(y_true=y_true, y_pred=y_pred)
26
27
28
EPS = tf.keras.backend.epsilon()
29
30
31
def rectangular_kernel1d(kernel_size: int) -> (tf.Tensor, tf.Tensor):
32
    """
33
    Return a the 1D filter for separable convolution equivalent to a 3-D rectangular
34
    kernel for LocalNormalizedCrossCorrelation.
35
36
    :param kernel_size: scalar, size of the 1-D kernel
37
    :return: kernel_weights, of shape (kernel_size, )
38
    """
39
40
    kernel = tf.ones(shape=(kernel_size,), dtype=tf.float32)
41
    return kernel
42
43
44
def triangular_kernel1d(kernel_size: int) -> (tf.Tensor, tf.Tensor):
45
    """
46
    1D triangular kernel.
47
48
    Assume kernel_size is odd, it will be a smoothed from
49
    a kernel which center part is zero.
50
    Then length of the ones will be around half kernel_size.
51
    The weight scale of the kernel does not matter as LNCC will normalize it.
52
53
    :param kernel_size: scalar, size of the 1-D kernel
54
    :return: kernel_weights, of shape (kernel_size, )
55
    """
56
    assert kernel_size >= 3
57
    assert kernel_size % 2 != 0
58
59
    padding = kernel_size // 4
60
61
    # (kernel_size, )
62
    kernel = [0] * padding + [1] * (kernel_size - padding * 2) + [0] * padding
63
    kernel = tf.constant(kernel, dtype=tf.float32)
64
65
    if kernel_size == 3:
66
        return kernel
67
68
    # (padding*2, )
69
    filters = tf.ones(shape=(padding * 2, 1, 1), dtype=tf.float32)
70
71
    # (kernel_size, 1, 1)
72
    kernel = tf.nn.conv1d(
73
        kernel[:, None, None], filters=filters, stride=[1, 1, 1], padding="SAME"
74
    )
75
    return kernel[:, 0, 0]
76
77
78
def gaussian_kernel1d_size(kernel_size: int) -> (tf.Tensor, tf.Tensor):
79
    """
80
    Return a the 1D filter for separable convolution equivalent to a 3-D Gaussian
81
    kernel for LocalNormalizedCrossCorrelation.
82
    :param kernel_size: scalar, size of the 1-D kernel
83
    :return: filters, of shape (kernel_size, )
84
    """
85
    mean = (kernel_size - 1) / 2.0
86
    sigma = kernel_size / 3
87
88
    grid = tf.range(0, kernel_size, dtype=tf.float32)
89
    filters = tf.exp(-tf.square(grid - mean) / (2 * sigma ** 2))
0 ignored issues
show
introduced by
bad operand type for unary -: object
Loading history...
90
91
    return filters
92
93
94
def gaussian_kernel1d_sigma(sigma: int) -> tf.Tensor:
95
    """
96
    Calculate a gaussian kernel.
97
98
    :param sigma: number defining standard deviation for
99
                  gaussian kernel.
100
    :return: shape = (dim, )
101
    """
102
    assert sigma > 0
103
    tail = int(sigma * 3)
104
    kernel = tf.exp([-0.5 * x ** 2 / sigma ** 2 for x in range(-tail, tail + 1)])
105
    kernel = kernel / tf.reduce_sum(kernel)
106
    return kernel
107
108
109
def cauchy_kernel1d(sigma: int) -> tf.Tensor:
110
    """
111
    Approximating cauchy kernel in 1d.
112
113
    :param sigma: int, defining standard deviation of kernel.
114
    :return: shape = (dim, )
115
    """
116
    assert sigma > 0
117
    tail = int(sigma * 5)
118
    k = tf.math.reciprocal([((x / sigma) ** 2 + 1) for x in range(-tail, tail + 1)])
119
    k = k / tf.reduce_sum(k)
120
    return k
121
122
123
def separable_filter(tensor: tf.Tensor, kernel: tf.Tensor) -> tf.Tensor:
124
    """
125
    Create a 3d separable filter.
126
127
    Here `tf.nn.conv3d` accepts the `filters` argument of shape
128
    (filter_depth, filter_height, filter_width, in_channels, out_channels),
129
    where the first axis of `filters` is the depth not batch,
130
    and the input to `tf.nn.conv3d` is of shape
131
    (batch, in_depth, in_height, in_width, in_channels).
132
133
    :param tensor: shape = (batch, dim1, dim2, dim3, 1)
134
    :param kernel: shape = (dim4,)
135
    :return: shape = (batch, dim1, dim2, dim3, 1)
136
    """
137
    strides = [1, 1, 1, 1, 1]
138
    kernel = tf.cast(kernel, dtype=tensor.dtype)
139
140
    tensor = tf.nn.conv3d(
141
        tf.nn.conv3d(
142
            tf.nn.conv3d(
143
                tensor,
144
                filters=tf.reshape(kernel, [-1, 1, 1, 1, 1]),
145
                strides=strides,
146
                padding="SAME",
147
            ),
148
            filters=tf.reshape(kernel, [1, -1, 1, 1, 1]),
149
            strides=strides,
150
            padding="SAME",
151
        ),
152
        filters=tf.reshape(kernel, [1, 1, -1, 1, 1]),
153
        strides=strides,
154
        padding="SAME",
155
    )
156
    return tensor
157