Passed
Pull Request — main (#656)
by Yunguan
02:51
created

test.unit.test_backbone.TestLocalNet.test_call()   A

Complexity

Conditions 1

Size

Total Lines 29
Code Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 19
dl 0
loc 29
rs 9.45
c 0
b 0
f 0
cc 1
nop 4
1
# coding=utf-8
2
3
"""
4
Tests for deepreg/model/backbone
5
"""
6
from test.unit.util import is_equal_tf
7
8
import numpy as np
9
import pytest
10
import tensorflow as tf
11
12
import deepreg.model.backbone as backbone
13
import deepreg.model.backbone.global_net as g
14
import deepreg.model.backbone.local_net as loc
15
import deepreg.model.backbone.u_net as u
16
17
18
def test_backbone_interface():
19
    """Test the get_config of the interface"""
20
    config = dict(
21
        image_size=(5, 5, 5),
22
        out_channels=3,
23
        num_channel_initial=4,
24
        out_kernel_initializer="zeros",
25
        out_activation="relu",
26
        name="test",
27
    )
28
    model = backbone.Backbone(**config)
29
    got = model.get_config()
30
    assert got == config
31
32
33
def test_init_global_net():
34
    """
35
    Testing init of GlobalNet is built as expected.
36
    """
37
    # initialising GlobalNet instance
38
    global_test = g.GlobalNet(
39
        image_size=[1, 2, 3],
40
        out_channels=3,
41
        num_channel_initial=3,
42
        extract_levels=[1, 2, 3],
43
        out_kernel_initializer="softmax",
44
        out_activation="softmax",
45
    )
46
47
    # self reference grid
48
    # assert global_test.reference_grid correct shape, Pass
49
    assert global_test._output_block.reference_grid.shape == [1, 2, 3, 3]
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _output_block was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
50
    # assert correct reference grid returned, Pass
51
    expected_ref_grid = tf.convert_to_tensor(
52
        [
53
            [
54
                [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0]],
55
                [[0.0, 1.0, 0.0], [0.0, 1.0, 1.0], [0.0, 1.0, 2.0]],
56
            ]
57
        ],
58
        dtype=tf.float32,
59
    )
60
    assert is_equal_tf(global_test._output_block.reference_grid, expected_ref_grid)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _output_block was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
61
62
    # assert correct initial transform is returned
63
    expected_transform_initial = tf.convert_to_tensor(
64
        [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
65
        dtype=tf.float32,
66
    )
67
    global_transform_initial = tf.Variable(
68
        global_test._output_block.transform_initial(shape=[12])
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _output_block was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
69
    )
70
    assert is_equal_tf(global_transform_initial, expected_transform_initial)
71
72
73
def test_call_global_net():
74
    """
75
    Asserting that output shape of globalnet Call method
76
    is correct.
77
    """
78
    out = 3
79
    im_size = (1, 2, 3)
80
    batch_size = 5
81
    # initialising GlobalNet instance
82
    global_test = g.GlobalNet(
83
        image_size=im_size,
84
        out_channels=out,
85
        num_channel_initial=3,
86
        extract_levels=[1, 2, 3],
87
        out_kernel_initializer="softmax",
88
        out_activation="softmax",
89
    )
90
    # pass an input of all zeros
91
    inputs = tf.constant(
92
        np.zeros(
93
            (batch_size, im_size[0], im_size[1], im_size[2], out), dtype=np.float32
94
        )
95
    )
96
    # get outputs by calling
97
    ddf, theta = global_test.call(inputs)
98
    assert ddf.shape == (batch_size, *im_size, 3)
99
    assert theta.shape == (batch_size, 4, 3)
100
101
102
class TestLocalNet:
103
    """
104
    Test the backbone.local_net.LocalNet class
105
    """
106
107
    @pytest.mark.parametrize("use_additive_upsampling", [True, False])
108
    @pytest.mark.parametrize(
109
        "image_size,extract_levels",
110
        [((11, 12, 13), [1, 2, 3]), ((8, 8, 8), [1, 2, 3])],
111
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
112
    def test_call(
113
        self, image_size: tuple, extract_levels: list, use_additive_upsampling: bool
114
    ):
115
        # initialising LocalNet instance
116
        network = loc.LocalNet(
117
            image_size=image_size,
118
            out_channels=3,
119
            num_channel_initial=3,
120
            extract_levels=extract_levels,
121
            out_kernel_initializer="he_normal",
122
            out_activation="softmax",
123
            use_additive_upsampling=use_additive_upsampling,
124
        )
125
126
        # pass an input of all zeros
127
        inputs = tf.constant(
128
            np.zeros(
129
                (5, image_size[0], image_size[1], image_size[2], 3), dtype=np.float32
130
            )
131
        )
132
        # get outputs by calling
133
        output = network.call(inputs)
134
        # expected shape is (5, 1, 2, 3, 3)
135
        assert all(x == y for x, y in zip(inputs.shape, output.shape))
136
137
138
class TestUNet:
139
    """
140
    Test the backbone.u_net.UNet class
141
    """
142
143
    @pytest.mark.parametrize(
144
        "image_size,depth",
145
        [((11, 12, 13), 5), ((8, 8, 8), 3)],
146
    )
147
    @pytest.mark.parametrize("pooling", [True, False])
148
    @pytest.mark.parametrize("concat_skip", [True, False])
149
    def test_call(
150
        self,
151
        image_size: tuple,
152
        depth: int,
153
        pooling: bool,
154
        concat_skip: bool,
155
    ):
156
        """
157
158
        :param image_size: (dim1, dim2, dim3), dims of input image.
159
        :param depth: input is at level 0, bottom is at level depth
160
        :param pooling: for down-sampling, use non-parameterized
161
                        pooling if true, otherwise use conv3d
162
        :param concat_skip: if concatenate skip or add it
163
        """
164
        out_ch = 3
165
        network = u.UNet(
166
            image_size=image_size,
167
            out_channels=out_ch,
168
            num_channel_initial=2,
169
            depth=depth,
170
            out_kernel_initializer="he_normal",
171
            out_activation="softmax",
172
            pooling=pooling,
173
            concat_skip=concat_skip,
174
        )
175
        inputs = tf.ones(shape=(5, image_size[0], image_size[1], image_size[2], out_ch))
176
        output = network.call(inputs)
177
        assert all(x == y for x, y in zip(inputs.shape, output.shape))
178