Passed
Pull Request — main (#656)
by Yunguan
02:48
created

test.unit.test_backbone.TestUNet.test_init()   A

Complexity

Conditions 1

Size

Total Lines 43
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 24
dl 0
loc 43
rs 9.304
c 0
b 0
f 0
cc 1
nop 3
1
# coding=utf-8
2
3
"""
4
Tests for deepreg/model/backbone
5
"""
6
from test.unit.util import is_equal_tf
7
8
import numpy as np
9
import pytest
10
import tensorflow as tf
11
12
import deepreg.model.backbone as backbone
13
import deepreg.model.backbone.global_net as g
14
import deepreg.model.backbone.local_net as loc
15
import deepreg.model.backbone.u_net as u
16
import deepreg.model.layer as layer
17
18
19
def test_backbone_interface():
20
    """Test the get_config of the interface"""
21
    config = dict(
22
        image_size=(5, 5, 5),
23
        out_channels=3,
24
        num_channel_initial=4,
25
        out_kernel_initializer="zeros",
26
        out_activation="relu",
27
        name="test",
28
    )
29
    model = backbone.Backbone(**config)
30
    got = model.get_config()
31
    assert got == config
32
33
34
def test_init_global_net():
35
    """
36
    Testing init of GlobalNet is built as expected.
37
    """
38
    # initialising GlobalNet instance
39
    global_test = g.GlobalNet(
40
        image_size=[1, 2, 3],
41
        out_channels=3,
42
        num_channel_initial=3,
43
        extract_levels=[1, 2, 3],
44
        out_kernel_initializer="softmax",
45
        out_activation="softmax",
46
    )
47
48
    # asserting initialised var for extract_levels is the same - Pass
49
    assert global_test._extract_levels == [1, 2, 3]
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _extract_levels was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
50
    # asserting initialised var for extract_max_level is the same - Pass
51
    assert global_test._extract_max_level == 3
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _extract_max_level was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
52
53
    # self reference grid
54
    # assert global_test.reference_grid correct shape, Pass
55
    assert global_test.reference_grid.shape == [1, 2, 3, 3]
56
    # assert correct reference grid returned, Pass
57
    expected_ref_grid = tf.convert_to_tensor(
58
        [
59
            [
60
                [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0]],
61
                [[0.0, 1.0, 0.0], [0.0, 1.0, 1.0], [0.0, 1.0, 2.0]],
62
            ]
63
        ],
64
        dtype=tf.float32,
65
    )
66
    assert is_equal_tf(global_test.reference_grid, expected_ref_grid)
67
68
    # assert correct initial transform is returned
69
    expected_transform_initial = tf.convert_to_tensor(
70
        [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
71
        dtype=tf.float32,
72
    )
73
    global_transform_initial = tf.Variable(global_test.transform_initial(shape=[12]))
74
    assert is_equal_tf(global_transform_initial, expected_transform_initial)
75
76
    # assert conv3dBlock type is correct, Pass
77
    assert isinstance(global_test._conv3d_block, layer.Conv3dBlock)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _conv3d_block was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
78
79
80
def test_call_global_net():
81
    """
82
    Asserting that output shape of globalnet Call method
83
    is correct.
84
    """
85
    out = 3
86
    im_size = (1, 2, 3)
87
    batch_size = 5
88
    # initialising GlobalNet instance
89
    global_test = g.GlobalNet(
90
        image_size=im_size,
91
        out_channels=out,
92
        num_channel_initial=3,
93
        extract_levels=[1, 2, 3],
94
        out_kernel_initializer="softmax",
95
        out_activation="softmax",
96
    )
97
    # pass an input of all zeros
98
    inputs = tf.constant(
99
        np.zeros(
100
            (batch_size, im_size[0], im_size[1], im_size[2], out), dtype=np.float32
101
        )
102
    )
103
    # get outputs by calling
104
    ddf, theta = global_test.call(inputs)
105
    assert ddf.shape == (batch_size, *im_size, 3)
106
    assert theta.shape == (batch_size, 4, 3)
107
108
109
class TestLocalNet:
110
    """
111
    Test the backbone.local_net.LocalNet class
112
    """
113
114
    @pytest.mark.parametrize(
115
        "image_size,extract_levels",
116
        [((11, 12, 13), [1, 2, 3]), ((8, 8, 8), [1, 2, 3])],
117
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
118
    def test_init(self, image_size, extract_levels):
119
        network = loc.LocalNet(
120
            image_size=image_size,
121
            out_channels=3,
122
            num_channel_initial=3,
123
            extract_levels=extract_levels,
124
            out_kernel_initializer="he_normal",
125
            out_activation="softmax",
126
        )
127
128
        # asserting initialised var for extract_levels is the same - Pass
129
        assert network._extract_levels == extract_levels
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _extract_levels was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
130
        # asserting initialised var for extract_max_level is the same - Pass
131
        assert network._extract_max_level == max(extract_levels)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _extract_max_level was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
132
        # asserting initialised var for extract_min_level is the same - Pass
133
        assert network._extract_min_level == min(extract_levels)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _extract_min_level was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
134
135
        # assert number of downsample blocks is correct (== max level), Pass
136
        assert len(network._downsample_convs) == max(extract_levels)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _downsample_convs was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
137
138
        # assert number of upsample blocks is correct (== extract_levels), Pass
139
        assert len(network._extract_layers) == len(extract_levels)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _extract_layers was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
140
141
    @pytest.mark.parametrize(
142
        "image_size,extract_levels",
143
        [((11, 12, 13), [1, 2, 3]), ((8, 8, 8), [1, 2, 3])],
144
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
145
    def test_call(self, image_size, extract_levels):
146
        # initialising LocalNet instance
147
        network = loc.LocalNet(
148
            image_size=image_size,
149
            out_channels=3,
150
            num_channel_initial=3,
151
            extract_levels=extract_levels,
152
            out_kernel_initializer="he_normal",
153
            out_activation="softmax",
154
        )
155
156
        # pass an input of all zeros
157
        inputs = tf.constant(
158
            np.zeros(
159
                (5, image_size[0], image_size[1], image_size[2], 3), dtype=np.float32
160
            )
161
        )
162
        # get outputs by calling
163
        output = network.call(inputs)
164
        # expected shape is (5, 1, 2, 3, 3)
165
        assert all(x == y for x, y in zip(inputs.shape, output.shape))
166
167
168
class TestUNet:
169
    """
170
    Test the backbone.u_net.UNet class
171
    """
172
173
    @pytest.mark.parametrize(
174
        "image_size,depth",
175
        [((11, 12, 13), 5), ((8, 8, 8), 3)],
176
    )
177
    @pytest.mark.parametrize("pooling", [True, False])
178
    @pytest.mark.parametrize("concat_skip", [True, False])
179
    def test_call_unet(
180
        self,
181
        image_size: tuple,
182
        depth: int,
183
        pooling: bool,
184
        concat_skip: bool,
185
    ):
186
        """
187
188
        :param image_size: (dim1, dim2, dim3), dims of input image.
189
        :param depth: input is at level 0, bottom is at level depth
190
        :param pooling: for downsampling, use non-parameterized
191
                        pooling if true, otherwise use conv3d
192
        :param concat_skip: if concatenate skip or add it
193
        """
194
        out_ch = 3
195
        network = u.UNet(
196
            image_size=image_size,
197
            out_channels=out_ch,
198
            num_channel_initial=2,
199
            depth=depth,
200
            out_kernel_initializer="he_normal",
201
            out_activation="softmax",
202
            pooling=pooling,
203
            concat_skip=concat_skip,
204
        )
205
        inputs = tf.ones(shape=(5, image_size[0], image_size[1], image_size[2], out_ch))
206
        output = network.call(inputs)
207
        assert all(x == y for x, y in zip(inputs.shape, output.shape))
208