Conditions | 3 |
Total Lines | 56 |
Code Lines | 37 |
Lines | 0 |
Ratio | 0 % |
Changes | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
1 | # coding=utf-8 |
||
19 | @pytest.mark.parametrize( |
||
20 | "depth,encode_num_channels,decode_num_channels", |
||
21 | [ |
||
22 | (2, (4, 8, 16), (4, 8, 16)), |
||
23 | (2, (4, 8, 8), (4, 8, 8)), |
||
24 | (2, (4, 8, 8), (8, 8, 8)), |
||
25 | ], |
||
26 | ) |
||
27 | @pytest.mark.parametrize("pooling", [True, False]) |
||
|
|||
28 | @pytest.mark.parametrize("concat_skip", [True, False]) |
||
29 | def test_channels( |
||
30 | self, |
||
31 | depth: int, |
||
32 | encode_num_channels: Tuple, |
||
33 | decode_num_channels: Tuple, |
||
34 | pooling: bool, |
||
35 | concat_skip: bool, |
||
36 | ): |
||
37 | """ |
||
38 | Test unet with custom encode/decode channels. |
||
39 | |||
40 | :param depth: input is at level 0, bottom is at level depth |
||
41 | :param encode_num_channels: filters/channels for down-sampling, |
||
42 | by default it is doubled at each layer during down-sampling |
||
43 | :param decode_num_channels: filters/channels for up-sampling, |
||
44 | by default it is the same as encode_num_channels |
||
45 | :param pooling: for down-sampling, use non-parameterized |
||
46 | pooling if true, otherwise use conv3d |
||
47 | :param concat_skip: if concatenate skip or add it |
||
48 | """ |
||
49 | # in case of adding skip tensors, the channels should match |
||
50 | expect_err = (not concat_skip) and encode_num_channels != decode_num_channels |
||
51 | |||
52 | image_size = (5, 6, 7) |
||
53 | out_ch = 3 |
||
54 | try: |
||
55 | network = UNet( |
||
56 | image_size=image_size, |
||
57 | out_channels=out_ch, |
||
58 | num_channel_initial=None, |
||
59 | encode_num_channels=encode_num_channels, |
||
60 | decode_num_channels=decode_num_channels, |
||
61 | depth=depth, |
||
62 | out_kernel_initializer="he_normal", |
||
63 | out_activation="softmax", |
||
64 | pooling=pooling, |
||
65 | concat_skip=concat_skip, |
||
66 | ) |
||
67 | except ValueError as err: |
||
68 | if expect_err: |
||
69 | return |
||
70 | raise err |
||
71 | inputs = tf.ones(shape=(5, *image_size, out_ch)) |
||
72 | |||
73 | output = network.call(inputs) |
||
74 | assert inputs.shape == output.shape |
||
75 | |||
135 |