@@ 390-452 (lines=63) @@ | ||
387 | out_channels=out_channels, |
|
388 | ) |
|
389 | ||
390 | def build_encode_layers( |
|
391 | self, |
|
392 | image_size: Tuple, |
|
393 | num_channels: Tuple, |
|
394 | depth: int, |
|
395 | encode_kernel_sizes: Union[int, List[int]], |
|
396 | strides: int, |
|
397 | padding: str, |
|
398 | ) -> List[Tuple]: |
|
399 | """ |
|
400 | Build layers for encoding. |
|
401 | ||
402 | :param image_size: (dim1, dim2, dim3). |
|
403 | :param num_channels: number of channels for each layer, |
|
404 | starting from the top layer. |
|
405 | :param depth: network starts with d = 0, and the bottom has d = depth. |
|
406 | :param encode_kernel_sizes: kernel size for down-sampling |
|
407 | :param strides: strides for down-sampling |
|
408 | :param padding: padding mode for all conv layers |
|
409 | :return: list of tensor shapes starting from d = 0 |
|
410 | """ |
|
411 | if isinstance(encode_kernel_sizes, int): |
|
412 | encode_kernel_sizes = [encode_kernel_sizes] * (depth + 1) |
|
413 | assert len(encode_kernel_sizes) == depth + 1 |
|
414 | ||
415 | # encoding / down-sampling |
|
416 | self._encode_convs = [] |
|
417 | self._encode_pools = [] |
|
418 | tensor_shape = image_size |
|
419 | tensor_shapes = [tensor_shape] |
|
420 | for d in range(depth): |
|
421 | encode_conv = self.build_encode_conv_block( |
|
422 | filters=num_channels[d], |
|
423 | kernel_size=encode_kernel_sizes[d], |
|
424 | padding=padding, |
|
425 | ) |
|
426 | encode_pool = self.build_down_sampling_block( |
|
427 | filters=num_channels[d], |
|
428 | kernel_size=strides, |
|
429 | strides=strides, |
|
430 | padding=padding, |
|
431 | ) |
|
432 | tensor_shape = tuple( |
|
433 | conv_utils.conv_output_length( |
|
434 | input_length=x, |
|
435 | filter_size=strides, |
|
436 | padding=padding, |
|
437 | stride=strides, |
|
438 | dilation=1, |
|
439 | ) |
|
440 | for x in tensor_shape |
|
441 | ) |
|
442 | self._encode_convs.append(encode_conv) |
|
443 | self._encode_pools.append(encode_pool) |
|
444 | tensor_shapes.append(tensor_shape) |
|
445 | ||
446 | # bottom layer |
|
447 | self._bottom_block = self.build_bottom_block( |
|
448 | filters=num_channels[depth], |
|
449 | kernel_size=encode_kernel_sizes[depth], |
|
450 | padding=padding, |
|
451 | ) |
|
452 | return tensor_shapes |
|
453 | ||
454 | def build_decode_layers( |
|
455 | self, |
@@ 384-446 (lines=63) @@ | ||
381 | out_activation=out_activation, |
|
382 | ) |
|
383 | ||
384 | def build_encode_layers( |
|
385 | self, |
|
386 | image_size: Tuple, |
|
387 | num_channels: Tuple, |
|
388 | depth: int, |
|
389 | encode_kernel_sizes: Union[int, List[int]], |
|
390 | strides: int, |
|
391 | padding: str, |
|
392 | ) -> List[Tuple]: |
|
393 | """ |
|
394 | Build layers for encoding. |
|
395 | ||
396 | :param image_size: (dim1, dim2, dim3). |
|
397 | :param num_channels: number of channels for each layer, |
|
398 | starting from the top layer. |
|
399 | :param depth: network starts with d = 0, and the bottom has d = depth. |
|
400 | :param encode_kernel_sizes: kernel size for down-sampling |
|
401 | :param strides: strides for down-sampling |
|
402 | :param padding: padding mode for all conv layers |
|
403 | :return: list of tensor shapes starting from d = 0 |
|
404 | """ |
|
405 | if isinstance(encode_kernel_sizes, int): |
|
406 | encode_kernel_sizes = [encode_kernel_sizes] * (depth + 1) |
|
407 | assert len(encode_kernel_sizes) == depth + 1 |
|
408 | ||
409 | # encoding / down-sampling |
|
410 | self._encode_convs = [] |
|
411 | self._encode_pools = [] |
|
412 | tensor_shape = image_size |
|
413 | tensor_shapes = [tensor_shape] |
|
414 | for d in range(depth): |
|
415 | encode_conv = self.build_encode_conv_block( |
|
416 | filters=num_channels[d], |
|
417 | kernel_size=encode_kernel_sizes[d], |
|
418 | padding=padding, |
|
419 | ) |
|
420 | encode_pool = self.build_down_sampling_block( |
|
421 | filters=num_channels[d], |
|
422 | kernel_size=strides, |
|
423 | strides=strides, |
|
424 | padding=padding, |
|
425 | ) |
|
426 | tensor_shape = tuple( |
|
427 | conv_utils.conv_output_length( |
|
428 | input_length=x, |
|
429 | filter_size=strides, |
|
430 | padding=padding, |
|
431 | stride=strides, |
|
432 | dilation=1, |
|
433 | ) |
|
434 | for x in tensor_shape |
|
435 | ) |
|
436 | self._encode_convs.append(encode_conv) |
|
437 | self._encode_pools.append(encode_pool) |
|
438 | tensor_shapes.append(tensor_shape) |
|
439 | ||
440 | # bottom layer |
|
441 | self._bottom_block = self.build_bottom_block( |
|
442 | filters=num_channels[depth], |
|
443 | kernel_size=encode_kernel_sizes[depth], |
|
444 | padding=padding, |
|
445 | ) |
|
446 | return tensor_shapes |
|
447 | ||
448 | def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor: |
|
449 | """ |