@@ 529-561 (lines=33) @@ | ||
526 | out_activation=out_activation, |
|
527 | ) |
|
528 | ||
529 | def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor: |
|
530 | """ |
|
531 | Build compute graph based on built layers. |
|
532 | ||
533 | :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch) |
|
534 | :param training: None or bool. |
|
535 | :param mask: None or tf.Tensor. |
|
536 | :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels) |
|
537 | """ |
|
538 | ||
539 | # encoding / down-sampling |
|
540 | skips = [] |
|
541 | encoded = inputs |
|
542 | for d in range(self._depth): |
|
543 | skip = self._encode_convs[d](inputs=encoded, training=training) |
|
544 | encoded = self._encode_pools[d](inputs=skip, training=training) |
|
545 | skips.append(skip) |
|
546 | ||
547 | # bottom |
|
548 | decoded = self._bottom_block(inputs=encoded, training=training) # type: ignore |
|
549 | ||
550 | # decoding / up-sampling |
|
551 | outs = [decoded] |
|
552 | for d in range(self._depth - 1, min(self._extract_levels) - 1, -1): |
|
553 | decoded = self._decode_deconvs[d](inputs=decoded, training=training) |
|
554 | decoded = self.build_skip_block()([decoded, skips[d]]) |
|
555 | decoded = self._decode_convs[d](inputs=decoded, training=training) |
|
556 | outs = [decoded] + outs |
|
557 | ||
558 | # output |
|
559 | output = self._output_block(outs) # type: ignore |
|
560 | ||
561 | return output |
|
562 | ||
563 | def get_config(self) -> dict: |
|
564 | """Return the config dictionary for recreating this class.""" |
@@ 156-188 (lines=33) @@ | ||
153 | self.depth_divisor = depth_divisor |
|
154 | self.activation_fn = tf.nn.swish |
|
155 | ||
156 | def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor: |
|
157 | """ |
|
158 | Build compute graph based on built layers. |
|
159 | ||
160 | :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch) |
|
161 | :param training: None or bool. |
|
162 | :param mask: None or tf.Tensor. |
|
163 | :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels) |
|
164 | """ |
|
165 | ||
166 | # encoding / down-sampling |
|
167 | skips = [] |
|
168 | encoded = inputs |
|
169 | for d in range(self._depth): |
|
170 | skip = self._encode_convs[d](inputs=encoded, training=training) |
|
171 | encoded = self._encode_pools[d](inputs=skip, training=training) |
|
172 | skips.append(skip) |
|
173 | ||
174 | # bottom |
|
175 | decoded = self.build_efficient_net(inputs=encoded, training=training) # type: ignore |
|
176 | ||
177 | # decoding / up-sampling |
|
178 | outs = [decoded] |
|
179 | for d in range(self._depth - 1, min(self._extract_levels) - 1, -1): |
|
180 | decoded = self._decode_deconvs[d](inputs=decoded, training=training) |
|
181 | decoded = self.build_skip_block()([decoded, skips[d]]) |
|
182 | decoded = self._decode_convs[d](inputs=decoded, training=training) |
|
183 | outs = [decoded] + outs |
|
184 | ||
185 | # output |
|
186 | output = self._output_block(outs) # type: ignore |
|
187 | ||
188 | return output |
|
189 | ||
190 | ||
191 | def build_efficient_net(self, inputs: tf.Tensor, training=None) -> tf.Tensor: |