@@ 231-246 (lines=16) @@ | ||
228 | ) |
|
229 | ] |
|
230 | ||
231 | def test_build_model(self, model, labeled, backbone): |
|
232 | expected_outputs_len = 3 if labeled else 2 |
|
233 | if backbone == "global": |
|
234 | expected_outputs_len += 1 |
|
235 | theta = model._outputs["theta"] |
|
236 | assert theta.shape == (batch_size, 4, 3) |
|
237 | assert len(model._outputs) == expected_outputs_len |
|
238 | ||
239 | ddf = model._outputs["ddf"] |
|
240 | pred_fixed_image = model._outputs["pred_fixed_image"] |
|
241 | assert ddf.shape == (batch_size, *fixed_image_size, 3) |
|
242 | assert pred_fixed_image.shape == (batch_size, *fixed_image_size) |
|
243 | ||
244 | if labeled: |
|
245 | pred_fixed_label = model._outputs["pred_fixed_label"] |
|
246 | assert pred_fixed_label.shape == (batch_size, *fixed_image_size) |
|
247 | ||
248 | def test_build_loss(self, model, labeled, backbone): |
|
249 | expected = 3 if labeled else 2 |
|
@@ 271-284 (lines=14) @@ | ||
268 | ) |
|
269 | ] |
|
270 | ||
271 | def test_build_model(self, model, labeled, backbone): |
|
272 | expected_outputs_len = 4 if labeled else 3 |
|
273 | assert len(model._outputs) == expected_outputs_len |
|
274 | ||
275 | dvf = model._outputs["dvf"] |
|
276 | ddf = model._outputs["ddf"] |
|
277 | pred_fixed_image = model._outputs["pred_fixed_image"] |
|
278 | assert dvf.shape == (batch_size, *fixed_image_size, 3) |
|
279 | assert ddf.shape == (batch_size, *fixed_image_size, 3) |
|
280 | assert pred_fixed_image.shape == (batch_size, *fixed_image_size) |
|
281 | ||
282 | if labeled: |
|
283 | pred_fixed_label = model._outputs["pred_fixed_label"] |
|
284 | assert pred_fixed_label.shape == (batch_size, *fixed_image_size) |
|
285 | ||
286 | def test_build_loss(self, model, labeled, backbone): |
|
287 | expected = 3 if labeled else 2 |