Passed
Pull Request — main (#605)
by
unknown
03:25
created

deepreg.loss.util.rectangular_kernel1d()   A

Complexity

Conditions 1

Size

Total Lines 11
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 11
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
"""Provide helper functions or classes for defining loss or metrics."""
2
import tensorflow as tf
3
4
5
class NegativeLossMixin(tf.keras.losses.Loss):
6
    """Mixin class to revert the sign of the loss value."""
7
8
    def __init__(self, **kwargs):
9
        """
10
        Init without required arguments.
11
12
        :param kwargs: additional arguments.
13
        """
14
        super().__init__(**kwargs)
15
        self.name = self.name + "Loss"
16
17
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
18
        """
19
        Revert the sign of loss.
20
21
        :param y_true: ground-truth tensor.
22
        :param y_pred: predicted tensor.
23
        :return: negated loss.
24
        """
25
        return -super().call(y_true=y_true, y_pred=y_pred)
26
27
28
EPS = tf.keras.backend.epsilon()
29
30
31
def rectangular_kernel1d(kernel_size: int) -> (tf.Tensor, tf.Tensor):
32
    """
33
    Return a the 1D filter for separable convolution equivalent to a 3-D rectangular
34
    kernel for LocalNormalizedCrossCorrelation.
35
36
    :param kernel_size: scalar, size of the 1-D kernel
37
    :return: kernel_weights, of shape (kernel_size, )
38
    """
39
40
    kernel = tf.ones(shape=(kernel_size,), dtype=tf.float32)
41
    return kernel
42
43
44
def triangular_kernel1d(kernel_size: int) -> (tf.Tensor, tf.Tensor):
45
    """
46
    1D triangular kernel.
47
48
    Assume kernel_size is odd, it will be a smoothed from
49
    a kernel which center part is zero.
50
    Then length of the ones will be around half kernel_size.
51
    The weight scale of the kernel does not matter as LNCC will normalize it.
52
53
    :param kernel_size: scalar, size of the 1-D kernel
54
    :return: kernel_weights, of shape (kernel_size, )
55
    """
56
    assert kernel_size >= 3
57
    assert kernel_size % 2 != 0
58
59
    padding = kernel_size // 2
60
61
    # (kernel_size, )
62
    kernel = (
63
        [0] * (padding // 2 + 1) + [1] * (kernel_size - padding) + [0] * (padding // 2)
64
    )
65
    kernel = tf.constant(kernel, dtype=tf.float32)
66
67
    # (padding*2, )
68
    filters = tf.ones(shape=(kernel_size - padding, 1, 1), dtype=tf.float32)
69
70
    # (kernel_size, 1, 1)
71
    kernel = tf.nn.conv1d(
72
        kernel[None, :, None], filters=filters, stride=[1, 1, 1], padding="SAME"
73
    )
74
    return kernel[0, :, 0]
75
76
77
def gaussian_kernel1d_size(kernel_size: int) -> (tf.Tensor, tf.Tensor):
78
    """
79
    Return a the 1D filter for separable convolution equivalent to a 3-D Gaussian
80
    kernel for LocalNormalizedCrossCorrelation.
81
    :param kernel_size: scalar, size of the 1-D kernel
82
    :return: filters, of shape (kernel_size, )
83
    """
84
    mean = (kernel_size - 1) / 2.0
85
    sigma = kernel_size / 3
86
87
    grid = tf.range(0, kernel_size, dtype=tf.float32)
88
    filters = tf.exp(-tf.square(grid - mean) / (2 * sigma ** 2))
0 ignored issues
show
introduced by
bad operand type for unary -: object
Loading history...
89
90
    return filters
91
92
93
def gaussian_kernel1d_sigma(sigma: int) -> tf.Tensor:
94
    """
95
    Calculate a gaussian kernel.
96
97
    :param sigma: number defining standard deviation for
98
                  gaussian kernel.
99
    :return: shape = (dim, )
100
    """
101
    assert sigma > 0
102
    tail = int(sigma * 3)
103
    kernel = tf.exp([-0.5 * x ** 2 / sigma ** 2 for x in range(-tail, tail + 1)])
104
    kernel = kernel / tf.reduce_sum(kernel)
105
    return kernel
106
107
108
def cauchy_kernel1d(sigma: int) -> tf.Tensor:
109
    """
110
    Approximating cauchy kernel in 1d.
111
112
    :param sigma: int, defining standard deviation of kernel.
113
    :return: shape = (dim, )
114
    """
115
    assert sigma > 0
116
    tail = int(sigma * 5)
117
    k = tf.math.reciprocal([((x / sigma) ** 2 + 1) for x in range(-tail, tail + 1)])
118
    k = k / tf.reduce_sum(k)
119
    return k
120
121
122
def separable_filter(tensor: tf.Tensor, kernel: tf.Tensor) -> tf.Tensor:
123
    """
124
    Create a 3d separable filter.
125
126
    Here `tf.nn.conv3d` accepts the `filters` argument of shape
127
    (filter_depth, filter_height, filter_width, in_channels, out_channels),
128
    where the first axis of `filters` is the depth not batch,
129
    and the input to `tf.nn.conv3d` is of shape
130
    (batch, in_depth, in_height, in_width, in_channels).
131
132
    :param tensor: shape = (batch, dim1, dim2, dim3, 1)
133
    :param kernel: shape = (dim4,)
134
    :return: shape = (batch, dim1, dim2, dim3, 1)
135
    """
136
    strides = [1, 1, 1, 1, 1]
137
    kernel = tf.cast(kernel, dtype=tensor.dtype)
138
139
    tensor = tf.nn.conv3d(
140
        tf.nn.conv3d(
141
            tf.nn.conv3d(
142
                tensor,
143
                filters=tf.reshape(kernel, [-1, 1, 1, 1, 1]),
144
                strides=strides,
145
                padding="SAME",
146
            ),
147
            filters=tf.reshape(kernel, [1, -1, 1, 1, 1]),
148
            strides=strides,
149
            padding="SAME",
150
        ),
151
        filters=tf.reshape(kernel, [1, 1, -1, 1, 1]),
152
        strides=strides,
153
        padding="SAME",
154
    )
155
    return tensor
156