Passed
Pull Request — main (#656)
by Yunguan
03:21
created

test.unit.test_backbone_unet.TestUNet.test_call()   A

Complexity

Conditions 1

Size

Total Lines 35
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 24
dl 0
loc 35
rs 9.304
c 0
b 0
f 0
cc 1
nop 5
1
# coding=utf-8
2
3
"""
4
Tests for deepreg/model/backbone/u_net.py
5
"""
6
import pytest
7
import tensorflow as tf
8
9
from deepreg.model.backbone.u_net import UNet
10
11
12
class TestUNet:
13
    """
14
    Test the backbone.u_net.UNet class
15
    """
16
17
    @pytest.mark.parametrize(
18
        "image_size,depth",
19
        [((11, 12, 13), 5), ((8, 8, 8), 3)],
20
    )
21
    @pytest.mark.parametrize("pooling", [True, False])
22
    @pytest.mark.parametrize("concat_skip", [True, False])
23
    def test_call(
24
        self,
25
        image_size: tuple,
26
        depth: int,
27
        pooling: bool,
28
        concat_skip: bool,
29
    ):
30
        """
31
32
        :param image_size: (dim1, dim2, dim3), dims of input image.
33
        :param depth: input is at level 0, bottom is at level depth
34
        :param pooling: for down-sampling, use non-parameterized
35
                        pooling if true, otherwise use conv3d
36
        :param concat_skip: if concatenate skip or add it
37
        """
38
        out_ch = 3
39
        network = UNet(
40
            image_size=image_size,
41
            out_channels=out_ch,
42
            num_channel_initial=2,
43
            depth=depth,
44
            out_kernel_initializer="he_normal",
45
            out_activation="softmax",
46
            pooling=pooling,
47
            concat_skip=concat_skip,
48
        )
49
        inputs = tf.ones(shape=(5, *image_size, out_ch))
50
        output = network.call(inputs)
51
        assert inputs.shape == output.shape
52
53 View Code Duplication
    def test_get_config(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
54
        config = dict(
55
            image_size=(4, 5, 6),
56
            out_channels=3,
57
            num_channel_initial=2,
58
            depth=2,
59
            extract_levels=(0, 1),
60
            out_kernel_initializer="he_normal",
61
            out_activation="softmax",
62
            pooling=False,
63
            concat_skip=False,
64
            encode_kernel_sizes=3,
65
            decode_kernel_sizes=3,
66
            strides=2,
67
            padding="same",
68
            name="Test",
69
        )
70
        network = UNet(**config)
71
        got = network.get_config()
72
        assert got == config
73