| Conditions | 3 |
| Total Lines | 56 |
| Code Lines | 37 |
| Lines | 0 |
| Ratio | 0 % |
| Changes | 0 | ||
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
| 1 | # coding=utf-8 |
||
| 19 | @pytest.mark.parametrize( |
||
| 20 | "depth,encode_num_channels,decode_num_channels", |
||
| 21 | [ |
||
| 22 | (2, (4, 8, 16), (4, 8, 16)), |
||
| 23 | (2, (4, 8, 8), (4, 8, 8)), |
||
| 24 | (2, (4, 8, 8), (8, 8, 8)), |
||
| 25 | ], |
||
| 26 | ) |
||
| 27 | @pytest.mark.parametrize("pooling", [True, False]) |
||
|
|
|||
| 28 | @pytest.mark.parametrize("concat_skip", [True, False]) |
||
| 29 | def test_channels( |
||
| 30 | self, |
||
| 31 | depth: int, |
||
| 32 | encode_num_channels: Tuple, |
||
| 33 | decode_num_channels: Tuple, |
||
| 34 | pooling: bool, |
||
| 35 | concat_skip: bool, |
||
| 36 | ): |
||
| 37 | """ |
||
| 38 | Test unet with custom encode/decode channels. |
||
| 39 | |||
| 40 | :param depth: input is at level 0, bottom is at level depth |
||
| 41 | :param encode_num_channels: filters/channels for down-sampling, |
||
| 42 | by default it is doubled at each layer during down-sampling |
||
| 43 | :param decode_num_channels: filters/channels for up-sampling, |
||
| 44 | by default it is the same as encode_num_channels |
||
| 45 | :param pooling: for down-sampling, use non-parameterized |
||
| 46 | pooling if true, otherwise use conv3d |
||
| 47 | :param concat_skip: if concatenate skip or add it |
||
| 48 | """ |
||
| 49 | # in case of adding skip tensors, the channels should match |
||
| 50 | expect_err = (not concat_skip) and encode_num_channels != decode_num_channels |
||
| 51 | |||
| 52 | image_size = (5, 6, 7) |
||
| 53 | out_ch = 3 |
||
| 54 | try: |
||
| 55 | network = UNet( |
||
| 56 | image_size=image_size, |
||
| 57 | out_channels=out_ch, |
||
| 58 | num_channel_initial=None, |
||
| 59 | encode_num_channels=encode_num_channels, |
||
| 60 | decode_num_channels=decode_num_channels, |
||
| 61 | depth=depth, |
||
| 62 | out_kernel_initializer="he_normal", |
||
| 63 | out_activation="softmax", |
||
| 64 | pooling=pooling, |
||
| 65 | concat_skip=concat_skip, |
||
| 66 | ) |
||
| 67 | except ValueError as err: |
||
| 68 | if expect_err: |
||
| 69 | return |
||
| 70 | raise err |
||
| 71 | inputs = tf.ones(shape=(5, *image_size, out_ch)) |
||
| 72 | |||
| 73 | output = network.call(inputs) |
||
| 74 | assert inputs.shape == output.shape |
||
| 75 | |||
| 135 |