1 | # coding=utf-8 |
||
2 | |||
3 | """ |
||
4 | Tests for deepreg/model/backbone/local_net.py |
||
5 | """ |
||
6 | from typing import Tuple |
||
7 | |||
8 | import pytest |
||
9 | import tensorflow as tf |
||
10 | |||
11 | from deepreg.model.backbone.local_net import AdditiveUpsampling, LocalNet |
||
12 | |||
13 | |||
14 | def test_additive_up_sampling(): |
||
15 | """ |
||
16 | Test AdditiveUpsampling. |
||
17 | """ |
||
18 | batch = 3 |
||
19 | filters = 4 |
||
20 | input_shape = (4, 5, 6) |
||
21 | outputs_shape = tuple(x * 2 for x in input_shape) |
||
22 | config = dict( |
||
23 | filters=filters, |
||
24 | output_padding=(1, 1, 1), |
||
25 | kernel_size=3, |
||
26 | padding="same", |
||
27 | strides=2, |
||
28 | output_shape=outputs_shape, |
||
29 | name="TestAdditiveUpsampling", |
||
30 | ) |
||
31 | layer = AdditiveUpsampling(**config) |
||
32 | inputs = tf.ones(shape=(batch, *input_shape, filters * 2)) |
||
33 | output = layer.call(inputs) |
||
34 | assert output.shape == (batch, *outputs_shape, filters) |
||
35 | |||
36 | got = layer.get_config() |
||
37 | assert got == {"trainable": True, "dtype": "float32", **config} |
||
38 | |||
39 | |||
40 | class TestLocalNet: |
||
41 | """ |
||
42 | Test the backbone.local_net.LocalNet class |
||
43 | """ |
||
44 | |||
45 | @pytest.mark.parametrize( |
||
46 | "image_size,extract_levels,depth", |
||
47 | [((11, 12, 13), (0, 1, 2, 4), 4), ((8, 8, 8), (0, 1, 2), 3)], |
||
48 | ) |
||
49 | @pytest.mark.parametrize("use_additive_upsampling", [True, False]) |
||
50 | @pytest.mark.parametrize("pooling", [True, False]) |
||
51 | @pytest.mark.parametrize("concat_skip", [True, False]) |
||
52 | def test_call( |
||
53 | self, |
||
54 | image_size: tuple, |
||
55 | extract_levels: Tuple[int, ...], |
||
56 | depth: int, |
||
57 | use_additive_upsampling: bool, |
||
58 | pooling: bool, |
||
59 | concat_skip: bool, |
||
60 | ): |
||
61 | """ |
||
62 | |||
63 | :param image_size: (dim1, dim2, dim3), dims of input image. |
||
64 | :param extract_levels: from which depths the output will be built. |
||
65 | :param depth: input is at level 0, bottom is at level depth |
||
66 | :param use_additive_upsampling: whether use additive up-sampling layer |
||
67 | for decoding. |
||
68 | :param pooling: for down-sampling, use non-parameterized |
||
69 | pooling if true, otherwise use conv3d |
||
70 | :param concat_skip: if concatenate skip or add it |
||
71 | """ |
||
72 | out_ch = 3 |
||
73 | network = LocalNet( |
||
74 | image_size=image_size, |
||
75 | num_channel_initial=2, |
||
76 | extract_levels=extract_levels, |
||
77 | depth=depth, |
||
78 | out_kernel_initializer="he_normal", |
||
79 | out_activation="softmax", |
||
80 | out_channels=out_ch, |
||
81 | use_additive_upsampling=use_additive_upsampling, |
||
82 | pooling=pooling, |
||
83 | concat_skip=concat_skip, |
||
84 | ) |
||
85 | inputs = tf.ones(shape=(5, *image_size, out_ch)) |
||
86 | output = network.call(inputs) |
||
87 | assert inputs.shape == output.shape |
||
88 | |||
89 | View Code Duplication | def test_get_config(self): |
|
0 ignored issues
–
show
Duplication
introduced
by
![]() |
|||
90 | config = dict( |
||
91 | image_size=(4, 5, 6), |
||
92 | out_channels=3, |
||
93 | num_channel_initial=2, |
||
94 | depth=2, |
||
95 | extract_levels=(0, 1), |
||
96 | out_kernel_initializer="he_normal", |
||
97 | out_activation="softmax", |
||
98 | pooling=False, |
||
99 | concat_skip=False, |
||
100 | use_additive_upsampling=True, |
||
101 | encode_kernel_sizes=[7, 3, 3], |
||
102 | decode_kernel_sizes=3, |
||
103 | encode_num_channels=(2, 4, 8), |
||
104 | decode_num_channels=(2, 4, 8), |
||
105 | strides=2, |
||
106 | padding="same", |
||
107 | name="Test", |
||
108 | ) |
||
109 | network = LocalNet(**config) |
||
110 | got = network.get_config() |
||
111 | assert got == config |
||
112 |