| @@ 89-109 (lines=21) @@ | ||
| 86 | output = network.call(inputs) |
|
| 87 | assert inputs.shape == output.shape |
|
| 88 | ||
| 89 | def test_get_config(self): |
|
| 90 | config = dict( |
|
| 91 | image_size=(4, 5, 6), |
|
| 92 | out_channels=3, |
|
| 93 | num_channel_initial=2, |
|
| 94 | depth=2, |
|
| 95 | extract_levels=(0, 1), |
|
| 96 | out_kernel_initializer="he_normal", |
|
| 97 | out_activation="softmax", |
|
| 98 | pooling=False, |
|
| 99 | concat_skip=False, |
|
| 100 | use_additive_upsampling=True, |
|
| 101 | encode_kernel_sizes=[7, 3, 3], |
|
| 102 | decode_kernel_sizes=3, |
|
| 103 | strides=2, |
|
| 104 | padding="same", |
|
| 105 | name="Test", |
|
| 106 | ) |
|
| 107 | network = LocalNet(**config) |
|
| 108 | got = network.get_config() |
|
| 109 | assert got == config |
|
| 110 | ||
| @@ 73-92 (lines=20) @@ | ||
| 70 | assert ddf.shape == inputs.shape |
|
| 71 | assert theta.shape == (batch_size, 4, 3) |
|
| 72 | ||
| 73 | def test_get_config(self): |
|
| 74 | config = dict( |
|
| 75 | image_size=(4, 5, 6), |
|
| 76 | out_channels=3, |
|
| 77 | num_channel_initial=2, |
|
| 78 | depth=2, |
|
| 79 | extract_levels=(2,), |
|
| 80 | out_kernel_initializer="he_normal", |
|
| 81 | out_activation="softmax", |
|
| 82 | pooling=False, |
|
| 83 | concat_skip=False, |
|
| 84 | encode_kernel_sizes=[7, 3, 3], |
|
| 85 | decode_kernel_sizes=3, |
|
| 86 | strides=2, |
|
| 87 | padding="same", |
|
| 88 | name="Test", |
|
| 89 | ) |
|
| 90 | network = GlobalNet(**config) |
|
| 91 | got = network.get_config() |
|
| 92 | assert got == config |
|
| 93 | ||
| @@ 53-72 (lines=20) @@ | ||
| 50 | output = network.call(inputs) |
|
| 51 | assert inputs.shape == output.shape |
|
| 52 | ||
| 53 | def test_get_config(self): |
|
| 54 | config = dict( |
|
| 55 | image_size=(4, 5, 6), |
|
| 56 | out_channels=3, |
|
| 57 | num_channel_initial=2, |
|
| 58 | depth=2, |
|
| 59 | extract_levels=(0, 1), |
|
| 60 | out_kernel_initializer="he_normal", |
|
| 61 | out_activation="softmax", |
|
| 62 | pooling=False, |
|
| 63 | concat_skip=False, |
|
| 64 | encode_kernel_sizes=3, |
|
| 65 | decode_kernel_sizes=3, |
|
| 66 | strides=2, |
|
| 67 | padding="same", |
|
| 68 | name="Test", |
|
| 69 | ) |
|
| 70 | network = UNet(**config) |
|
| 71 | got = network.get_config() |
|
| 72 | assert got == config |
|
| 73 | ||