Issues (32)

test/unit/test_backbone_unet.py (1 issue)

1
# coding=utf-8
2
3
"""
4
Tests for deepreg/model/backbone/u_net.py
5
"""
6
from typing import Tuple
7
8
import pytest
9
import tensorflow as tf
10
11
from deepreg.model.backbone.u_net import UNet
12
13
14
class TestUNet:
15
    """
16
    Test the backbone.u_net.UNet class
17
    """
18
19
    @pytest.mark.parametrize(
20
        "depth,encode_num_channels,decode_num_channels",
21
        [
22
            (2, (4, 8, 16), (4, 8, 16)),
23
            (2, (4, 8, 8), (4, 8, 8)),
24
            (2, (4, 8, 8), (8, 8, 8)),
25
        ],
26
    )
27
    @pytest.mark.parametrize("pooling", [True, False])
28
    @pytest.mark.parametrize("concat_skip", [True, False])
29
    def test_channels(
30
        self,
31
        depth: int,
32
        encode_num_channels: Tuple,
33
        decode_num_channels: Tuple,
34
        pooling: bool,
35
        concat_skip: bool,
36
    ):
37
        """
38
        Test unet with custom encode/decode channels.
39
40
        :param depth: input is at level 0, bottom is at level depth
41
        :param encode_num_channels: filters/channels for down-sampling,
42
            by default it is doubled at each layer during down-sampling
43
        :param decode_num_channels: filters/channels for up-sampling,
44
            by default it is the same as encode_num_channels
45
        :param pooling: for down-sampling, use non-parameterized
46
                        pooling if true, otherwise use conv3d
47
        :param concat_skip: if concatenate skip or add it
48
        """
49
        # in case of adding skip tensors, the channels should match
50
        expect_err = (not concat_skip) and encode_num_channels != decode_num_channels
51
52
        image_size = (5, 6, 7)
53
        out_ch = 3
54
        try:
55
            network = UNet(
56
                image_size=image_size,
57
                out_channels=out_ch,
58
                num_channel_initial=0,
59
                encode_num_channels=encode_num_channels,
60
                decode_num_channels=decode_num_channels,
61
                depth=depth,
62
                out_kernel_initializer="he_normal",
63
                out_activation="softmax",
64
                pooling=pooling,
65
                concat_skip=concat_skip,
66
            )
67
        except ValueError as err:
68
            if expect_err:
69
                return
70
            raise err
71
        inputs = tf.ones(shape=(5, *image_size, out_ch))
72
73
        output = network.call(inputs)
74
        assert inputs.shape == output.shape
75
76
    @pytest.mark.parametrize(
77
        "image_size,depth",
78
        [((11, 12, 13), 5), ((8, 8, 8), 3)],
79
    )
80
    @pytest.mark.parametrize("pooling", [True, False])
81
    @pytest.mark.parametrize("concat_skip", [True, False])
82
    def test_call(
83
        self,
84
        image_size: Tuple,
85
        depth: int,
86
        pooling: bool,
87
        concat_skip: bool,
88
    ):
89
        """
90
        Test unet call function.
91
92
        :param image_size: (dim1, dim2, dim3), dims of input image.
93
        :param depth: input is at level 0, bottom is at level depth
94
        :param pooling: for down-sampling, use non-parameterized
95
                        pooling if true, otherwise use conv3d
96
        :param concat_skip: if concatenate skip or add it
97
        """
98
        out_ch = 3
99
        network = UNet(
100
            image_size=image_size,
101
            out_channels=out_ch,
102
            num_channel_initial=2,
103
            depth=depth,
104
            out_kernel_initializer="he_normal",
105
            out_activation="softmax",
106
            pooling=pooling,
107
            concat_skip=concat_skip,
108
        )
109
        inputs = tf.ones(shape=(5, *image_size, out_ch))
110
        output = network.call(inputs)
111
        assert inputs.shape == output.shape
112
113 View Code Duplication
    def test_get_config(self):
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
114
        config = dict(
115
            image_size=(4, 5, 6),
116
            out_channels=3,
117
            num_channel_initial=2,
118
            depth=2,
119
            extract_levels=(0, 1),
120
            out_kernel_initializer="he_normal",
121
            out_activation="softmax",
122
            pooling=False,
123
            concat_skip=False,
124
            encode_kernel_sizes=3,
125
            decode_kernel_sizes=3,
126
            encode_num_channels=(2, 4, 8),
127
            decode_num_channels=(2, 4, 8),
128
            strides=2,
129
            padding="same",
130
            name="Test",
131
        )
132
        network = UNet(**config)
133
        got = network.get_config()
134
        assert got == config
135