Passed
Pull Request — main (#656)
by Yunguan
02:55
created

test.unit.test_backbone_global_net   A

Complexity

Total Complexity 5

Size/Duplication

Total Lines 113
Duplicated Lines 17.7 %

Importance

Changes 0
Metric Value
wmc 5
eloc 80
dl 20
loc 113
rs 10
c 0
b 0
f 0

3 Methods

Rating   Name   Duplication   Size   Complexity  
A TestGlobalNet.test_call() 0 36 1
A TestGlobalNet.test_get_config() 20 20 1
A TestGlobalNet.test_err() 0 19 2

1 Function

Rating   Name   Duplication   Size   Complexity  
A test_affine_head() 0 15 1

How to fix   Duplicated Code   

Duplicated Code

Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.

Common duplication problems, and corresponding solutions are:

1
# coding=utf-8
2
3
"""
4
Tests for deepreg/model/backbone/global_net.py
5
"""
6
from typing import Tuple
7
8
import pytest
9
import tensorflow as tf
10
11
from deepreg.model.backbone.global_net import AffineHead, GlobalNet
12
13
14
def test_affine_head():
15
    """
16
    Test AffineHead.
17
    """
18
    batch = 3
19
    input_shape = (4, 5, 6)
20
    config = dict(image_size=input_shape, name="TestAffineHead")
21
    layer = AffineHead(**config)
22
    inputs = tf.ones(shape=(batch, *input_shape, 2))
23
    ddf, theta = layer.call(inputs)
24
    assert ddf.shape == (batch, *input_shape, 3)
25
    assert theta.shape == (batch, 4, 3)
26
27
    got = layer.get_config()
28
    assert got == {"trainable": True, "dtype": "float32", **config}
29
30
31
class TestGlobalNet:
32
    """
33
    Test the backbone.global_net.GlobalNet class
34
    """
35
36
    @pytest.mark.parametrize(
37
        "image_size,extract_levels,depth",
38
        [
39
            ((11, 12, 13), (0, 1, 2, 4), 4),
40
            ((11, 12, 13), None, 4),
41
            ((11, 12, 13), (0, 1, 2, 4), None),
42
            ((8, 8, 8), (0, 1, 2), 3),
43
        ],
44
    )
45
    def test_call(
46
        self,
47
        image_size: tuple,
48
        extract_levels: Tuple[int],
49
        depth: int,
50
    ):
51
        """
52
53
        :param image_size: (dim1, dim2, dim3), dims of input image.
54
        :param extract_levels: from which depths the output will be built.
55
        :param depth: input is at level 0, bottom is at level depth
56
        """
57
        batch_size = 5
58
        out_ch = 3
59
        network = GlobalNet(
60
            image_size=image_size,
61
            num_channel_initial=2,
62
            extract_levels=extract_levels,
63
            depth=depth,
64
            out_kernel_initializer="he_normal",
65
            out_activation="softmax",
66
            out_channels=out_ch,
67
        )
68
        inputs = tf.ones(shape=(batch_size, *image_size, out_ch))
69
        ddf, theta = network.call(inputs)
70
        assert ddf.shape == inputs.shape
71
        assert theta.shape == (batch_size, 4, 3)
72
73
    def test_err(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
74
        with pytest.raises(ValueError) as err_info:
75
            GlobalNet(
76
                image_size=(4, 5, 6),
77
                out_channels=3,
78
                num_channel_initial=2,
79
                depth=None,
80
                extract_levels=None,
81
                out_kernel_initializer="he_normal",
82
                out_activation="softmax",
83
                pooling=False,
84
                concat_skip=False,
85
                encode_kernel_sizes=[7, 3, 3],
86
                decode_kernel_sizes=3,
87
                strides=2,
88
                padding="same",
89
                name="Test",
90
            )
91
        assert "GlobalNet requires `depth` or `extract_levels`" in str(err_info.value)
92
93 View Code Duplication
    def test_get_config(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
introduced by
Missing function or method docstring
Loading history...
94
        config = dict(
95
            image_size=(4, 5, 6),
96
            out_channels=3,
97
            num_channel_initial=2,
98
            depth=2,
99
            extract_levels=(2,),
100
            out_kernel_initializer="he_normal",
101
            out_activation="softmax",
102
            pooling=False,
103
            concat_skip=False,
104
            encode_kernel_sizes=[7, 3, 3],
105
            decode_kernel_sizes=3,
106
            strides=2,
107
            padding="same",
108
            name="Test",
109
        )
110
        network = GlobalNet(**config)
111
        got = network.get_config()
112
        assert got == config
113