Conditions | 12 |
Total Lines | 243 |
Code Lines | 172 |
Lines | 0 |
Ratio | 0 % |
Changes | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like test.unit.test_interface.test_generator_data_loader() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | # coding=utf-8 |
||
217 | |||
218 | def get_arr(shape: Tuple = (2, 3, 4), seed: Optional[int] = None) -> np.ndarray: |
||
219 | """ |
||
220 | Return a random array. |
||
221 | |||
222 | :param shape: shape of array. |
||
223 | :param seed: random seed. |
||
224 | :return: random array. |
||
225 | """ |
||
226 | np.random.seed(seed) |
||
227 | return np.random.random(size=shape).astype(np.float32) |
||
228 | |||
229 | |||
230 | class TestGeneratorDataLoader: |
||
231 | @pytest.mark.parametrize("labeled", [True, False]) |
||
232 | def test_get_labeled_dataset(self, labeled: bool): |
||
233 | """ |
||
234 | Test get_dataset with data loader. |
||
235 | |||
236 | :param labeled: labeled data or not. |
||
237 | """ |
||
238 | sample = { |
||
239 | "moving_image": get_arr(), |
||
240 | "fixed_image": get_arr(), |
||
241 | "indices": [1], |
||
242 | } |
||
243 | if labeled: |
||
244 | sample = { |
||
245 | "moving_label": get_arr(), |
||
246 | "fixed_label": get_arr(), |
||
247 | **sample, |
||
248 | } |
||
249 | |||
250 | def mock_gen(): |
||
251 | """Toy data generator.""" |
||
252 | for _ in range(3): |
||
253 | yield sample |
||
254 | |||
255 | loader = GeneratorDataLoader(labeled=labeled, num_indices=1, sample_label="all") |
||
256 | loader.__setattr__("data_generator", mock_gen) |
||
257 | dataset = loader.get_dataset() |
||
258 | for got in dataset.as_numpy_iterator(): |
||
259 | assert all(is_equal_np(got[key], sample[key]) for key in sample.keys()) |
||
260 | |||
261 | @pytest.mark.parametrize("labeled", [True, False]) |
||
262 | def test_data_generator(self, labeled: bool): |
||
263 | """ |
||
264 | Test data_generator() |
||
265 | |||
266 | :param labeled: labeled data or not. |
||
267 | """ |
||
268 | |||
269 | class MockDataLoader: |
||
270 | """Toy data loader.""" |
||
271 | |||
272 | def __init__(self, seed: int): |
||
273 | """ |
||
274 | Init. |
||
275 | |||
276 | :param seed: random seed for numpy. |
||
277 | :param kwargs: additional arguments. |
||
278 | """ |
||
279 | self.seed = seed |
||
280 | |||
281 | def get_data(self, index: int) -> np.ndarray: |
||
282 | """ |
||
283 | Return the dummy array despite of the index. |
||
284 | |||
285 | :param index: not used |
||
286 | :return: dummy array. |
||
287 | """ |
||
288 | assert isinstance(index, int) |
||
289 | return get_arr(seed=self.seed) |
||
290 | |||
291 | def mock_sample_index_generator(): |
||
292 | """Toy sample index generator.""" |
||
293 | return [[1, 1, [1]]] |
||
294 | |||
295 | loader = GeneratorDataLoader(labeled=labeled, num_indices=1, sample_label="all") |
||
296 | loader.__setattr__("sample_index_generator", mock_sample_index_generator) |
||
297 | loader.loader_moving_image = MockDataLoader(seed=0) |
||
298 | loader.loader_fixed_image = MockDataLoader(seed=1) |
||
299 | if labeled: |
||
300 | loader.loader_moving_label = MockDataLoader(seed=2) |
||
301 | loader.loader_fixed_label = MockDataLoader(seed=3) |
||
302 | |||
303 | # check data loader output |
||
304 | got = next(loader.data_generator()) |
||
305 | |||
306 | expected = { |
||
307 | "moving_image": normalize_array(get_arr(seed=0)), |
||
308 | "fixed_image": normalize_array(get_arr(seed=1)), |
||
309 | # 0 or -1 is the label index |
||
310 | "indices": np.array([1, 0] if labeled else [1, -1], dtype=np.float32), |
||
311 | } |
||
312 | if labeled: |
||
313 | expected = { |
||
314 | "moving_label": get_arr(seed=2), |
||
315 | "fixed_label": get_arr(seed=3), |
||
316 | **expected, |
||
317 | } |
||
318 | assert all(is_equal_np(got[key], expected[key]) for key in expected.keys()) |
||
319 | |||
320 | def test_sample_index_generator(self): |
||
321 | loader = GeneratorDataLoader(labeled=True, num_indices=1, sample_label="all") |
||
322 | with pytest.raises(NotImplementedError): |
||
323 | loader.sample_index_generator() |
||
324 | |||
325 | @pytest.mark.parametrize( |
||
326 | ( |
||
327 | "moving_image_shape", |
||
328 | "fixed_image_shape", |
||
329 | "moving_label_shape", |
||
330 | "fixed_label_shape", |
||
331 | "err_msg", |
||
332 | ), |
||
333 | [ |
||
334 | ( |
||
335 | None, |
||
336 | (10, 10, 10), |
||
337 | (10, 10, 10), |
||
338 | (10, 10, 10), |
||
339 | "moving image and fixed image must not be None", |
||
340 | ), |
||
341 | ( |
||
342 | (10, 10, 10), |
||
343 | None, |
||
344 | (10, 10, 10), |
||
345 | (10, 10, 10), |
||
346 | "moving image and fixed image must not be None", |
||
347 | ), |
||
348 | ( |
||
349 | (10, 10, 10), |
||
350 | (10, 10, 10), |
||
351 | None, |
||
352 | (10, 10, 10), |
||
353 | "moving label and fixed label must be both None or non-None", |
||
354 | ), |
||
355 | ( |
||
356 | (10, 10, 10), |
||
357 | (10, 10, 10), |
||
358 | (10, 10, 10), |
||
359 | None, |
||
360 | "moving label and fixed label must be both None or non-None", |
||
361 | ), |
||
362 | ( |
||
363 | (10, 10), |
||
364 | (10, 10, 10), |
||
365 | (10, 10, 10), |
||
366 | (10, 10, 10), |
||
367 | "Sample [1]'s moving_image's shape should be 3D", |
||
368 | ), |
||
369 | ( |
||
370 | (10, 10, 10), |
||
371 | (10, 10), |
||
372 | (10, 10, 10), |
||
373 | (10, 10, 10), |
||
374 | "Sample [1]'s fixed_image's shape should be 3D", |
||
375 | ), |
||
376 | ( |
||
377 | (10, 10, 10), |
||
378 | (10, 10, 10), |
||
379 | (10, 10), |
||
380 | (10, 10, 10), |
||
381 | "Sample [1]'s moving_label's shape should be 3D or 4D.", |
||
382 | ), |
||
383 | ( |
||
384 | (10, 10, 10), |
||
385 | (10, 10, 10), |
||
386 | (10, 10, 10), |
||
387 | (10, 10), |
||
388 | "Sample [1]'s fixed_label's shape should be 3D or 4D.", |
||
389 | ), |
||
390 | ( |
||
391 | (10, 10, 10), |
||
392 | (10, 10, 10), |
||
393 | (10, 10, 10, 2), |
||
394 | (10, 10, 10, 3), |
||
395 | "Sample [1]'s moving image and fixed image " |
||
396 | "have different numbers of labels.", |
||
397 | ), |
||
398 | ], |
||
399 | ) |
||
400 | def test_validate_images_and_labels( |
||
401 | self, |
||
402 | moving_image_shape: Optional[Tuple], |
||
403 | fixed_image_shape: Optional[Tuple], |
||
404 | moving_label_shape: Optional[Tuple], |
||
405 | fixed_label_shape: Optional[Tuple], |
||
406 | err_msg: str, |
||
407 | ): |
||
408 | """ |
||
409 | Test error messages. |
||
410 | |||
411 | :param moving_image_shape: None or tuple. |
||
412 | :param fixed_image_shape: None or tuple. |
||
413 | :param moving_label_shape: None or tuple. |
||
414 | :param fixed_label_shape: None or tuple. |
||
415 | :param err_msg: message. |
||
416 | """ |
||
417 | moving_image = None |
||
418 | fixed_image = None |
||
419 | moving_label = None |
||
420 | fixed_label = None |
||
421 | if moving_image_shape: |
||
422 | moving_image = get_arr(shape=moving_image_shape) |
||
423 | if fixed_image_shape: |
||
424 | fixed_image = get_arr(shape=fixed_image_shape) |
||
425 | if moving_label_shape: |
||
426 | moving_label = get_arr(shape=moving_label_shape) |
||
427 | if fixed_label_shape: |
||
428 | fixed_label = get_arr(shape=fixed_label_shape) |
||
429 | loader = GeneratorDataLoader(labeled=True, num_indices=1, sample_label="all") |
||
430 | with pytest.raises(ValueError) as err_info: |
||
431 | loader.validate_images_and_labels( |
||
432 | moving_image=moving_image, |
||
433 | fixed_image=fixed_image, |
||
434 | moving_label=moving_label, |
||
435 | fixed_label=fixed_label, |
||
436 | image_indices=[1], |
||
437 | ) |
||
438 | assert err_msg in str(err_info.value) |
||
439 | |||
440 | @pytest.mark.parametrize("option", [0, 1, 2, 3]) |
||
441 | def test_validate_images_and_labels_range(self, option: int): |
||
442 | """ |
||
443 | Test error messages related to input range. |
||
444 | |||
445 | :param option: control which image to modify |
||
446 | """ |
||
447 | option_to_name = { |
||
448 | 0: "moving_image", |
||
449 | 1: "fixed_image", |
||
450 | 2: "moving_label", |
||
451 | 3: "fixed_label", |
||
452 | } |
||
453 | input = { |
||
454 | "moving_image": get_arr(), |
||
455 | "fixed_image": get_arr(), |
||
456 | "moving_label": get_arr(), |
||
457 | "fixed_label": get_arr(), |
||
458 | } |
||
459 | name = option_to_name[option] |
||
460 | input[name] += 1 |
||
594 |