Conditions | 12 |
Total Lines | 243 |
Code Lines | 172 |
Lines | 0 |
Ratio | 0 % |
Changes | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like test.unit.test_interface.test_generator_data_loader() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | # coding=utf-8 |
||
217 | |||
218 | def get_arr(shape: Tuple = (2, 3, 4), seed: Optional[int] = None) -> np.ndarray: |
||
219 | """ |
||
220 | Return a random array. |
||
221 | |||
222 | :param shape: shape of array. |
||
223 | :param seed: random seed. |
||
224 | :return: random array. |
||
225 | """ |
||
226 | np.random.seed(seed) |
||
227 | return np.random.random(size=shape).astype(np.float32) |
||
228 | |||
229 | |||
230 | class TestGeneratorDataLoader: |
||
231 | @pytest.mark.parametrize("labeled", [True, False]) |
||
232 | def test_get_labeled_dataset(self, labeled: bool): |
||
233 | """Test get_dataset with data loader.""" |
||
234 | sample = { |
||
235 | "moving_image": get_arr(), |
||
236 | "fixed_image": get_arr(), |
||
237 | "indices": [1], |
||
238 | } |
||
239 | if labeled: |
||
240 | sample = { |
||
241 | "moving_label": get_arr(), |
||
242 | "fixed_label": get_arr(), |
||
243 | **sample, |
||
244 | } |
||
245 | |||
246 | def mock_gen(): |
||
247 | """Toy data generator.""" |
||
248 | for _ in range(3): |
||
249 | yield sample |
||
250 | |||
251 | loader = GeneratorDataLoader(labeled=labeled, num_indices=1, sample_label="all") |
||
252 | loader.__setattr__("data_generator", mock_gen) |
||
253 | dataset = loader.get_dataset() |
||
254 | for got in dataset.as_numpy_iterator(): |
||
255 | assert all(is_equal_np(got[key], sample[key]) for key in sample.keys()) |
||
256 | |||
257 | @pytest.mark.parametrize("labeled", [True, False]) |
||
258 | def test_data_generator(self, labeled: bool): |
||
259 | """ |
||
260 | Test data_generator() |
||
261 | |||
262 | :param labeled: labeled data or not. |
||
263 | """ |
||
264 | |||
265 | class MockDataLoader: |
||
266 | """Toy data loader.""" |
||
267 | |||
268 | def __init__(self, seed: int): |
||
269 | """ |
||
270 | Init. |
||
271 | |||
272 | :param seed: random seed for numpy. |
||
273 | :param kwargs: additional arguments. |
||
274 | """ |
||
275 | self.seed = seed |
||
276 | |||
277 | def get_data(self, index: int) -> np.ndarray: |
||
278 | """ |
||
279 | Return the dummy array despite of the index. |
||
280 | |||
281 | :param index: not used |
||
282 | :return: dummy array. |
||
283 | """ |
||
284 | assert isinstance(index, int) |
||
285 | return get_arr(seed=self.seed) |
||
286 | |||
287 | def mock_sample_index_generator(): |
||
288 | """Toy sample index generator.""" |
||
289 | return [[1, 1, [1]]] |
||
290 | |||
291 | loader = GeneratorDataLoader(labeled=labeled, num_indices=1, sample_label="all") |
||
292 | loader.__setattr__("sample_index_generator", mock_sample_index_generator) |
||
293 | loader.loader_moving_image = MockDataLoader(seed=0) |
||
294 | loader.loader_fixed_image = MockDataLoader(seed=1) |
||
295 | if labeled: |
||
296 | loader.loader_moving_label = MockDataLoader(seed=2) |
||
297 | loader.loader_fixed_label = MockDataLoader(seed=3) |
||
298 | |||
299 | # check data loader output |
||
300 | got = next(loader.data_generator()) |
||
301 | |||
302 | expected = { |
||
303 | "moving_image": normalize_array(get_arr(seed=0)), |
||
304 | "fixed_image": normalize_array(get_arr(seed=1)), |
||
305 | # 0 or -1 is the label index |
||
306 | "indices": np.array([1, 0] if labeled else [1, -1], dtype=np.float32), |
||
307 | } |
||
308 | if labeled: |
||
309 | expected = { |
||
310 | "moving_label": get_arr(seed=2), |
||
311 | "fixed_label": get_arr(seed=3), |
||
312 | **expected, |
||
313 | } |
||
314 | assert all(is_equal_np(got[key], expected[key]) for key in expected.keys()) |
||
315 | |||
316 | def test_sample_index_generator(self): |
||
317 | loader = GeneratorDataLoader(labeled=True, num_indices=1, sample_label="all") |
||
318 | with pytest.raises(NotImplementedError): |
||
319 | loader.sample_index_generator() |
||
320 | |||
321 | @pytest.mark.parametrize( |
||
322 | ( |
||
323 | "moving_image_shape", |
||
324 | "fixed_image_shape", |
||
325 | "moving_label_shape", |
||
326 | "fixed_label_shape", |
||
327 | "err_msg", |
||
328 | ), |
||
329 | [ |
||
330 | ( |
||
331 | None, |
||
332 | (10, 10, 10), |
||
333 | (10, 10, 10), |
||
334 | (10, 10, 10), |
||
335 | "moving image and fixed image must not be None", |
||
336 | ), |
||
337 | ( |
||
338 | (10, 10, 10), |
||
339 | None, |
||
340 | (10, 10, 10), |
||
341 | (10, 10, 10), |
||
342 | "moving image and fixed image must not be None", |
||
343 | ), |
||
344 | ( |
||
345 | (10, 10, 10), |
||
346 | (10, 10, 10), |
||
347 | None, |
||
348 | (10, 10, 10), |
||
349 | "moving label and fixed label must be both None or non-None", |
||
350 | ), |
||
351 | ( |
||
352 | (10, 10, 10), |
||
353 | (10, 10, 10), |
||
354 | (10, 10, 10), |
||
355 | None, |
||
356 | "moving label and fixed label must be both None or non-None", |
||
357 | ), |
||
358 | ( |
||
359 | (10, 10), |
||
360 | (10, 10, 10), |
||
361 | (10, 10, 10), |
||
362 | (10, 10, 10), |
||
363 | "Sample [1]'s moving_image's shape should be 3D", |
||
364 | ), |
||
365 | ( |
||
366 | (10, 10, 10), |
||
367 | (10, 10), |
||
368 | (10, 10, 10), |
||
369 | (10, 10, 10), |
||
370 | "Sample [1]'s fixed_image's shape should be 3D", |
||
371 | ), |
||
372 | ( |
||
373 | (10, 10, 10), |
||
374 | (10, 10, 10), |
||
375 | (10, 10), |
||
376 | (10, 10, 10), |
||
377 | "Sample [1]'s moving_label's shape should be 3D or 4D.", |
||
378 | ), |
||
379 | ( |
||
380 | (10, 10, 10), |
||
381 | (10, 10, 10), |
||
382 | (10, 10, 10), |
||
383 | (10, 10), |
||
384 | "Sample [1]'s fixed_label's shape should be 3D or 4D.", |
||
385 | ), |
||
386 | ( |
||
387 | (10, 10, 10), |
||
388 | (10, 10, 10), |
||
389 | (10, 10, 10, 2), |
||
390 | (10, 10, 10, 3), |
||
391 | "Sample [1]'s moving image and fixed image have different numbers of labels.", |
||
392 | ), |
||
393 | ], |
||
394 | ) |
||
395 | def test_validate_images_and_labels( |
||
396 | self, |
||
397 | moving_image_shape: Optional[Tuple], |
||
398 | fixed_image_shape: Optional[Tuple], |
||
399 | moving_label_shape: Optional[Tuple], |
||
400 | fixed_label_shape: Optional[Tuple], |
||
401 | err_msg: str, |
||
402 | ): |
||
403 | """ |
||
404 | Test error messages. |
||
405 | |||
406 | :param moving_image_shape: None or tuple. |
||
407 | :param fixed_image_shape: None or tuple. |
||
408 | :param moving_label_shape: None or tuple. |
||
409 | :param fixed_label_shape: None or tuple. |
||
410 | :param err_msg: message. |
||
411 | """ |
||
412 | moving_image = None |
||
413 | fixed_image = None |
||
414 | moving_label = None |
||
415 | fixed_label = None |
||
416 | if moving_image_shape: |
||
417 | moving_image = get_arr(shape=moving_image_shape) |
||
418 | if fixed_image_shape: |
||
419 | fixed_image = get_arr(shape=fixed_image_shape) |
||
420 | if moving_label_shape: |
||
421 | moving_label = get_arr(shape=moving_label_shape) |
||
422 | if fixed_label_shape: |
||
423 | fixed_label = get_arr(shape=fixed_label_shape) |
||
424 | loader = GeneratorDataLoader(labeled=True, num_indices=1, sample_label="all") |
||
425 | with pytest.raises(ValueError) as err_info: |
||
426 | loader.validate_images_and_labels( |
||
427 | moving_image=moving_image, |
||
428 | fixed_image=fixed_image, |
||
429 | moving_label=moving_label, |
||
430 | fixed_label=fixed_label, |
||
431 | image_indices=[1], |
||
432 | ) |
||
433 | assert err_msg in str(err_info.value) |
||
434 | |||
435 | @pytest.mark.parametrize("option", [0, 1, 2, 3]) |
||
436 | def test_validate_images_and_labels_range(self, option: int): |
||
437 | """ |
||
438 | Test error messages related to input range. |
||
439 | |||
440 | :param option: control which image to modify |
||
441 | """ |
||
442 | option_to_name = { |
||
443 | 0: "moving_image", |
||
444 | 1: "fixed_image", |
||
445 | 2: "moving_label", |
||
446 | 3: "fixed_label", |
||
447 | } |
||
448 | input = { |
||
449 | "moving_image": get_arr(), |
||
450 | "fixed_image": get_arr(), |
||
451 | "moving_label": get_arr(), |
||
452 | "fixed_label": get_arr(), |
||
453 | } |
||
454 | name = option_to_name[option] |
||
455 | input[name] += 1 |
||
456 | err_msg = f"Sample [1]'s {name}'s values are not between [0, 1]" |
||
457 | |||
458 | loader = GeneratorDataLoader(labeled=True, num_indices=1, sample_label="all") |
||
459 | with pytest.raises(ValueError) as err_info: |
||
460 | loader.validate_images_and_labels( |
||
589 |