1 | # coding=utf-8 |
||
2 | |||
3 | """ |
||
4 | Tests for deepreg/model/backbone/u_net.py |
||
5 | """ |
||
6 | from typing import Tuple |
||
7 | |||
8 | import pytest |
||
9 | import tensorflow as tf |
||
10 | |||
11 | from deepreg.model.backbone.u_net import UNet |
||
12 | |||
13 | |||
14 | class TestUNet: |
||
15 | """ |
||
16 | Test the backbone.u_net.UNet class |
||
17 | """ |
||
18 | |||
19 | @pytest.mark.parametrize( |
||
20 | "depth,encode_num_channels,decode_num_channels", |
||
21 | [ |
||
22 | (2, (4, 8, 16), (4, 8, 16)), |
||
23 | (2, (4, 8, 8), (4, 8, 8)), |
||
24 | (2, (4, 8, 8), (8, 8, 8)), |
||
25 | ], |
||
26 | ) |
||
27 | @pytest.mark.parametrize("pooling", [True, False]) |
||
28 | @pytest.mark.parametrize("concat_skip", [True, False]) |
||
29 | def test_channels( |
||
30 | self, |
||
31 | depth: int, |
||
32 | encode_num_channels: Tuple, |
||
33 | decode_num_channels: Tuple, |
||
34 | pooling: bool, |
||
35 | concat_skip: bool, |
||
36 | ): |
||
37 | """ |
||
38 | Test unet with custom encode/decode channels. |
||
39 | |||
40 | :param depth: input is at level 0, bottom is at level depth |
||
41 | :param encode_num_channels: filters/channels for down-sampling, |
||
42 | by default it is doubled at each layer during down-sampling |
||
43 | :param decode_num_channels: filters/channels for up-sampling, |
||
44 | by default it is the same as encode_num_channels |
||
45 | :param pooling: for down-sampling, use non-parameterized |
||
46 | pooling if true, otherwise use conv3d |
||
47 | :param concat_skip: if concatenate skip or add it |
||
48 | """ |
||
49 | # in case of adding skip tensors, the channels should match |
||
50 | expect_err = (not concat_skip) and encode_num_channels != decode_num_channels |
||
51 | |||
52 | image_size = (5, 6, 7) |
||
53 | out_ch = 3 |
||
54 | try: |
||
55 | network = UNet( |
||
56 | image_size=image_size, |
||
57 | out_channels=out_ch, |
||
58 | num_channel_initial=0, |
||
59 | encode_num_channels=encode_num_channels, |
||
60 | decode_num_channels=decode_num_channels, |
||
61 | depth=depth, |
||
62 | out_kernel_initializer="he_normal", |
||
63 | out_activation="softmax", |
||
64 | pooling=pooling, |
||
65 | concat_skip=concat_skip, |
||
66 | ) |
||
67 | except ValueError as err: |
||
68 | if expect_err: |
||
69 | return |
||
70 | raise err |
||
71 | inputs = tf.ones(shape=(5, *image_size, out_ch)) |
||
72 | |||
73 | output = network.call(inputs) |
||
74 | assert inputs.shape == output.shape |
||
75 | |||
76 | @pytest.mark.parametrize( |
||
77 | "image_size,depth", |
||
78 | [((11, 12, 13), 5), ((8, 8, 8), 3)], |
||
79 | ) |
||
80 | @pytest.mark.parametrize("pooling", [True, False]) |
||
81 | @pytest.mark.parametrize("concat_skip", [True, False]) |
||
82 | def test_call( |
||
83 | self, |
||
84 | image_size: Tuple, |
||
85 | depth: int, |
||
86 | pooling: bool, |
||
87 | concat_skip: bool, |
||
88 | ): |
||
89 | """ |
||
90 | Test unet call function. |
||
91 | |||
92 | :param image_size: (dim1, dim2, dim3), dims of input image. |
||
93 | :param depth: input is at level 0, bottom is at level depth |
||
94 | :param pooling: for down-sampling, use non-parameterized |
||
95 | pooling if true, otherwise use conv3d |
||
96 | :param concat_skip: if concatenate skip or add it |
||
97 | """ |
||
98 | out_ch = 3 |
||
99 | network = UNet( |
||
100 | image_size=image_size, |
||
101 | out_channels=out_ch, |
||
102 | num_channel_initial=2, |
||
103 | depth=depth, |
||
104 | out_kernel_initializer="he_normal", |
||
105 | out_activation="softmax", |
||
106 | pooling=pooling, |
||
107 | concat_skip=concat_skip, |
||
108 | ) |
||
109 | inputs = tf.ones(shape=(5, *image_size, out_ch)) |
||
110 | output = network.call(inputs) |
||
111 | assert inputs.shape == output.shape |
||
112 | |||
113 | View Code Duplication | def test_get_config(self): |
|
0 ignored issues
–
show
Duplication
introduced
by
![]() |
|||
114 | config = dict( |
||
115 | image_size=(4, 5, 6), |
||
116 | out_channels=3, |
||
117 | num_channel_initial=2, |
||
118 | depth=2, |
||
119 | extract_levels=(0, 1), |
||
120 | out_kernel_initializer="he_normal", |
||
121 | out_activation="softmax", |
||
122 | pooling=False, |
||
123 | concat_skip=False, |
||
124 | encode_kernel_sizes=3, |
||
125 | decode_kernel_sizes=3, |
||
126 | encode_num_channels=(2, 4, 8), |
||
127 | decode_num_channels=(2, 4, 8), |
||
128 | strides=2, |
||
129 | padding="same", |
||
130 | name="Test", |
||
131 | ) |
||
132 | network = UNet(**config) |
||
133 | got = network.get_config() |
||
134 | assert got == config |
||
135 |