@@ 89-109 (lines=21) @@ | ||
86 | output = network.call(inputs) |
|
87 | assert inputs.shape == output.shape |
|
88 | ||
89 | def test_get_config(self): |
|
90 | config = dict( |
|
91 | image_size=(4, 5, 6), |
|
92 | out_channels=3, |
|
93 | num_channel_initial=2, |
|
94 | depth=2, |
|
95 | extract_levels=(0, 1), |
|
96 | out_kernel_initializer="he_normal", |
|
97 | out_activation="softmax", |
|
98 | pooling=False, |
|
99 | concat_skip=False, |
|
100 | use_additive_upsampling=True, |
|
101 | encode_kernel_sizes=[7, 3, 3], |
|
102 | decode_kernel_sizes=3, |
|
103 | strides=2, |
|
104 | padding="same", |
|
105 | name="Test", |
|
106 | ) |
|
107 | network = LocalNet(**config) |
|
108 | got = network.get_config() |
|
109 | assert got == config |
|
110 |
@@ 73-93 (lines=21) @@ | ||
70 | assert ddf.shape == inputs.shape |
|
71 | assert theta.shape == (batch_size, 4, 3) |
|
72 | ||
73 | def test_get_config(self): |
|
74 | config = dict( |
|
75 | image_size=(4, 5, 6), |
|
76 | out_channels=3, |
|
77 | num_channel_initial=2, |
|
78 | depth=2, |
|
79 | extract_levels=(2,), |
|
80 | out_kernel_initializer="he_normal", |
|
81 | out_activation="softmax", |
|
82 | pooling=False, |
|
83 | concat_skip=False, |
|
84 | use_additive_upsampling=True, |
|
85 | encode_kernel_sizes=[7, 3, 3], |
|
86 | decode_kernel_sizes=3, |
|
87 | strides=2, |
|
88 | padding="same", |
|
89 | name="Test", |
|
90 | ) |
|
91 | network = GlobalNet(**config) |
|
92 | got = network.get_config() |
|
93 | assert got == config |
|
94 |
@@ 53-72 (lines=20) @@ | ||
50 | output = network.call(inputs) |
|
51 | assert inputs.shape == output.shape |
|
52 | ||
53 | def test_get_config(self): |
|
54 | config = dict( |
|
55 | image_size=(4, 5, 6), |
|
56 | out_channels=3, |
|
57 | num_channel_initial=2, |
|
58 | depth=2, |
|
59 | extract_levels=(0, 1), |
|
60 | out_kernel_initializer="he_normal", |
|
61 | out_activation="softmax", |
|
62 | pooling=False, |
|
63 | concat_skip=False, |
|
64 | encode_kernel_sizes=3, |
|
65 | decode_kernel_sizes=3, |
|
66 | strides=2, |
|
67 | padding="same", |
|
68 | name="Test", |
|
69 | ) |
|
70 | network = UNet(**config) |
|
71 | got = network.get_config() |
|
72 | assert got == config |
|
73 |