| @@ 390-452 (lines=63) @@ | ||
| 387 | out_channels=out_channels, |
|
| 388 | ) |
|
| 389 | ||
| 390 | def build_encode_layers( |
|
| 391 | self, |
|
| 392 | image_size: Tuple, |
|
| 393 | num_channels: Tuple, |
|
| 394 | depth: int, |
|
| 395 | encode_kernel_sizes: Union[int, List[int]], |
|
| 396 | strides: int, |
|
| 397 | padding: str, |
|
| 398 | ) -> List[Tuple]: |
|
| 399 | """ |
|
| 400 | Build layers for encoding. |
|
| 401 | ||
| 402 | :param image_size: (dim1, dim2, dim3). |
|
| 403 | :param num_channels: number of channels for each layer, |
|
| 404 | starting from the top layer. |
|
| 405 | :param depth: network starts with d = 0, and the bottom has d = depth. |
|
| 406 | :param encode_kernel_sizes: kernel size for down-sampling |
|
| 407 | :param strides: strides for down-sampling |
|
| 408 | :param padding: padding mode for all conv layers |
|
| 409 | :return: list of tensor shapes starting from d = 0 |
|
| 410 | """ |
|
| 411 | if isinstance(encode_kernel_sizes, int): |
|
| 412 | encode_kernel_sizes = [encode_kernel_sizes] * (depth + 1) |
|
| 413 | assert len(encode_kernel_sizes) == depth + 1 |
|
| 414 | ||
| 415 | # encoding / down-sampling |
|
| 416 | self._encode_convs = [] |
|
| 417 | self._encode_pools = [] |
|
| 418 | tensor_shape = image_size |
|
| 419 | tensor_shapes = [tensor_shape] |
|
| 420 | for d in range(depth): |
|
| 421 | encode_conv = self.build_encode_conv_block( |
|
| 422 | filters=num_channels[d], |
|
| 423 | kernel_size=encode_kernel_sizes[d], |
|
| 424 | padding=padding, |
|
| 425 | ) |
|
| 426 | encode_pool = self.build_down_sampling_block( |
|
| 427 | filters=num_channels[d], |
|
| 428 | kernel_size=strides, |
|
| 429 | strides=strides, |
|
| 430 | padding=padding, |
|
| 431 | ) |
|
| 432 | tensor_shape = tuple( |
|
| 433 | conv_utils.conv_output_length( |
|
| 434 | input_length=x, |
|
| 435 | filter_size=strides, |
|
| 436 | padding=padding, |
|
| 437 | stride=strides, |
|
| 438 | dilation=1, |
|
| 439 | ) |
|
| 440 | for x in tensor_shape |
|
| 441 | ) |
|
| 442 | self._encode_convs.append(encode_conv) |
|
| 443 | self._encode_pools.append(encode_pool) |
|
| 444 | tensor_shapes.append(tensor_shape) |
|
| 445 | ||
| 446 | # bottom layer |
|
| 447 | self._bottom_block = self.build_bottom_block( |
|
| 448 | filters=num_channels[depth], |
|
| 449 | kernel_size=encode_kernel_sizes[depth], |
|
| 450 | padding=padding, |
|
| 451 | ) |
|
| 452 | return tensor_shapes |
|
| 453 | ||
| 454 | def build_decode_layers( |
|
| 455 | self, |
|
| @@ 384-446 (lines=63) @@ | ||
| 381 | out_activation=out_activation, |
|
| 382 | ) |
|
| 383 | ||
| 384 | def build_encode_layers( |
|
| 385 | self, |
|
| 386 | image_size: Tuple, |
|
| 387 | num_channels: Tuple, |
|
| 388 | depth: int, |
|
| 389 | encode_kernel_sizes: Union[int, List[int]], |
|
| 390 | strides: int, |
|
| 391 | padding: str, |
|
| 392 | ) -> List[Tuple]: |
|
| 393 | """ |
|
| 394 | Build layers for encoding. |
|
| 395 | ||
| 396 | :param image_size: (dim1, dim2, dim3). |
|
| 397 | :param num_channels: number of channels for each layer, |
|
| 398 | starting from the top layer. |
|
| 399 | :param depth: network starts with d = 0, and the bottom has d = depth. |
|
| 400 | :param encode_kernel_sizes: kernel size for down-sampling |
|
| 401 | :param strides: strides for down-sampling |
|
| 402 | :param padding: padding mode for all conv layers |
|
| 403 | :return: list of tensor shapes starting from d = 0 |
|
| 404 | """ |
|
| 405 | if isinstance(encode_kernel_sizes, int): |
|
| 406 | encode_kernel_sizes = [encode_kernel_sizes] * (depth + 1) |
|
| 407 | assert len(encode_kernel_sizes) == depth + 1 |
|
| 408 | ||
| 409 | # encoding / down-sampling |
|
| 410 | self._encode_convs = [] |
|
| 411 | self._encode_pools = [] |
|
| 412 | tensor_shape = image_size |
|
| 413 | tensor_shapes = [tensor_shape] |
|
| 414 | for d in range(depth): |
|
| 415 | encode_conv = self.build_encode_conv_block( |
|
| 416 | filters=num_channels[d], |
|
| 417 | kernel_size=encode_kernel_sizes[d], |
|
| 418 | padding=padding, |
|
| 419 | ) |
|
| 420 | encode_pool = self.build_down_sampling_block( |
|
| 421 | filters=num_channels[d], |
|
| 422 | kernel_size=strides, |
|
| 423 | strides=strides, |
|
| 424 | padding=padding, |
|
| 425 | ) |
|
| 426 | tensor_shape = tuple( |
|
| 427 | conv_utils.conv_output_length( |
|
| 428 | input_length=x, |
|
| 429 | filter_size=strides, |
|
| 430 | padding=padding, |
|
| 431 | stride=strides, |
|
| 432 | dilation=1, |
|
| 433 | ) |
|
| 434 | for x in tensor_shape |
|
| 435 | ) |
|
| 436 | self._encode_convs.append(encode_conv) |
|
| 437 | self._encode_pools.append(encode_pool) |
|
| 438 | tensor_shapes.append(tensor_shape) |
|
| 439 | ||
| 440 | # bottom layer |
|
| 441 | self._bottom_block = self.build_bottom_block( |
|
| 442 | filters=num_channels[depth], |
|
| 443 | kernel_size=encode_kernel_sizes[depth], |
|
| 444 | padding=padding, |
|
| 445 | ) |
|
| 446 | return tensor_shapes |
|
| 447 | ||
| 448 | def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor: |
|
| 449 | """ |
|