| Conditions | 12 |
| Total Lines | 243 |
| Code Lines | 172 |
| Lines | 0 |
| Ratio | 0 % |
| Changes | 0 | ||
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like test.unit.test_interface.test_generator_data_loader() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | # coding=utf-8 |
||
| 217 | |||
| 218 | def get_arr(shape: Tuple = (2, 3, 4), seed: Optional[int] = None) -> np.ndarray: |
||
| 219 | """ |
||
| 220 | Return a random array. |
||
| 221 | |||
| 222 | :param shape: shape of array. |
||
| 223 | :param seed: random seed. |
||
| 224 | :return: random array. |
||
| 225 | """ |
||
| 226 | np.random.seed(seed) |
||
| 227 | return np.random.random(size=shape).astype(np.float32) |
||
| 228 | |||
| 229 | |||
| 230 | class TestGeneratorDataLoader: |
||
| 231 | @pytest.mark.parametrize("labeled", [True, False]) |
||
| 232 | def test_get_labeled_dataset(self, labeled: bool): |
||
| 233 | """ |
||
| 234 | Test get_dataset with data loader. |
||
| 235 | |||
| 236 | :param labeled: labeled data or not. |
||
| 237 | """ |
||
| 238 | sample = { |
||
| 239 | "moving_image": get_arr(), |
||
| 240 | "fixed_image": get_arr(), |
||
| 241 | "indices": [1], |
||
| 242 | } |
||
| 243 | if labeled: |
||
| 244 | sample = { |
||
| 245 | "moving_label": get_arr(), |
||
| 246 | "fixed_label": get_arr(), |
||
| 247 | **sample, |
||
| 248 | } |
||
| 249 | |||
| 250 | def mock_gen(): |
||
| 251 | """Toy data generator.""" |
||
| 252 | for _ in range(3): |
||
| 253 | yield sample |
||
| 254 | |||
| 255 | loader = GeneratorDataLoader(labeled=labeled, num_indices=1, sample_label="all") |
||
| 256 | loader.__setattr__("data_generator", mock_gen) |
||
| 257 | dataset = loader.get_dataset() |
||
| 258 | for got in dataset.as_numpy_iterator(): |
||
| 259 | assert all(is_equal_np(got[key], sample[key]) for key in sample.keys()) |
||
| 260 | |||
| 261 | @pytest.mark.parametrize("labeled", [True, False]) |
||
| 262 | def test_data_generator(self, labeled: bool): |
||
| 263 | """ |
||
| 264 | Test data_generator() |
||
| 265 | |||
| 266 | :param labeled: labeled data or not. |
||
| 267 | """ |
||
| 268 | |||
| 269 | class MockDataLoader: |
||
| 270 | """Toy data loader.""" |
||
| 271 | |||
| 272 | def __init__(self, seed: int): |
||
| 273 | """ |
||
| 274 | Init. |
||
| 275 | |||
| 276 | :param seed: random seed for numpy. |
||
| 277 | :param kwargs: additional arguments. |
||
| 278 | """ |
||
| 279 | self.seed = seed |
||
| 280 | |||
| 281 | def get_data(self, index: int) -> np.ndarray: |
||
| 282 | """ |
||
| 283 | Return the dummy array despite of the index. |
||
| 284 | |||
| 285 | :param index: not used |
||
| 286 | :return: dummy array. |
||
| 287 | """ |
||
| 288 | assert isinstance(index, int) |
||
| 289 | return get_arr(seed=self.seed) |
||
| 290 | |||
| 291 | def mock_sample_index_generator(): |
||
| 292 | """Toy sample index generator.""" |
||
| 293 | return [[1, 1, [1]]] |
||
| 294 | |||
| 295 | loader = GeneratorDataLoader(labeled=labeled, num_indices=1, sample_label="all") |
||
| 296 | loader.__setattr__("sample_index_generator", mock_sample_index_generator) |
||
| 297 | loader.loader_moving_image = MockDataLoader(seed=0) |
||
| 298 | loader.loader_fixed_image = MockDataLoader(seed=1) |
||
| 299 | if labeled: |
||
| 300 | loader.loader_moving_label = MockDataLoader(seed=2) |
||
| 301 | loader.loader_fixed_label = MockDataLoader(seed=3) |
||
| 302 | |||
| 303 | # check data loader output |
||
| 304 | got = next(loader.data_generator()) |
||
| 305 | |||
| 306 | expected = { |
||
| 307 | "moving_image": normalize_array(get_arr(seed=0)), |
||
| 308 | "fixed_image": normalize_array(get_arr(seed=1)), |
||
| 309 | # 0 or -1 is the label index |
||
| 310 | "indices": np.array([1, 0] if labeled else [1, -1], dtype=np.float32), |
||
| 311 | } |
||
| 312 | if labeled: |
||
| 313 | expected = { |
||
| 314 | "moving_label": get_arr(seed=2), |
||
| 315 | "fixed_label": get_arr(seed=3), |
||
| 316 | **expected, |
||
| 317 | } |
||
| 318 | assert all(is_equal_np(got[key], expected[key]) for key in expected.keys()) |
||
| 319 | |||
| 320 | def test_sample_index_generator(self): |
||
| 321 | loader = GeneratorDataLoader(labeled=True, num_indices=1, sample_label="all") |
||
| 322 | with pytest.raises(NotImplementedError): |
||
| 323 | loader.sample_index_generator() |
||
| 324 | |||
| 325 | @pytest.mark.parametrize( |
||
| 326 | ( |
||
| 327 | "moving_image_shape", |
||
| 328 | "fixed_image_shape", |
||
| 329 | "moving_label_shape", |
||
| 330 | "fixed_label_shape", |
||
| 331 | "err_msg", |
||
| 332 | ), |
||
| 333 | [ |
||
| 334 | ( |
||
| 335 | None, |
||
| 336 | (10, 10, 10), |
||
| 337 | (10, 10, 10), |
||
| 338 | (10, 10, 10), |
||
| 339 | "moving image and fixed image must not be None", |
||
| 340 | ), |
||
| 341 | ( |
||
| 342 | (10, 10, 10), |
||
| 343 | None, |
||
| 344 | (10, 10, 10), |
||
| 345 | (10, 10, 10), |
||
| 346 | "moving image and fixed image must not be None", |
||
| 347 | ), |
||
| 348 | ( |
||
| 349 | (10, 10, 10), |
||
| 350 | (10, 10, 10), |
||
| 351 | None, |
||
| 352 | (10, 10, 10), |
||
| 353 | "moving label and fixed label must be both None or non-None", |
||
| 354 | ), |
||
| 355 | ( |
||
| 356 | (10, 10, 10), |
||
| 357 | (10, 10, 10), |
||
| 358 | (10, 10, 10), |
||
| 359 | None, |
||
| 360 | "moving label and fixed label must be both None or non-None", |
||
| 361 | ), |
||
| 362 | ( |
||
| 363 | (10, 10), |
||
| 364 | (10, 10, 10), |
||
| 365 | (10, 10, 10), |
||
| 366 | (10, 10, 10), |
||
| 367 | "Sample [1]'s moving_image's shape should be 3D", |
||
| 368 | ), |
||
| 369 | ( |
||
| 370 | (10, 10, 10), |
||
| 371 | (10, 10), |
||
| 372 | (10, 10, 10), |
||
| 373 | (10, 10, 10), |
||
| 374 | "Sample [1]'s fixed_image's shape should be 3D", |
||
| 375 | ), |
||
| 376 | ( |
||
| 377 | (10, 10, 10), |
||
| 378 | (10, 10, 10), |
||
| 379 | (10, 10), |
||
| 380 | (10, 10, 10), |
||
| 381 | "Sample [1]'s moving_label's shape should be 3D or 4D.", |
||
| 382 | ), |
||
| 383 | ( |
||
| 384 | (10, 10, 10), |
||
| 385 | (10, 10, 10), |
||
| 386 | (10, 10, 10), |
||
| 387 | (10, 10), |
||
| 388 | "Sample [1]'s fixed_label's shape should be 3D or 4D.", |
||
| 389 | ), |
||
| 390 | ( |
||
| 391 | (10, 10, 10), |
||
| 392 | (10, 10, 10), |
||
| 393 | (10, 10, 10, 2), |
||
| 394 | (10, 10, 10, 3), |
||
| 395 | "Sample [1]'s moving image and fixed image " |
||
| 396 | "have different numbers of labels.", |
||
| 397 | ), |
||
| 398 | ], |
||
| 399 | ) |
||
| 400 | def test_validate_images_and_labels( |
||
| 401 | self, |
||
| 402 | moving_image_shape: Optional[Tuple], |
||
| 403 | fixed_image_shape: Optional[Tuple], |
||
| 404 | moving_label_shape: Optional[Tuple], |
||
| 405 | fixed_label_shape: Optional[Tuple], |
||
| 406 | err_msg: str, |
||
| 407 | ): |
||
| 408 | """ |
||
| 409 | Test error messages. |
||
| 410 | |||
| 411 | :param moving_image_shape: None or tuple. |
||
| 412 | :param fixed_image_shape: None or tuple. |
||
| 413 | :param moving_label_shape: None or tuple. |
||
| 414 | :param fixed_label_shape: None or tuple. |
||
| 415 | :param err_msg: message. |
||
| 416 | """ |
||
| 417 | moving_image = None |
||
| 418 | fixed_image = None |
||
| 419 | moving_label = None |
||
| 420 | fixed_label = None |
||
| 421 | if moving_image_shape: |
||
| 422 | moving_image = get_arr(shape=moving_image_shape) |
||
| 423 | if fixed_image_shape: |
||
| 424 | fixed_image = get_arr(shape=fixed_image_shape) |
||
| 425 | if moving_label_shape: |
||
| 426 | moving_label = get_arr(shape=moving_label_shape) |
||
| 427 | if fixed_label_shape: |
||
| 428 | fixed_label = get_arr(shape=fixed_label_shape) |
||
| 429 | loader = GeneratorDataLoader(labeled=True, num_indices=1, sample_label="all") |
||
| 430 | with pytest.raises(ValueError) as err_info: |
||
| 431 | loader.validate_images_and_labels( |
||
| 432 | moving_image=moving_image, |
||
| 433 | fixed_image=fixed_image, |
||
| 434 | moving_label=moving_label, |
||
| 435 | fixed_label=fixed_label, |
||
| 436 | image_indices=[1], |
||
| 437 | ) |
||
| 438 | assert err_msg in str(err_info.value) |
||
| 439 | |||
| 440 | @pytest.mark.parametrize("option", [0, 1, 2, 3]) |
||
| 441 | def test_validate_images_and_labels_range(self, option: int): |
||
| 442 | """ |
||
| 443 | Test error messages related to input range. |
||
| 444 | |||
| 445 | :param option: control which image to modify |
||
| 446 | """ |
||
| 447 | option_to_name = { |
||
| 448 | 0: "moving_image", |
||
| 449 | 1: "fixed_image", |
||
| 450 | 2: "moving_label", |
||
| 451 | 3: "fixed_label", |
||
| 452 | } |
||
| 453 | input = { |
||
| 454 | "moving_image": get_arr(), |
||
| 455 | "fixed_image": get_arr(), |
||
| 456 | "moving_label": get_arr(), |
||
| 457 | "fixed_label": get_arr(), |
||
| 458 | } |
||
| 459 | name = option_to_name[option] |
||
| 460 | input[name] += 1 |
||
| 594 |