Passed
Pull Request — main (#656)
by Yunguan
02:46
created

test.unit.test_backbone.TestUNet.test_init()   A

Complexity

Conditions 2

Size

Total Lines 50
Code Lines 29

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 29
dl 0
loc 50
rs 9.184
c 0
b 0
f 0
cc 2
nop 4
1
# coding=utf-8
2
3
"""
4
Tests for deepreg/model/backbone
5
"""
6
from test.unit.util import is_equal_tf
7
8
import numpy as np
9
import pytest
10
import tensorflow as tf
11
12
import deepreg.model.backbone as backbone
13
import deepreg.model.backbone.global_net as g
14
import deepreg.model.backbone.local_net as loc
15
import deepreg.model.backbone.u_net as u
16
import deepreg.model.layer as layer
17
18
19
def test_backbone_interface():
20
    """Test the get_config of the interface"""
21
    config = dict(
22
        image_size=(5, 5, 5),
23
        out_channels=3,
24
        num_channel_initial=4,
25
        out_kernel_initializer="zeros",
26
        out_activation="relu",
27
        name="test",
28
    )
29
    model = backbone.Backbone(**config)
30
    got = model.get_config()
31
    assert got == config
32
33
34
def test_init_global_net():
35
    """
36
    Testing init of GlobalNet is built as expected.
37
    """
38
    # initialising GlobalNet instance
39
    global_test = g.GlobalNet(
40
        image_size=[1, 2, 3],
41
        out_channels=3,
42
        num_channel_initial=3,
43
        extract_levels=[1, 2, 3],
44
        out_kernel_initializer="softmax",
45
        out_activation="softmax",
46
    )
47
48
    # asserting initialised var for extract_levels is the same - Pass
49
    assert global_test._extract_levels == [1, 2, 3]
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _extract_levels was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
50
    # asserting initialised var for extract_max_level is the same - Pass
51
    assert global_test._extract_max_level == 3
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _extract_max_level was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
52
53
    # self reference grid
54
    # assert global_test.reference_grid correct shape, Pass
55
    assert global_test.reference_grid.shape == [1, 2, 3, 3]
56
    # assert correct reference grid returned, Pass
57
    expected_ref_grid = tf.convert_to_tensor(
58
        [
59
            [
60
                [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0]],
61
                [[0.0, 1.0, 0.0], [0.0, 1.0, 1.0], [0.0, 1.0, 2.0]],
62
            ]
63
        ],
64
        dtype=tf.float32,
65
    )
66
    assert is_equal_tf(global_test.reference_grid, expected_ref_grid)
67
68
    # assert correct initial transform is returned
69
    expected_transform_initial = tf.convert_to_tensor(
70
        [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
71
        dtype=tf.float32,
72
    )
73
    global_transform_initial = tf.Variable(global_test.transform_initial(shape=[12]))
74
    assert is_equal_tf(global_transform_initial, expected_transform_initial)
75
76
    # assert conv3dBlock type is correct, Pass
77
    assert isinstance(global_test._conv3d_block, layer.Conv3dBlock)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _conv3d_block was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
78
79
80
def test_call_global_net():
81
    """
82
    Asserting that output shape of globalnet Call method
83
    is correct.
84
    """
85
    out = 3
86
    im_size = (1, 2, 3)
87
    batch_size = 5
88
    # initialising GlobalNet instance
89
    global_test = g.GlobalNet(
90
        image_size=im_size,
91
        out_channels=out,
92
        num_channel_initial=3,
93
        extract_levels=[1, 2, 3],
94
        out_kernel_initializer="softmax",
95
        out_activation="softmax",
96
    )
97
    # pass an input of all zeros
98
    inputs = tf.constant(
99
        np.zeros(
100
            (batch_size, im_size[0], im_size[1], im_size[2], out), dtype=np.float32
101
        )
102
    )
103
    # get outputs by calling
104
    ddf, theta = global_test.call(inputs)
105
    assert ddf.shape == (batch_size, *im_size, 3)
106
    assert theta.shape == (batch_size, 4, 3)
107
108
109
class TestLocalNet:
110
    """
111
    Test the backbone.local_net.LocalNet class
112
    """
113
114
    @pytest.mark.parametrize(
115
        "image_size,extract_levels,control_points",
116
        [((10, 20, 30), [1, 2, 3], None), ((8, 8, 8), [1, 2, 3], (2, 2, 2))],
117
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
118
    def test_init(self, image_size, extract_levels, control_points):
119
        network = loc.LocalNet(
120
            image_size=image_size,
121
            out_channels=3,
122
            num_channel_initial=3,
123
            extract_levels=extract_levels,
124
            out_kernel_initializer="he_normal",
125
            out_activation="softmax",
126
            control_points=control_points,
127
        )
128
129
        # asserting initialised var for extract_levels is the same - Pass
130
        assert network._extract_levels == extract_levels
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _extract_levels was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
131
        # asserting initialised var for extract_max_level is the same - Pass
132
        assert network._extract_max_level == max(extract_levels)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _extract_max_level was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
133
        # asserting initialised var for extract_min_level is the same - Pass
134
        assert network._extract_min_level == min(extract_levels)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _extract_min_level was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
135
136
        # assert number of downsample blocks is correct (== max level), Pass
137
        assert len(network._downsample_convs) == max(extract_levels)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _downsample_convs was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
138
139
        # assert upsample blocks type is correct, Pass
140
        assert all(
141
            isinstance(item, layer.LocalNetUpSampleResnetBlock)
142
            for item in network._upsample_blocks
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _upsample_blocks was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
143
        )
144
        # assert number of upsample blocks is correct (== max level - min level), Pass
145
        assert len(network._upsample_blocks) == max(extract_levels) - min(
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _upsample_blocks was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
146
            extract_levels
147
        )
148
149
        # assert number of upsample blocks is correct (== extract_levels), Pass
150
        assert len(network._extract_layers) == len(extract_levels)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _extract_layers was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
151
152
        if control_points is None:
153
            assert network.resize is False
154
        else:
155
            assert isinstance(network.resize, layer.ResizeCPTransform)
156
            assert isinstance(network.interpolate, layer.BSplines3DTransform)
157
158
    @pytest.mark.parametrize(
159
        "image_size,extract_levels,control_points",
160
        [((64, 65, 66), [1, 2, 3, 4], None), ((8, 8, 8), [1, 2, 3], (2, 2, 2))],
161
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
162
    def test_call(self, image_size, extract_levels, control_points):
163
        # initialising LocalNet instance
164
        network = loc.LocalNet(
165
            image_size=image_size,
166
            out_channels=3,
167
            num_channel_initial=3,
168
            extract_levels=extract_levels,
169
            out_kernel_initializer="he_normal",
170
            out_activation="softmax",
171
            control_points=control_points,
172
        )
173
174
        # pass an input of all zeros
175
        inputs = tf.constant(
176
            np.zeros(
177
                (5, image_size[0], image_size[1], image_size[2], 3), dtype=np.float32
178
            )
179
        )
180
        # get outputs by calling
181
        output = network.call(inputs)
182
        # expected shape is (5, 1, 2, 3, 3)
183
        assert all(x == y for x, y in zip(inputs.shape, output.shape))
184
185
186
class TestUNet:
187
    """
188
    Test the backbone.u_net.UNet class
189
    """
190
191
    @pytest.mark.parametrize(
192
        "image_size,depth,control_points",
193
        [((32, 33, 34), 5, None), ((8, 8, 8), 3, (2, 2, 2))],
194
    )
195
    @pytest.mark.parametrize("pooling", [True, False])
196
    def test_call_unet(
197
        self, image_size: tuple, depth: int, control_points: tuple, pooling: bool
198
    ):
199
        """
200
201
        :param image_size: (dim1, dim2, dim3), dims of input image.
202
        :param depth: input is at level 0, bottom is at level depth
203
        :param pooling: for downsampling, use non-parameterized
204
                        pooling if true, otherwise use conv3d
205
        :param control_points: specify the distance between control points (in voxels).
206
        """
207
        out_ch = 3
208
        network = u.UNet(
209
            image_size=image_size,
210
            out_channels=out_ch,
211
            num_channel_initial=2,
212
            depth=depth,
213
            out_kernel_initializer="he_normal",
214
            out_activation="softmax",
215
            pooling=pooling,
216
            control_points=control_points,
217
        )
218
        inputs = tf.ones(shape=(5, image_size[0], image_size[1], image_size[2], out_ch))
219
        output = network.call(inputs)
220
        assert all(x == y for x, y in zip(inputs.shape, output.shape))
221