Issues (32)

test/unit/test_backbone_local_net.py (1 issue)

1
# coding=utf-8
2
3
"""
4
Tests for deepreg/model/backbone/local_net.py
5
"""
6
from typing import Tuple
7
8
import pytest
9
import tensorflow as tf
10
11
from deepreg.model.backbone.local_net import AdditiveUpsampling, LocalNet
12
13
14
def test_additive_up_sampling():
15
    """
16
    Test AdditiveUpsampling.
17
    """
18
    batch = 3
19
    filters = 4
20
    input_shape = (4, 5, 6)
21
    outputs_shape = tuple(x * 2 for x in input_shape)
22
    config = dict(
23
        filters=filters,
24
        output_padding=(1, 1, 1),
25
        kernel_size=3,
26
        padding="same",
27
        strides=2,
28
        output_shape=outputs_shape,
29
        name="TestAdditiveUpsampling",
30
    )
31
    layer = AdditiveUpsampling(**config)
32
    inputs = tf.ones(shape=(batch, *input_shape, filters * 2))
33
    output = layer.call(inputs)
34
    assert output.shape == (batch, *outputs_shape, filters)
35
36
    got = layer.get_config()
37
    assert got == {"trainable": True, "dtype": "float32", **config}
38
39
40
class TestLocalNet:
41
    """
42
    Test the backbone.local_net.LocalNet class
43
    """
44
45
    @pytest.mark.parametrize(
46
        "image_size,extract_levels,depth",
47
        [((11, 12, 13), (0, 1, 2, 4), 4), ((8, 8, 8), (0, 1, 2), 3)],
48
    )
49
    @pytest.mark.parametrize("use_additive_upsampling", [True, False])
50
    @pytest.mark.parametrize("pooling", [True, False])
51
    @pytest.mark.parametrize("concat_skip", [True, False])
52
    def test_call(
53
        self,
54
        image_size: tuple,
55
        extract_levels: Tuple[int, ...],
56
        depth: int,
57
        use_additive_upsampling: bool,
58
        pooling: bool,
59
        concat_skip: bool,
60
    ):
61
        """
62
63
        :param image_size: (dim1, dim2, dim3), dims of input image.
64
        :param extract_levels: from which depths the output will be built.
65
        :param depth: input is at level 0, bottom is at level depth
66
        :param use_additive_upsampling: whether use additive up-sampling layer
67
            for decoding.
68
        :param pooling: for down-sampling, use non-parameterized
69
                        pooling if true, otherwise use conv3d
70
        :param concat_skip: if concatenate skip or add it
71
        """
72
        out_ch = 3
73
        network = LocalNet(
74
            image_size=image_size,
75
            num_channel_initial=2,
76
            extract_levels=extract_levels,
77
            depth=depth,
78
            out_kernel_initializer="he_normal",
79
            out_activation="softmax",
80
            out_channels=out_ch,
81
            use_additive_upsampling=use_additive_upsampling,
82
            pooling=pooling,
83
            concat_skip=concat_skip,
84
        )
85
        inputs = tf.ones(shape=(5, *image_size, out_ch))
86
        output = network.call(inputs)
87
        assert inputs.shape == output.shape
88
89 View Code Duplication
    def test_get_config(self):
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
90
        config = dict(
91
            image_size=(4, 5, 6),
92
            out_channels=3,
93
            num_channel_initial=2,
94
            depth=2,
95
            extract_levels=(0, 1),
96
            out_kernel_initializer="he_normal",
97
            out_activation="softmax",
98
            pooling=False,
99
            concat_skip=False,
100
            use_additive_upsampling=True,
101
            encode_kernel_sizes=[7, 3, 3],
102
            decode_kernel_sizes=3,
103
            encode_num_channels=(2, 4, 8),
104
            decode_num_channels=(2, 4, 8),
105
            strides=2,
106
            padding="same",
107
            name="Test",
108
        )
109
        network = LocalNet(**config)
110
        got = network.get_config()
111
        assert got == config
112