@@ 111-128 (lines=18) @@ | ||
108 | got = parse_image_loss(loss_config=loss_config) |
|
109 | assert got == expected |
|
110 | ||
111 | def test_parse_multiple_loss(self): |
|
112 | loss_config = { |
|
113 | "image": [ |
|
114 | { |
|
115 | "name": "lncc", |
|
116 | "weight": 0.5, |
|
117 | "kernel_size": 9, |
|
118 | "kernel_type": "rectangular", |
|
119 | }, |
|
120 | { |
|
121 | "name": "ssd", |
|
122 | "weight": 0.5, |
|
123 | }, |
|
124 | ], |
|
125 | } |
|
126 | ||
127 | got = parse_reg_loss(loss_config=loss_config) |
|
128 | assert got == loss_config |
|
129 | ||
130 | ||
131 | class TestParseLabelLoss: |
|
@@ 176-191 (lines=16) @@ | ||
173 | got = parse_label_loss(loss_config=loss_config) |
|
174 | assert got == expected_config |
|
175 | ||
176 | def test_parse_multiple_loss(self): |
|
177 | loss_config = { |
|
178 | "label": [ |
|
179 | { |
|
180 | "name": "dice", |
|
181 | "weight": 1.0, |
|
182 | }, |
|
183 | { |
|
184 | "name": "cross-entropy", |
|
185 | "weight": 1.0, |
|
186 | }, |
|
187 | ], |
|
188 | } |
|
189 | ||
190 | got = parse_reg_loss(loss_config=loss_config) |
|
191 | assert got == loss_config |
|
192 | ||
193 | ||
194 | class TestParseRegularizationLoss: |