|
@@ 218-233 (lines=16) @@
|
| 215 |
|
) |
| 216 |
|
] |
| 217 |
|
|
| 218 |
|
def test_build_model(self, model, labeled, backbone): |
| 219 |
|
expected_outputs_len = 3 if labeled else 2 |
| 220 |
|
if backbone == "global": |
| 221 |
|
expected_outputs_len += 1 |
| 222 |
|
theta = model._outputs["theta"] |
| 223 |
|
assert theta.shape == (batch_size, 4, 3) |
| 224 |
|
assert len(model._outputs) == expected_outputs_len |
| 225 |
|
|
| 226 |
|
ddf = model._outputs["ddf"] |
| 227 |
|
pred_fixed_image = model._outputs["pred_fixed_image"] |
| 228 |
|
assert ddf.shape == (batch_size, *fixed_image_size, 3) |
| 229 |
|
assert pred_fixed_image.shape == (batch_size, *fixed_image_size) |
| 230 |
|
|
| 231 |
|
if labeled: |
| 232 |
|
pred_fixed_label = model._outputs["pred_fixed_label"] |
| 233 |
|
assert pred_fixed_label.shape == (batch_size, *fixed_image_size) |
| 234 |
|
|
| 235 |
|
def test_build_loss(self, model, labeled, backbone): |
| 236 |
|
expected = 3 if labeled else 2 |
|
@@ 258-271 (lines=14) @@
|
| 255 |
|
) |
| 256 |
|
] |
| 257 |
|
|
| 258 |
|
def test_build_model(self, model, labeled, backbone): |
| 259 |
|
expected_outputs_len = 4 if labeled else 3 |
| 260 |
|
assert len(model._outputs) == expected_outputs_len |
| 261 |
|
|
| 262 |
|
dvf = model._outputs["dvf"] |
| 263 |
|
ddf = model._outputs["ddf"] |
| 264 |
|
pred_fixed_image = model._outputs["pred_fixed_image"] |
| 265 |
|
assert dvf.shape == (batch_size, *fixed_image_size, 3) |
| 266 |
|
assert ddf.shape == (batch_size, *fixed_image_size, 3) |
| 267 |
|
assert pred_fixed_image.shape == (batch_size, *fixed_image_size) |
| 268 |
|
|
| 269 |
|
if labeled: |
| 270 |
|
pred_fixed_label = model._outputs["pred_fixed_label"] |
| 271 |
|
assert pred_fixed_label.shape == (batch_size, *fixed_image_size) |
| 272 |
|
|
| 273 |
|
def test_build_loss(self, model, labeled, backbone): |
| 274 |
|
expected = 3 if labeled else 2 |