Conditions | 12 |
Total Lines | 243 |
Code Lines | 172 |
Lines | 0 |
Ratio | 0 % |
Changes | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like test.unit.test_interface.test_generator_data_loader() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | # coding=utf-8 |
||
217 | def test_generator_data_loader(caplog): |
||
218 | """ |
||
219 | Test the functions in GeneratorDataLoader |
||
220 | :param caplog: used to check warning message. |
||
221 | """ |
||
222 | generator = GeneratorDataLoader(labeled=True, num_indices=1, sample_label="all") |
||
223 | |||
224 | # test properties |
||
225 | assert generator.loader_moving_image is None |
||
226 | assert generator.loader_moving_image is None |
||
227 | assert generator.loader_moving_image is None |
||
228 | assert generator.loader_moving_image is None |
||
229 | |||
230 | # not implemented properties / functions |
||
231 | with pytest.raises(NotImplementedError): |
||
232 | generator.sample_index_generator() |
||
233 | |||
234 | # implemented functions |
||
235 | # test get_Dataset |
||
236 | dummy_array = np.random.random(size=(100, 100, 100)).astype(np.float32) |
||
237 | # for unlabeled data |
||
238 | # mock generator |
||
239 | sequence = [ |
||
240 | dict( |
||
241 | moving_image=dummy_array, |
||
242 | fixed_image=dummy_array, |
||
243 | moving_label=dummy_array, |
||
244 | fixed_label=dummy_array, |
||
245 | indices=[1], |
||
246 | ) |
||
247 | for i in range(3) |
||
248 | ] |
||
249 | |||
250 | def mock_generator(): |
||
251 | for el in sequence: |
||
252 | yield el |
||
253 | |||
254 | # inputs, no error means passed |
||
255 | generator.data_generator = mock_generator |
||
256 | dataset = generator.get_dataset() |
||
257 | |||
258 | # check dataset output |
||
259 | expected = dict( |
||
260 | moving_image=dummy_array, |
||
261 | fixed_image=dummy_array, |
||
262 | moving_label=dummy_array, |
||
263 | fixed_label=dummy_array, |
||
264 | indices=[1], |
||
265 | ) |
||
266 | for got in list(dataset.as_numpy_iterator()): |
||
267 | assert all(is_equal_np(got[key], expected[key]) for key in expected.keys()) |
||
268 | |||
269 | # for unlabeled data |
||
270 | generator_unlabeled = GeneratorDataLoader( |
||
271 | labeled=False, num_indices=1, sample_label="all" |
||
272 | ) |
||
273 | |||
274 | sequence = [ |
||
275 | dict(moving_image=dummy_array, fixed_image=dummy_array, indices=[1]) |
||
276 | for i in range(3) |
||
277 | ] |
||
278 | |||
279 | # inputs, no error means passed |
||
280 | generator_unlabeled.data_generator = mock_generator |
||
281 | dataset = generator_unlabeled.get_dataset() |
||
282 | |||
283 | # check dataset output |
||
284 | expected = dict(moving_image=dummy_array, fixed_image=dummy_array, indices=[1]) |
||
285 | for got in list(dataset.as_numpy_iterator()): |
||
286 | assert all(is_equal_np(got[key], expected[key]) for key in expected.keys()) |
||
287 | |||
288 | # test data_generator |
||
289 | # create mock data loader and sample index generator |
||
290 | class MockDataLoader: |
||
291 | def __init__(self, **kwargs): |
||
292 | super().__init__(**kwargs) |
||
293 | |||
294 | def get_data(index): |
||
295 | return dummy_array |
||
296 | |||
297 | def mock_sample_index_generator(): |
||
298 | return [[[1], [1], [1]]] |
||
299 | |||
300 | generator = GeneratorDataLoader(labeled=True, num_indices=1, sample_label="all") |
||
301 | generator.sample_index_generator = mock_sample_index_generator |
||
302 | generator.loader_moving_image = MockDataLoader |
||
303 | generator.loader_fixed_image = MockDataLoader |
||
304 | generator.loader_moving_label = MockDataLoader |
||
305 | generator.loader_fixed_label = MockDataLoader |
||
306 | |||
307 | # check data generator output |
||
308 | got = next(generator.data_generator()) |
||
309 | |||
310 | expected = dict( |
||
311 | moving_image=normalize_array(dummy_array), |
||
312 | fixed_image=normalize_array(dummy_array), |
||
313 | moving_label=dummy_array, |
||
314 | fixed_label=dummy_array, |
||
315 | indices=np.asarray([1] + [0], dtype=np.float32), |
||
316 | ) |
||
317 | assert all(is_equal_np(got[key], expected[key]) for key in expected.keys()) |
||
318 | |||
319 | # test validate_images_and_labels |
||
320 | with pytest.raises(ValueError) as err_info: |
||
321 | generator.validate_images_and_labels( |
||
322 | fixed_image=None, |
||
323 | moving_image=dummy_array, |
||
324 | moving_label=None, |
||
325 | fixed_label=None, |
||
326 | image_indices=[1], |
||
327 | ) |
||
328 | assert "moving image and fixed image must not be None" in str(err_info.value) |
||
329 | with pytest.raises(ValueError) as err_info: |
||
330 | generator.validate_images_and_labels( |
||
331 | fixed_image=dummy_array, |
||
332 | moving_image=dummy_array, |
||
333 | moving_label=dummy_array, |
||
334 | fixed_label=None, |
||
335 | image_indices=[1], |
||
336 | ) |
||
337 | assert "moving label and fixed label must be both None or non-None" in str( |
||
338 | err_info.value |
||
339 | ) |
||
340 | with pytest.raises(ValueError) as err_info: |
||
341 | generator.validate_images_and_labels( |
||
342 | fixed_image=dummy_array, |
||
343 | moving_image=dummy_array + 1.0, |
||
344 | moving_label=None, |
||
345 | fixed_label=None, |
||
346 | image_indices=[1], |
||
347 | ) |
||
348 | assert "Sample [1]'s moving_image's values are not between [0, 1]" in str( |
||
349 | err_info.value |
||
350 | ) |
||
351 | with pytest.raises(ValueError) as err_info: |
||
352 | generator.validate_images_and_labels( |
||
353 | fixed_image=dummy_array, |
||
354 | moving_image=np.random.random(size=(100, 100)), |
||
355 | moving_label=None, |
||
356 | fixed_label=None, |
||
357 | image_indices=[1], |
||
358 | ) |
||
359 | assert "Sample [1]'s moving_image' shape should be 3D. " in str(err_info.value) |
||
360 | with pytest.raises(ValueError) as err_info: |
||
361 | generator.validate_images_and_labels( |
||
362 | fixed_image=dummy_array, |
||
363 | moving_image=dummy_array, |
||
364 | moving_label=np.random.random(size=(100, 100)), |
||
365 | fixed_label=dummy_array, |
||
366 | image_indices=[1], |
||
367 | ) |
||
368 | assert "Sample [1]'s moving_label' shape should be 3D or 4D. " in str( |
||
369 | err_info.value |
||
370 | ) |
||
371 | with pytest.raises(ValueError) as err_info: |
||
372 | generator.validate_images_and_labels( |
||
373 | fixed_image=dummy_array, |
||
374 | moving_image=dummy_array, |
||
375 | moving_label=np.random.random(size=(100, 100, 100, 3)), |
||
376 | fixed_label=np.random.random(size=(100, 100, 100, 4)), |
||
377 | image_indices=[1], |
||
378 | ) |
||
379 | assert ( |
||
380 | "Sample [1]'s moving image and fixed image have different numbers of labels." |
||
381 | in str(err_info.value) |
||
382 | ) |
||
383 | |||
384 | # warning |
||
385 | caplog.clear() # clear previous log |
||
386 | generator.validate_images_and_labels( |
||
387 | fixed_image=dummy_array, |
||
388 | moving_image=dummy_array, |
||
389 | moving_label=np.random.random(size=(100, 100, 90)), |
||
390 | fixed_label=dummy_array, |
||
391 | image_indices=[1], |
||
392 | ) |
||
393 | assert "Sample [1]'s moving image and label have different shapes. " in caplog.text |
||
394 | caplog.clear() # clear previous log |
||
395 | generator.validate_images_and_labels( |
||
396 | fixed_image=dummy_array, |
||
397 | moving_image=dummy_array, |
||
398 | moving_label=dummy_array, |
||
399 | fixed_label=np.random.random(size=(100, 100, 90)), |
||
400 | image_indices=[1], |
||
401 | ) |
||
402 | assert "Sample [1]'s fixed image and label have different shapes. " in caplog.text |
||
403 | |||
404 | # test sample_image_label method |
||
405 | # for unlabeled input data |
||
406 | got = next( |
||
407 | generator.sample_image_label( |
||
408 | fixed_image=dummy_array, |
||
409 | moving_image=dummy_array, |
||
410 | moving_label=None, |
||
411 | fixed_label=None, |
||
412 | image_indices=[1], |
||
413 | ) |
||
414 | ) |
||
415 | expected = dict( |
||
416 | moving_image=dummy_array, |
||
417 | fixed_image=dummy_array, |
||
418 | indices=np.asarray([1] + [-1], dtype=np.float32), |
||
419 | ) |
||
420 | assert all(is_equal_np(got[key], expected[key]) for key in expected.keys()) |
||
421 | |||
422 | # for data with one label |
||
423 | got = next( |
||
424 | generator.sample_image_label( |
||
425 | fixed_image=dummy_array, |
||
426 | moving_image=dummy_array, |
||
427 | moving_label=dummy_array, |
||
428 | fixed_label=dummy_array, |
||
429 | image_indices=[1], |
||
430 | ) |
||
431 | ) |
||
432 | expected = dict( |
||
433 | moving_image=dummy_array, |
||
434 | fixed_image=dummy_array, |
||
435 | moving_label=dummy_array, |
||
436 | fixed_label=dummy_array, |
||
437 | indices=np.asarray([1] + [0], dtype=np.float32), |
||
438 | ) |
||
439 | assert all(is_equal_np(got[key], expected[key]) for key in expected.keys()) |
||
440 | |||
441 | # for data with multiple labels |
||
442 | dummy_labels = np.random.random(size=(100, 100, 100, 3)) |
||
443 | got = generator.sample_image_label( |
||
444 | fixed_image=dummy_array, |
||
445 | moving_image=dummy_array, |
||
446 | moving_label=dummy_labels, |
||
447 | fixed_label=dummy_labels, |
||
448 | image_indices=[1], |
||
449 | ) |
||
450 | for label_index in range(dummy_labels.shape[3]): |
||
451 | got_iter = next(got) |
||
452 | expected = dict( |
||
453 | moving_image=dummy_array, |
||
454 | fixed_image=dummy_array, |
||
455 | moving_label=dummy_labels[..., label_index], |
||
456 | fixed_label=dummy_labels[..., label_index], |
||
457 | indices=np.asarray([1] + [label_index], dtype=np.float32), |
||
458 | ) |
||
459 | assert all(is_equal_np(got_iter[key], expected[key]) for key in expected.keys()) |
||
460 | |||
515 |