|
@@ 231-246 (lines=16) @@
|
| 228 |
|
) |
| 229 |
|
] |
| 230 |
|
|
| 231 |
|
def test_build_model(self, model, labeled, backbone): |
| 232 |
|
expected_outputs_len = 3 if labeled else 2 |
| 233 |
|
if backbone == "global": |
| 234 |
|
expected_outputs_len += 1 |
| 235 |
|
theta = model._outputs["theta"] |
| 236 |
|
assert theta.shape == (batch_size, 4, 3) |
| 237 |
|
assert len(model._outputs) == expected_outputs_len |
| 238 |
|
|
| 239 |
|
ddf = model._outputs["ddf"] |
| 240 |
|
pred_fixed_image = model._outputs["pred_fixed_image"] |
| 241 |
|
assert ddf.shape == (batch_size, *fixed_image_size, 3) |
| 242 |
|
assert pred_fixed_image.shape == (batch_size, *fixed_image_size) |
| 243 |
|
|
| 244 |
|
if labeled: |
| 245 |
|
pred_fixed_label = model._outputs["pred_fixed_label"] |
| 246 |
|
assert pred_fixed_label.shape == (batch_size, *fixed_image_size) |
| 247 |
|
|
| 248 |
|
def test_build_loss(self, model, labeled, backbone): |
| 249 |
|
expected = 3 if labeled else 2 |
|
@@ 271-284 (lines=14) @@
|
| 268 |
|
) |
| 269 |
|
] |
| 270 |
|
|
| 271 |
|
def test_build_model(self, model, labeled, backbone): |
| 272 |
|
expected_outputs_len = 4 if labeled else 3 |
| 273 |
|
assert len(model._outputs) == expected_outputs_len |
| 274 |
|
|
| 275 |
|
dvf = model._outputs["dvf"] |
| 276 |
|
ddf = model._outputs["ddf"] |
| 277 |
|
pred_fixed_image = model._outputs["pred_fixed_image"] |
| 278 |
|
assert dvf.shape == (batch_size, *fixed_image_size, 3) |
| 279 |
|
assert ddf.shape == (batch_size, *fixed_image_size, 3) |
| 280 |
|
assert pred_fixed_image.shape == (batch_size, *fixed_image_size) |
| 281 |
|
|
| 282 |
|
if labeled: |
| 283 |
|
pred_fixed_label = model._outputs["pred_fixed_label"] |
| 284 |
|
assert pred_fixed_label.shape == (batch_size, *fixed_image_size) |
| 285 |
|
|
| 286 |
|
def test_build_loss(self, model, labeled, backbone): |
| 287 |
|
expected = 3 if labeled else 2 |