@@ 218-233 (lines=16) @@ | ||
215 | ) |
|
216 | ] |
|
217 | ||
218 | def test_build_model(self, model, labeled, backbone): |
|
219 | expected_outputs_len = 3 if labeled else 2 |
|
220 | if backbone == "global": |
|
221 | expected_outputs_len += 1 |
|
222 | theta = model._outputs["theta"] |
|
223 | assert theta.shape == (batch_size, 4, 3) |
|
224 | assert len(model._outputs) == expected_outputs_len |
|
225 | ||
226 | ddf = model._outputs["ddf"] |
|
227 | pred_fixed_image = model._outputs["pred_fixed_image"] |
|
228 | assert ddf.shape == (batch_size, *fixed_image_size, 3) |
|
229 | assert pred_fixed_image.shape == (batch_size, *fixed_image_size) |
|
230 | ||
231 | if labeled: |
|
232 | pred_fixed_label = model._outputs["pred_fixed_label"] |
|
233 | assert pred_fixed_label.shape == (batch_size, *fixed_image_size) |
|
234 | ||
235 | def test_build_loss(self, model, labeled, backbone): |
|
236 | expected = 3 if labeled else 2 |
|
@@ 258-271 (lines=14) @@ | ||
255 | ) |
|
256 | ] |
|
257 | ||
258 | def test_build_model(self, model, labeled, backbone): |
|
259 | expected_outputs_len = 4 if labeled else 3 |
|
260 | assert len(model._outputs) == expected_outputs_len |
|
261 | ||
262 | dvf = model._outputs["dvf"] |
|
263 | ddf = model._outputs["ddf"] |
|
264 | pred_fixed_image = model._outputs["pred_fixed_image"] |
|
265 | assert dvf.shape == (batch_size, *fixed_image_size, 3) |
|
266 | assert ddf.shape == (batch_size, *fixed_image_size, 3) |
|
267 | assert pred_fixed_image.shape == (batch_size, *fixed_image_size) |
|
268 | ||
269 | if labeled: |
|
270 | pred_fixed_label = model._outputs["pred_fixed_label"] |
|
271 | assert pred_fixed_label.shape == (batch_size, *fixed_image_size) |
|
272 | ||
273 | def test_build_loss(self, model, labeled, backbone): |
|
274 | expected = 3 if labeled else 2 |