| @@ 111-128 (lines=18) @@ | ||
| 108 | got = parse_image_loss(loss_config=loss_config) |
|
| 109 | assert got == expected |
|
| 110 | ||
| 111 | def test_parse_multiple_loss(self): |
|
| 112 | loss_config = { |
|
| 113 | "image": [ |
|
| 114 | { |
|
| 115 | "name": "lncc", |
|
| 116 | "weight": 0.5, |
|
| 117 | "kernel_size": 9, |
|
| 118 | "kernel_type": "rectangular", |
|
| 119 | }, |
|
| 120 | { |
|
| 121 | "name": "ssd", |
|
| 122 | "weight": 0.5, |
|
| 123 | }, |
|
| 124 | ], |
|
| 125 | } |
|
| 126 | ||
| 127 | got = parse_reg_loss(loss_config=loss_config) |
|
| 128 | assert got == loss_config |
|
| 129 | ||
| 130 | ||
| 131 | class TestParseLabelLoss: |
|
| @@ 176-191 (lines=16) @@ | ||
| 173 | got = parse_label_loss(loss_config=loss_config) |
|
| 174 | assert got == expected_config |
|
| 175 | ||
| 176 | def test_parse_multiple_loss(self): |
|
| 177 | loss_config = { |
|
| 178 | "label": [ |
|
| 179 | { |
|
| 180 | "name": "dice", |
|
| 181 | "weight": 1.0, |
|
| 182 | }, |
|
| 183 | { |
|
| 184 | "name": "cross-entropy", |
|
| 185 | "weight": 1.0, |
|
| 186 | }, |
|
| 187 | ], |
|
| 188 | } |
|
| 189 | ||
| 190 | got = parse_reg_loss(loss_config=loss_config) |
|
| 191 | assert got == loss_config |
|
| 192 | ||
| 193 | ||
| 194 | class TestParseRegularizationLoss: |
|