Passed
Pull Request — main (#656)
by Yunguan
11:15 queued 50s
created

test.unit.test_backbone.TestUNet.test_call()   A

Complexity

Conditions 1

Size

Total Lines 35
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 24
dl 0
loc 35
rs 9.304
c 0
b 0
f 0
cc 1
nop 5
1
# coding=utf-8
2
3
"""
4
Tests for deepreg/model/backbone
5
"""
6
from test.unit.util import is_equal_tf
7
8
import numpy as np
9
import pytest
10
import tensorflow as tf
11
12
import deepreg.model.backbone as backbone
13
import deepreg.model.backbone.global_net as g
14
import deepreg.model.backbone.local_net as loc
15
import deepreg.model.backbone.u_net as u
16
import deepreg.model.layer as layer
17
18
19
def test_backbone_interface():
20
    """Test the get_config of the interface"""
21
    config = dict(
22
        image_size=(5, 5, 5),
23
        out_channels=3,
24
        num_channel_initial=4,
25
        out_kernel_initializer="zeros",
26
        out_activation="relu",
27
        name="test",
28
    )
29
    model = backbone.Backbone(**config)
30
    got = model.get_config()
31
    assert got == config
32
33
34
def test_init_global_net():
35
    """
36
    Testing init of GlobalNet is built as expected.
37
    """
38
    # initialising GlobalNet instance
39
    global_test = g.GlobalNet(
40
        image_size=[1, 2, 3],
41
        out_channels=3,
42
        num_channel_initial=3,
43
        extract_levels=[1, 2, 3],
44
        out_kernel_initializer="softmax",
45
        out_activation="softmax",
46
    )
47
48
    # asserting initialised var for extract_levels is the same - Pass
49
    assert global_test._extract_levels == [1, 2, 3]
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _extract_levels was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
50
    # asserting initialised var for extract_max_level is the same - Pass
51
    assert global_test._extract_max_level == 3
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _extract_max_level was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
52
53
    # self reference grid
54
    # assert global_test.reference_grid correct shape, Pass
55
    assert global_test.reference_grid.shape == [1, 2, 3, 3]
56
    # assert correct reference grid returned, Pass
57
    expected_ref_grid = tf.convert_to_tensor(
58
        [
59
            [
60
                [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0]],
61
                [[0.0, 1.0, 0.0], [0.0, 1.0, 1.0], [0.0, 1.0, 2.0]],
62
            ]
63
        ],
64
        dtype=tf.float32,
65
    )
66
    assert is_equal_tf(global_test.reference_grid, expected_ref_grid)
67
68
    # assert correct initial transform is returned
69
    expected_transform_initial = tf.convert_to_tensor(
70
        [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
71
        dtype=tf.float32,
72
    )
73
    global_transform_initial = tf.Variable(global_test.transform_initial(shape=[12]))
74
    assert is_equal_tf(global_transform_initial, expected_transform_initial)
75
76
    # assert conv3dBlock type is correct, Pass
77
    assert isinstance(global_test._conv3d_block, layer.Conv3dBlock)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _conv3d_block was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
78
79
80
def test_call_global_net():
81
    """
82
    Asserting that output shape of globalnet Call method
83
    is correct.
84
    """
85
    out = 3
86
    im_size = (1, 2, 3)
87
    batch_size = 5
88
    # initialising GlobalNet instance
89
    global_test = g.GlobalNet(
90
        image_size=im_size,
91
        out_channels=out,
92
        num_channel_initial=3,
93
        extract_levels=[1, 2, 3],
94
        out_kernel_initializer="softmax",
95
        out_activation="softmax",
96
    )
97
    # pass an input of all zeros
98
    inputs = tf.constant(
99
        np.zeros(
100
            (batch_size, im_size[0], im_size[1], im_size[2], out), dtype=np.float32
101
        )
102
    )
103
    # get outputs by calling
104
    ddf, theta = global_test.call(inputs)
105
    assert ddf.shape == (batch_size, *im_size, 3)
106
    assert theta.shape == (batch_size, 4, 3)
107
108
109
class TestLocalNet:
110
    """
111
    Test the backbone.local_net.LocalNet class
112
    """
113
114
    @pytest.mark.parametrize("use_additive_upsampling", [True, False])
115
    @pytest.mark.parametrize(
116
        "image_size,extract_levels",
117
        [((11, 12, 13), [1, 2, 3]), ((8, 8, 8), [1, 2, 3])],
118
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
119
    def test_call(
120
        self, image_size: tuple, extract_levels: list, use_additive_upsampling: bool
121
    ):
122
        # initialising LocalNet instance
123
        network = loc.LocalNet(
124
            image_size=image_size,
125
            out_channels=3,
126
            num_channel_initial=3,
127
            extract_levels=extract_levels,
128
            out_kernel_initializer="he_normal",
129
            out_activation="softmax",
130
            use_additive_upsampling=use_additive_upsampling,
131
        )
132
133
        # pass an input of all zeros
134
        inputs = tf.constant(
135
            np.zeros(
136
                (5, image_size[0], image_size[1], image_size[2], 3), dtype=np.float32
137
            )
138
        )
139
        # get outputs by calling
140
        output = network.call(inputs)
141
        # expected shape is (5, 1, 2, 3, 3)
142
        assert all(x == y for x, y in zip(inputs.shape, output.shape))
143
144
145
class TestUNet:
146
    """
147
    Test the backbone.u_net.UNet class
148
    """
149
150
    @pytest.mark.parametrize(
151
        "image_size,depth",
152
        [((11, 12, 13), 5), ((8, 8, 8), 3)],
153
    )
154
    @pytest.mark.parametrize("pooling", [True, False])
155
    @pytest.mark.parametrize("concat_skip", [True, False])
156
    def test_call(
157
        self,
158
        image_size: tuple,
159
        depth: int,
160
        pooling: bool,
161
        concat_skip: bool,
162
    ):
163
        """
164
165
        :param image_size: (dim1, dim2, dim3), dims of input image.
166
        :param depth: input is at level 0, bottom is at level depth
167
        :param pooling: for downsampling, use non-parameterized
168
                        pooling if true, otherwise use conv3d
169
        :param concat_skip: if concatenate skip or add it
170
        """
171
        out_ch = 3
172
        network = u.UNet(
173
            image_size=image_size,
174
            out_channels=out_ch,
175
            num_channel_initial=2,
176
            depth=depth,
177
            out_kernel_initializer="he_normal",
178
            out_activation="softmax",
179
            pooling=pooling,
180
            concat_skip=concat_skip,
181
        )
182
        inputs = tf.ones(shape=(5, image_size[0], image_size[1], image_size[2], out_ch))
183
        output = network.call(inputs)
184
        assert all(x == y for x, y in zip(inputs.shape, output.shape))
185