Passed
Pull Request — main (#675)
by Yunguan
05:44 queued 02:42
created

TestGlobalNet.test_call()   A

Complexity

Conditions 1

Size

Total Lines 36
Code Lines 26

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 26
dl 0
loc 36
rs 9.256
c 0
b 0
f 0
cc 1
nop 4
1
# coding=utf-8
2
3
"""
4
Tests for deepreg/model/backbone/global_net.py
5
"""
6
from typing import Tuple
7
8
import pytest
9
import tensorflow as tf
10
11
from deepreg.model.backbone.global_net import AffineHead, GlobalNet
12
13
14
def test_affine_head():
15
    """
16
    Test AffineHead.
17
    """
18
    batch = 3
19
    input_shape = (4, 5, 6)
20
    config = dict(image_size=input_shape, name="TestAffineHead")
21
    layer = AffineHead(**config)
22
    inputs = tf.ones(shape=(batch, *input_shape, 2))
23
    ddf, theta = layer.call(inputs)
24
    assert ddf.shape == (batch, *input_shape, 3)
25
    assert theta.shape == (batch, 4, 3)
26
27
    got = layer.get_config()
28
    assert got == {"trainable": True, "dtype": "float32", **config}
29
30
31
class TestGlobalNet:
32
    """
33
    Test the backbone.global_net.GlobalNet class
34
    """
35
36
    @pytest.mark.parametrize(
37
        "image_size,extract_levels,depth",
38
        [
39
            ((11, 12, 13), (0, 1, 2, 4), 4),
40
            ((11, 12, 13), None, 4),
41
            ((11, 12, 13), (0, 1, 2, 4), None),
42
            ((8, 8, 8), (0, 1, 2), 3),
43
        ],
44
    )
45
    def test_call(
46
        self,
47
        image_size: tuple,
48
        extract_levels: Tuple[int, ...],
49
        depth: int,
50
    ):
51
        """
52
53
        :param image_size: (dim1, dim2, dim3), dims of input image.
54
        :param extract_levels: from which depths the output will be built.
55
        :param depth: input is at level 0, bottom is at level depth
56
        """
57
        batch_size = 5
58
        out_ch = 3
59
        network = GlobalNet(
60
            image_size=image_size,
61
            num_channel_initial=2,
62
            extract_levels=extract_levels,
63
            depth=depth,
64
            out_kernel_initializer="he_normal",
65
            out_activation="softmax",
66
            out_channels=out_ch,
67
        )
68
        inputs = tf.ones(shape=(batch_size, *image_size, out_ch))
69
        ddf, theta = network.call(inputs)
70
        assert ddf.shape == inputs.shape
71
        assert theta.shape == (batch_size, 4, 3)
72
73
    def test_err(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
74
        with pytest.raises(ValueError) as err_info:
75
            GlobalNet(
76
                image_size=(4, 5, 6),
77
                out_channels=3,
78
                num_channel_initial=2,
79
                depth=None,
80
                extract_levels=None,
81
                out_kernel_initializer="he_normal",
82
                out_activation="softmax",
83
                pooling=False,
84
                concat_skip=False,
85
                encode_kernel_sizes=[7, 3, 3],
86
                decode_kernel_sizes=3,
87
                strides=2,
88
                padding="same",
89
                name="Test",
90
            )
91
        assert "GlobalNet requires `depth` or `extract_levels`" in str(err_info.value)
92
93 View Code Duplication
    def test_get_config(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
introduced by
Missing function or method docstring
Loading history...
94
        config = dict(
95
            image_size=(4, 5, 6),
96
            out_channels=3,
97
            num_channel_initial=2,
98
            depth=2,
99
            extract_levels=(2,),
100
            out_kernel_initializer="he_normal",
101
            out_activation="softmax",
102
            pooling=False,
103
            concat_skip=False,
104
            encode_kernel_sizes=[7, 3, 3],
105
            decode_kernel_sizes=3,
106
            encode_num_channels=[2, 4, 8],
107
            decode_num_channels=[2, 4, 8],
108
            strides=2,
109
            padding="same",
110
            name="Test",
111
        )
112
        network = GlobalNet(**config)
113
        got = network.get_config()
114
        assert got == config
115