| @@ 89-111 (lines=23) @@ | ||
| 86 | output = network.call(inputs) |
|
| 87 | assert inputs.shape == output.shape |
|
| 88 | ||
| 89 | def test_get_config(self): |
|
| 90 | config = dict( |
|
| 91 | image_size=(4, 5, 6), |
|
| 92 | out_channels=3, |
|
| 93 | num_channel_initial=2, |
|
| 94 | depth=2, |
|
| 95 | extract_levels=(0, 1), |
|
| 96 | out_kernel_initializer="he_normal", |
|
| 97 | out_activation="softmax", |
|
| 98 | pooling=False, |
|
| 99 | concat_skip=False, |
|
| 100 | use_additive_upsampling=True, |
|
| 101 | encode_kernel_sizes=[7, 3, 3], |
|
| 102 | decode_kernel_sizes=3, |
|
| 103 | encode_num_channels=(2, 4, 8), |
|
| 104 | decode_num_channels=(2, 4, 8), |
|
| 105 | strides=2, |
|
| 106 | padding="same", |
|
| 107 | name="Test", |
|
| 108 | ) |
|
| 109 | network = LocalNet(**config) |
|
| 110 | got = network.get_config() |
|
| 111 | assert got == config |
|
| 112 | ||
| @@ 113-134 (lines=22) @@ | ||
| 110 | output = network.call(inputs) |
|
| 111 | assert inputs.shape == output.shape |
|
| 112 | ||
| 113 | def test_get_config(self): |
|
| 114 | config = dict( |
|
| 115 | image_size=(4, 5, 6), |
|
| 116 | out_channels=3, |
|
| 117 | num_channel_initial=2, |
|
| 118 | depth=2, |
|
| 119 | extract_levels=(0, 1), |
|
| 120 | out_kernel_initializer="he_normal", |
|
| 121 | out_activation="softmax", |
|
| 122 | pooling=False, |
|
| 123 | concat_skip=False, |
|
| 124 | encode_kernel_sizes=3, |
|
| 125 | decode_kernel_sizes=3, |
|
| 126 | encode_num_channels=(2, 4, 8), |
|
| 127 | decode_num_channels=(2, 4, 8), |
|
| 128 | strides=2, |
|
| 129 | padding="same", |
|
| 130 | name="Test", |
|
| 131 | ) |
|
| 132 | network = UNet(**config) |
|
| 133 | got = network.get_config() |
|
| 134 | assert got == config |
|
| 135 | ||
| @@ 93-114 (lines=22) @@ | ||
| 90 | ) |
|
| 91 | assert "GlobalNet requires `depth` or `extract_levels`" in str(err_info.value) |
|
| 92 | ||
| 93 | def test_get_config(self): |
|
| 94 | config = dict( |
|
| 95 | image_size=(4, 5, 6), |
|
| 96 | out_channels=3, |
|
| 97 | num_channel_initial=2, |
|
| 98 | depth=2, |
|
| 99 | extract_levels=(2,), |
|
| 100 | out_kernel_initializer="he_normal", |
|
| 101 | out_activation="softmax", |
|
| 102 | pooling=False, |
|
| 103 | concat_skip=False, |
|
| 104 | encode_kernel_sizes=[7, 3, 3], |
|
| 105 | decode_kernel_sizes=3, |
|
| 106 | encode_num_channels=[2, 4, 8], |
|
| 107 | decode_num_channels=[2, 4, 8], |
|
| 108 | strides=2, |
|
| 109 | padding="same", |
|
| 110 | name="Test", |
|
| 111 | ) |
|
| 112 | network = GlobalNet(**config) |
|
| 113 | got = network.get_config() |
|
| 114 | assert got == config |
|
| 115 | ||