Passed
Pull Request — main (#656)
by Yunguan
02:35
created

test.unit.test_backbone.TestUNet.test_init()   A

Complexity

Conditions 2

Size

Total Lines 50
Code Lines 29

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 29
dl 0
loc 50
rs 9.184
c 0
b 0
f 0
cc 2
nop 4
1
# coding=utf-8
2
3
"""
4
Tests for deepreg/model/backbone
5
"""
6
from test.unit.util import is_equal_tf
7
8
import numpy as np
9
import pytest
10
import tensorflow as tf
11
12
import deepreg.model.backbone as backbone
13
import deepreg.model.backbone.global_net as g
14
import deepreg.model.backbone.local_net as loc
15
import deepreg.model.backbone.u_net as u
16
import deepreg.model.layer as layer
17
18
19
def test_backbone_interface():
20
    """Test the get_config of the interface"""
21
    config = dict(
22
        image_size=(5, 5, 5),
23
        out_channels=3,
24
        num_channel_initial=4,
25
        out_kernel_initializer="zeros",
26
        out_activation="relu",
27
        name="test",
28
    )
29
    model = backbone.Backbone(**config)
30
    got = model.get_config()
31
    assert got == config
32
33
34
def test_init_global_net():
35
    """
36
    Testing init of GlobalNet is built as expected.
37
    """
38
    # initialising GlobalNet instance
39
    global_test = g.GlobalNet(
40
        image_size=[1, 2, 3],
41
        out_channels=3,
42
        num_channel_initial=3,
43
        extract_levels=[1, 2, 3],
44
        out_kernel_initializer="softmax",
45
        out_activation="softmax",
46
    )
47
48
    # asserting initialised var for extract_levels is the same - Pass
49
    assert global_test._extract_levels == [1, 2, 3]
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _extract_levels was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
50
    # asserting initialised var for extract_max_level is the same - Pass
51
    assert global_test._extract_max_level == 3
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _extract_max_level was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
52
53
    # self reference grid
54
    # assert global_test.reference_grid correct shape, Pass
55
    assert global_test.reference_grid.shape == [1, 2, 3, 3]
56
    # assert correct reference grid returned, Pass
57
    expected_ref_grid = tf.convert_to_tensor(
58
        [
59
            [
60
                [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0]],
61
                [[0.0, 1.0, 0.0], [0.0, 1.0, 1.0], [0.0, 1.0, 2.0]],
62
            ]
63
        ],
64
        dtype=tf.float32,
65
    )
66
    assert is_equal_tf(global_test.reference_grid, expected_ref_grid)
67
68
    # assert correct initial transform is returned
69
    expected_transform_initial = tf.convert_to_tensor(
70
        [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
71
        dtype=tf.float32,
72
    )
73
    global_transform_initial = tf.Variable(global_test.transform_initial(shape=[12]))
74
    assert is_equal_tf(global_transform_initial, expected_transform_initial)
75
76
    # assert conv3dBlock type is correct, Pass
77
    assert isinstance(global_test._conv3d_block, layer.Conv3dBlock)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _conv3d_block was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
78
79
80
def test_call_global_net():
81
    """
82
    Asserting that output shape of globalnet Call method
83
    is correct.
84
    """
85
    out = 3
86
    im_size = (1, 2, 3)
87
    batch_size = 5
88
    # initialising GlobalNet instance
89
    global_test = g.GlobalNet(
90
        image_size=im_size,
91
        out_channels=out,
92
        num_channel_initial=3,
93
        extract_levels=[1, 2, 3],
94
        out_kernel_initializer="softmax",
95
        out_activation="softmax",
96
    )
97
    # pass an input of all zeros
98
    inputs = tf.constant(
99
        np.zeros(
100
            (batch_size, im_size[0], im_size[1], im_size[2], out), dtype=np.float32
101
        )
102
    )
103
    # get outputs by calling
104
    ddf, theta = global_test.call(inputs)
105
    assert ddf.shape == (batch_size, *im_size, 3)
106
    assert theta.shape == (batch_size, 4, 3)
107
108
109
class TestLocalNet:
110
    """
111
    Test the backbone.local_net.LocalNet class
112
    """
113
114
    @pytest.mark.parametrize(
115
        "image_size,extract_levels,control_points",
116
        [((10, 20, 30), [1, 2, 3], None), ((8, 8, 8), [1, 2, 3], (2, 2, 2))],
117
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
118
    def test_init(self, image_size, extract_levels, control_points):
119
        network = loc.LocalNet(
120
            image_size=image_size,
121
            out_channels=3,
122
            num_channel_initial=3,
123
            extract_levels=extract_levels,
124
            out_kernel_initializer="he_normal",
125
            out_activation="softmax",
126
            control_points=control_points,
127
        )
128
129
        # asserting initialised var for extract_levels is the same - Pass
130
        assert network._extract_levels == extract_levels
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _extract_levels was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
131
        # asserting initialised var for extract_max_level is the same - Pass
132
        assert network._extract_max_level == max(extract_levels)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _extract_max_level was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
133
        # asserting initialised var for extract_min_level is the same - Pass
134
        assert network._extract_min_level == min(extract_levels)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _extract_min_level was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
135
136
        # assert number of downsample blocks is correct (== max level), Pass
137
        assert len(network._downsample_convs) == max(extract_levels)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _downsample_convs was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
138
139
        # assert number of upsample blocks is correct (== extract_levels), Pass
140
        assert len(network._extract_layers) == len(extract_levels)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _extract_layers was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
141
142
        if control_points is None:
143
            assert network.resize is False
144
        else:
145
            assert isinstance(network.resize, layer.ResizeCPTransform)
146
            assert isinstance(network.interpolate, layer.BSplines3DTransform)
147
148
    @pytest.mark.parametrize(
149
        "image_size,extract_levels,control_points",
150
        [((64, 65, 66), [1, 2, 3, 4], None), ((8, 8, 8), [1, 2, 3], (2, 2, 2))],
151
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
152
    def test_call(self, image_size, extract_levels, control_points):
153
        # initialising LocalNet instance
154
        network = loc.LocalNet(
155
            image_size=image_size,
156
            out_channels=3,
157
            num_channel_initial=3,
158
            extract_levels=extract_levels,
159
            out_kernel_initializer="he_normal",
160
            out_activation="softmax",
161
            control_points=control_points,
162
        )
163
164
        # pass an input of all zeros
165
        inputs = tf.constant(
166
            np.zeros(
167
                (5, image_size[0], image_size[1], image_size[2], 3), dtype=np.float32
168
            )
169
        )
170
        # get outputs by calling
171
        output = network.call(inputs)
172
        # expected shape is (5, 1, 2, 3, 3)
173
        assert all(x == y for x, y in zip(inputs.shape, output.shape))
174
175
176
class TestUNet:
177
    """
178
    Test the backbone.u_net.UNet class
179
    """
180
181
    @pytest.mark.parametrize(
182
        "image_size,depth,control_points",
183
        [((32, 33, 34), 5, None), ((8, 8, 8), 3, (2, 2, 2))],
184
    )
185
    @pytest.mark.parametrize("pooling", [True, False])
186
    @pytest.mark.parametrize("concat_skip", [True, False])
187
    def test_call_unet(
188
        self,
189
        image_size: tuple,
190
        depth: int,
191
        control_points: tuple,
192
        pooling: bool,
193
        concat_skip: bool,
194
    ):
195
        """
196
197
        :param image_size: (dim1, dim2, dim3), dims of input image.
198
        :param depth: input is at level 0, bottom is at level depth
199
        :param pooling: for downsampling, use non-parameterized
200
                        pooling if true, otherwise use conv3d
201
        :param concat_skip: if concatenate skip or add it
202
        :param control_points: specify the distance between control points (in voxels).
203
        """
204
        out_ch = 3
205
        network = u.UNet(
206
            image_size=image_size,
207
            out_channels=out_ch,
208
            num_channel_initial=2,
209
            depth=depth,
210
            out_kernel_initializer="he_normal",
211
            out_activation="softmax",
212
            pooling=pooling,
213
            concat_skip=concat_skip,
214
            control_points=control_points,
215
        )
216
        inputs = tf.ones(shape=(5, image_size[0], image_size[1], image_size[2], out_ch))
217
        output = network.call(inputs)
218
        assert all(x == y for x, y in zip(inputs.shape, output.shape))
219