| Conditions | 12 |
| Total Lines | 243 |
| Code Lines | 172 |
| Lines | 0 |
| Ratio | 0 % |
| Changes | 0 | ||
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like test.unit.test_interface.test_generator_data_loader() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | # coding=utf-8 |
||
| 217 | def test_generator_data_loader(caplog): |
||
| 218 | """ |
||
| 219 | Test the functions in GeneratorDataLoader |
||
| 220 | :param caplog: used to check warning message. |
||
| 221 | """ |
||
| 222 | generator = GeneratorDataLoader(labeled=True, num_indices=1, sample_label="all") |
||
| 223 | |||
| 224 | # test properties |
||
| 225 | assert generator.loader_moving_image is None |
||
| 226 | assert generator.loader_moving_image is None |
||
| 227 | assert generator.loader_moving_image is None |
||
| 228 | assert generator.loader_moving_image is None |
||
| 229 | |||
| 230 | # not implemented properties / functions |
||
| 231 | with pytest.raises(NotImplementedError): |
||
| 232 | generator.sample_index_generator() |
||
| 233 | |||
| 234 | # implemented functions |
||
| 235 | # test get_Dataset |
||
| 236 | dummy_array = np.random.random(size=(100, 100, 100)).astype(np.float32) |
||
| 237 | # for unlabeled data |
||
| 238 | # mock generator |
||
| 239 | sequence = [ |
||
| 240 | dict( |
||
| 241 | moving_image=dummy_array, |
||
| 242 | fixed_image=dummy_array, |
||
| 243 | moving_label=dummy_array, |
||
| 244 | fixed_label=dummy_array, |
||
| 245 | indices=[1], |
||
| 246 | ) |
||
| 247 | for i in range(3) |
||
| 248 | ] |
||
| 249 | |||
| 250 | def mock_generator(): |
||
| 251 | for el in sequence: |
||
| 252 | yield el |
||
| 253 | |||
| 254 | # inputs, no error means passed |
||
| 255 | generator.data_generator = mock_generator |
||
| 256 | dataset = generator.get_dataset() |
||
| 257 | |||
| 258 | # check dataset output |
||
| 259 | expected = dict( |
||
| 260 | moving_image=dummy_array, |
||
| 261 | fixed_image=dummy_array, |
||
| 262 | moving_label=dummy_array, |
||
| 263 | fixed_label=dummy_array, |
||
| 264 | indices=[1], |
||
| 265 | ) |
||
| 266 | for got in list(dataset.as_numpy_iterator()): |
||
| 267 | assert all(is_equal_np(got[key], expected[key]) for key in expected.keys()) |
||
| 268 | |||
| 269 | # for unlabeled data |
||
| 270 | generator_unlabeled = GeneratorDataLoader( |
||
| 271 | labeled=False, num_indices=1, sample_label="all" |
||
| 272 | ) |
||
| 273 | |||
| 274 | sequence = [ |
||
| 275 | dict(moving_image=dummy_array, fixed_image=dummy_array, indices=[1]) |
||
| 276 | for i in range(3) |
||
| 277 | ] |
||
| 278 | |||
| 279 | # inputs, no error means passed |
||
| 280 | generator_unlabeled.data_generator = mock_generator |
||
| 281 | dataset = generator_unlabeled.get_dataset() |
||
| 282 | |||
| 283 | # check dataset output |
||
| 284 | expected = dict(moving_image=dummy_array, fixed_image=dummy_array, indices=[1]) |
||
| 285 | for got in list(dataset.as_numpy_iterator()): |
||
| 286 | assert all(is_equal_np(got[key], expected[key]) for key in expected.keys()) |
||
| 287 | |||
| 288 | # test data_generator |
||
| 289 | # create mock data loader and sample index generator |
||
| 290 | class MockDataLoader: |
||
| 291 | def __init__(self, **kwargs): |
||
| 292 | super().__init__(**kwargs) |
||
| 293 | |||
| 294 | def get_data(index): |
||
| 295 | return dummy_array |
||
| 296 | |||
| 297 | def mock_sample_index_generator(): |
||
| 298 | return [[[1], [1], [1]]] |
||
| 299 | |||
| 300 | generator = GeneratorDataLoader(labeled=True, num_indices=1, sample_label="all") |
||
| 301 | generator.sample_index_generator = mock_sample_index_generator |
||
| 302 | generator.loader_moving_image = MockDataLoader |
||
| 303 | generator.loader_fixed_image = MockDataLoader |
||
| 304 | generator.loader_moving_label = MockDataLoader |
||
| 305 | generator.loader_fixed_label = MockDataLoader |
||
| 306 | |||
| 307 | # check data generator output |
||
| 308 | got = next(generator.data_generator()) |
||
| 309 | |||
| 310 | expected = dict( |
||
| 311 | moving_image=normalize_array(dummy_array), |
||
| 312 | fixed_image=normalize_array(dummy_array), |
||
| 313 | moving_label=dummy_array, |
||
| 314 | fixed_label=dummy_array, |
||
| 315 | indices=np.asarray([1] + [0], dtype=np.float32), |
||
| 316 | ) |
||
| 317 | assert all(is_equal_np(got[key], expected[key]) for key in expected.keys()) |
||
| 318 | |||
| 319 | # test validate_images_and_labels |
||
| 320 | with pytest.raises(ValueError) as err_info: |
||
| 321 | generator.validate_images_and_labels( |
||
| 322 | fixed_image=None, |
||
| 323 | moving_image=dummy_array, |
||
| 324 | moving_label=None, |
||
| 325 | fixed_label=None, |
||
| 326 | image_indices=[1], |
||
| 327 | ) |
||
| 328 | assert "moving image and fixed image must not be None" in str(err_info.value) |
||
| 329 | with pytest.raises(ValueError) as err_info: |
||
| 330 | generator.validate_images_and_labels( |
||
| 331 | fixed_image=dummy_array, |
||
| 332 | moving_image=dummy_array, |
||
| 333 | moving_label=dummy_array, |
||
| 334 | fixed_label=None, |
||
| 335 | image_indices=[1], |
||
| 336 | ) |
||
| 337 | assert "moving label and fixed label must be both None or non-None" in str( |
||
| 338 | err_info.value |
||
| 339 | ) |
||
| 340 | with pytest.raises(ValueError) as err_info: |
||
| 341 | generator.validate_images_and_labels( |
||
| 342 | fixed_image=dummy_array, |
||
| 343 | moving_image=dummy_array + 1.0, |
||
| 344 | moving_label=None, |
||
| 345 | fixed_label=None, |
||
| 346 | image_indices=[1], |
||
| 347 | ) |
||
| 348 | assert "Sample [1]'s moving_image's values are not between [0, 1]" in str( |
||
| 349 | err_info.value |
||
| 350 | ) |
||
| 351 | with pytest.raises(ValueError) as err_info: |
||
| 352 | generator.validate_images_and_labels( |
||
| 353 | fixed_image=dummy_array, |
||
| 354 | moving_image=np.random.random(size=(100, 100)), |
||
| 355 | moving_label=None, |
||
| 356 | fixed_label=None, |
||
| 357 | image_indices=[1], |
||
| 358 | ) |
||
| 359 | assert "Sample [1]'s moving_image' shape should be 3D. " in str(err_info.value) |
||
| 360 | with pytest.raises(ValueError) as err_info: |
||
| 361 | generator.validate_images_and_labels( |
||
| 362 | fixed_image=dummy_array, |
||
| 363 | moving_image=dummy_array, |
||
| 364 | moving_label=np.random.random(size=(100, 100)), |
||
| 365 | fixed_label=dummy_array, |
||
| 366 | image_indices=[1], |
||
| 367 | ) |
||
| 368 | assert "Sample [1]'s moving_label' shape should be 3D or 4D. " in str( |
||
| 369 | err_info.value |
||
| 370 | ) |
||
| 371 | with pytest.raises(ValueError) as err_info: |
||
| 372 | generator.validate_images_and_labels( |
||
| 373 | fixed_image=dummy_array, |
||
| 374 | moving_image=dummy_array, |
||
| 375 | moving_label=np.random.random(size=(100, 100, 100, 3)), |
||
| 376 | fixed_label=np.random.random(size=(100, 100, 100, 4)), |
||
| 377 | image_indices=[1], |
||
| 378 | ) |
||
| 379 | assert ( |
||
| 380 | "Sample [1]'s moving image and fixed image have different numbers of labels." |
||
| 381 | in str(err_info.value) |
||
| 382 | ) |
||
| 383 | |||
| 384 | # warning |
||
| 385 | caplog.clear() # clear previous log |
||
| 386 | generator.validate_images_and_labels( |
||
| 387 | fixed_image=dummy_array, |
||
| 388 | moving_image=dummy_array, |
||
| 389 | moving_label=np.random.random(size=(100, 100, 90)), |
||
| 390 | fixed_label=dummy_array, |
||
| 391 | image_indices=[1], |
||
| 392 | ) |
||
| 393 | assert "Sample [1]'s moving image and label have different shapes. " in caplog.text |
||
| 394 | caplog.clear() # clear previous log |
||
| 395 | generator.validate_images_and_labels( |
||
| 396 | fixed_image=dummy_array, |
||
| 397 | moving_image=dummy_array, |
||
| 398 | moving_label=dummy_array, |
||
| 399 | fixed_label=np.random.random(size=(100, 100, 90)), |
||
| 400 | image_indices=[1], |
||
| 401 | ) |
||
| 402 | assert "Sample [1]'s fixed image and label have different shapes. " in caplog.text |
||
| 403 | |||
| 404 | # test sample_image_label method |
||
| 405 | # for unlabeled input data |
||
| 406 | got = next( |
||
| 407 | generator.sample_image_label( |
||
| 408 | fixed_image=dummy_array, |
||
| 409 | moving_image=dummy_array, |
||
| 410 | moving_label=None, |
||
| 411 | fixed_label=None, |
||
| 412 | image_indices=[1], |
||
| 413 | ) |
||
| 414 | ) |
||
| 415 | expected = dict( |
||
| 416 | moving_image=dummy_array, |
||
| 417 | fixed_image=dummy_array, |
||
| 418 | indices=np.asarray([1] + [-1], dtype=np.float32), |
||
| 419 | ) |
||
| 420 | assert all(is_equal_np(got[key], expected[key]) for key in expected.keys()) |
||
| 421 | |||
| 422 | # for data with one label |
||
| 423 | got = next( |
||
| 424 | generator.sample_image_label( |
||
| 425 | fixed_image=dummy_array, |
||
| 426 | moving_image=dummy_array, |
||
| 427 | moving_label=dummy_array, |
||
| 428 | fixed_label=dummy_array, |
||
| 429 | image_indices=[1], |
||
| 430 | ) |
||
| 431 | ) |
||
| 432 | expected = dict( |
||
| 433 | moving_image=dummy_array, |
||
| 434 | fixed_image=dummy_array, |
||
| 435 | moving_label=dummy_array, |
||
| 436 | fixed_label=dummy_array, |
||
| 437 | indices=np.asarray([1] + [0], dtype=np.float32), |
||
| 438 | ) |
||
| 439 | assert all(is_equal_np(got[key], expected[key]) for key in expected.keys()) |
||
| 440 | |||
| 441 | # for data with multiple labels |
||
| 442 | dummy_labels = np.random.random(size=(100, 100, 100, 3)) |
||
| 443 | got = generator.sample_image_label( |
||
| 444 | fixed_image=dummy_array, |
||
| 445 | moving_image=dummy_array, |
||
| 446 | moving_label=dummy_labels, |
||
| 447 | fixed_label=dummy_labels, |
||
| 448 | image_indices=[1], |
||
| 449 | ) |
||
| 450 | for label_index in range(dummy_labels.shape[3]): |
||
| 451 | got_iter = next(got) |
||
| 452 | expected = dict( |
||
| 453 | moving_image=dummy_array, |
||
| 454 | fixed_image=dummy_array, |
||
| 455 | moving_label=dummy_labels[..., label_index], |
||
| 456 | fixed_label=dummy_labels[..., label_index], |
||
| 457 | indices=np.asarray([1] + [label_index], dtype=np.float32), |
||
| 458 | ) |
||
| 459 | assert all(is_equal_np(got_iter[key], expected[key]) for key in expected.keys()) |
||
| 460 | |||
| 515 |