| @@ 529-561 (lines=33) @@ | ||
| 526 | out_activation=out_activation, |
|
| 527 | ) |
|
| 528 | ||
| 529 | def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor: |
|
| 530 | """ |
|
| 531 | Build compute graph based on built layers. |
|
| 532 | ||
| 533 | :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch) |
|
| 534 | :param training: None or bool. |
|
| 535 | :param mask: None or tf.Tensor. |
|
| 536 | :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels) |
|
| 537 | """ |
|
| 538 | ||
| 539 | # encoding / down-sampling |
|
| 540 | skips = [] |
|
| 541 | encoded = inputs |
|
| 542 | for d in range(self._depth): |
|
| 543 | skip = self._encode_convs[d](inputs=encoded, training=training) |
|
| 544 | encoded = self._encode_pools[d](inputs=skip, training=training) |
|
| 545 | skips.append(skip) |
|
| 546 | ||
| 547 | # bottom |
|
| 548 | decoded = self._bottom_block(inputs=encoded, training=training) # type: ignore |
|
| 549 | ||
| 550 | # decoding / up-sampling |
|
| 551 | outs = [decoded] |
|
| 552 | for d in range(self._depth - 1, min(self._extract_levels) - 1, -1): |
|
| 553 | decoded = self._decode_deconvs[d](inputs=decoded, training=training) |
|
| 554 | decoded = self.build_skip_block()([decoded, skips[d]]) |
|
| 555 | decoded = self._decode_convs[d](inputs=decoded, training=training) |
|
| 556 | outs = [decoded] + outs |
|
| 557 | ||
| 558 | # output |
|
| 559 | output = self._output_block(outs) # type: ignore |
|
| 560 | ||
| 561 | return output |
|
| 562 | ||
| 563 | def get_config(self) -> dict: |
|
| 564 | """Return the config dictionary for recreating this class.""" |
|
| @@ 156-188 (lines=33) @@ | ||
| 153 | self.depth_divisor = depth_divisor |
|
| 154 | self.activation_fn = tf.nn.swish |
|
| 155 | ||
| 156 | def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor: |
|
| 157 | """ |
|
| 158 | Build compute graph based on built layers. |
|
| 159 | ||
| 160 | :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch) |
|
| 161 | :param training: None or bool. |
|
| 162 | :param mask: None or tf.Tensor. |
|
| 163 | :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels) |
|
| 164 | """ |
|
| 165 | ||
| 166 | # encoding / down-sampling |
|
| 167 | skips = [] |
|
| 168 | encoded = inputs |
|
| 169 | for d in range(self._depth): |
|
| 170 | skip = self._encode_convs[d](inputs=encoded, training=training) |
|
| 171 | encoded = self._encode_pools[d](inputs=skip, training=training) |
|
| 172 | skips.append(skip) |
|
| 173 | ||
| 174 | # bottom |
|
| 175 | decoded = self.build_efficient_net(inputs=encoded, training=training) # type: ignore |
|
| 176 | ||
| 177 | # decoding / up-sampling |
|
| 178 | outs = [decoded] |
|
| 179 | for d in range(self._depth - 1, min(self._extract_levels) - 1, -1): |
|
| 180 | decoded = self._decode_deconvs[d](inputs=decoded, training=training) |
|
| 181 | decoded = self.build_skip_block()([decoded, skips[d]]) |
|
| 182 | decoded = self._decode_convs[d](inputs=decoded, training=training) |
|
| 183 | outs = [decoded] + outs |
|
| 184 | ||
| 185 | # output |
|
| 186 | output = self._output_block(outs) # type: ignore |
|
| 187 | ||
| 188 | return output |
|
| 189 | ||
| 190 | ||
| 191 | def build_efficient_net(self, inputs: tf.Tensor, training=None) -> tf.Tensor: |
|