Conditions | 7 |
Total Lines | 125 |
Code Lines | 70 |
Lines | 13 |
Ratio | 10.4 % |
Changes | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.
There are several approaches to avoid long parameter lists:
1 | # coding=utf-8 |
||
188 | def predict( |
||
189 | gpu: str, |
||
190 | ckpt_path: str, |
||
191 | mode: str, |
||
192 | batch_size: int, |
||
193 | exp_name: str, |
||
194 | config_path: Union[str, List[str]], |
||
195 | num_workers: int = 1, |
||
196 | gpu_allow_growth: bool = True, |
||
197 | save_nifti: bool = True, |
||
198 | save_png: bool = True, |
||
199 | log_dir: str = "logs", |
||
200 | ): |
||
201 | """ |
||
202 | Function to predict some metrics from the saved model and logging results. |
||
203 | |||
204 | :param gpu: which env gpu to use. |
||
205 | :param ckpt_path: where model is stored, should be like log_folder/save/ckpt-x. |
||
206 | :param mode: train / valid / test, to define which split of dataset to be evaluated. |
||
207 | :param batch_size: int, batch size to perform predictions. |
||
208 | :param exp_name: name of the experiment. |
||
209 | :param config_path: to overwrite the default config. |
||
210 | :param num_workers: number of cpu cores to be used, <=0 means not limited. |
||
211 | :param gpu_allow_growth: whether to allocate whole GPU memory for training. |
||
212 | :param save_nifti: if true, outputs will be saved in nifti format. |
||
213 | :param save_png: if true, outputs will be saved in png format. |
||
214 | :param log_dir: path of the log directory. |
||
215 | """ |
||
216 | |||
217 | # env vars |
||
218 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu |
||
219 | os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "false" if gpu_allow_growth else "true" |
||
220 | View Code Duplication | if num_workers <= 0: # pragma: no cover |
|
|
|||
221 | logger.info( |
||
222 | "Limiting CPU usage by setting environment variables " |
||
223 | "OMP_NUM_THREADS, TF_NUM_INTRAOP_THREADS, TF_NUM_INTEROP_THREADS to %d. " |
||
224 | "This may slow down the prediction. " |
||
225 | "Please use --num_workers flag to modify the behavior. " |
||
226 | "Setting to 0 or negative values will remove the limitation.", |
||
227 | num_workers, |
||
228 | ) |
||
229 | # limit CPU usage |
||
230 | # https://github.com/tensorflow/tensorflow/issues/29968#issuecomment-789604232 |
||
231 | os.environ["OMP_NUM_THREADS"] = str(num_workers) |
||
232 | os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_workers) |
||
233 | os.environ["TF_NUM_INTEROP_THREADS"] = str(num_workers) |
||
234 | |||
235 | # load config |
||
236 | config, log_dir, ckpt_path = build_config( |
||
237 | config_path=config_path, log_dir=log_dir, exp_name=exp_name, ckpt_path=ckpt_path |
||
238 | ) |
||
239 | config["train"]["preprocess"]["batch_size"] = batch_size |
||
240 | |||
241 | # data |
||
242 | data_loader, dataset, _ = build_dataset( |
||
243 | dataset_config=config["dataset"], |
||
244 | preprocess_config=config["train"]["preprocess"], |
||
245 | mode=mode, |
||
246 | training=False, |
||
247 | repeat=False, |
||
248 | ) |
||
249 | assert data_loader is not None |
||
250 | |||
251 | # use strategy to support multiple GPUs |
||
252 | # the network is mirrored in each GPU so that we can use larger batch size |
||
253 | # https://www.tensorflow.org/guide/distributed_training |
||
254 | # only model, optimizer and metrics need to be defined inside the strategy |
||
255 | num_devices = max(len(tf.config.list_physical_devices("GPU")), 1) |
||
256 | if num_devices > 1: # pragma: no cover |
||
257 | strategy = tf.distribute.MirroredStrategy() |
||
258 | if batch_size % num_devices != 0: |
||
259 | raise ValueError( |
||
260 | f"batch size {batch_size} can not be divided evenly " |
||
261 | f"by the number of devices." |
||
262 | ) |
||
263 | else: |
||
264 | strategy = tf.distribute.get_strategy() |
||
265 | with strategy.scope(): |
||
266 | model: tf.keras.Model = REGISTRY.build_model( |
||
267 | config=dict( |
||
268 | name=config["train"]["method"], |
||
269 | moving_image_size=data_loader.moving_image_shape, |
||
270 | fixed_image_size=data_loader.fixed_image_shape, |
||
271 | index_size=data_loader.num_indices, |
||
272 | labeled=config["dataset"]["labeled"], |
||
273 | batch_size=batch_size, |
||
274 | config=config["train"], |
||
275 | ) |
||
276 | ) |
||
277 | optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"]) |
||
278 | model.compile(optimizer=optimizer) |
||
279 | model.plot_model(output_dir=log_dir) |
||
280 | |||
281 | # load weights |
||
282 | if ckpt_path.endswith(".ckpt"): |
||
283 | # for ckpt from tf.keras.callbacks.ModelCheckpoint |
||
284 | # skip warnings because of optimizers |
||
285 | # https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-object |
||
286 | model.load_weights(ckpt_path).expect_partial() # pragma: no cover |
||
287 | else: |
||
288 | # for ckpts from ckpt manager callback |
||
289 | _, _ = build_checkpoint_callback( |
||
290 | model=model, |
||
291 | dataset=dataset, |
||
292 | log_dir=log_dir, |
||
293 | save_period=config["train"]["save_period"], |
||
294 | ckpt_path=ckpt_path, |
||
295 | ) |
||
296 | |||
297 | # predict |
||
298 | fixed_grid_ref = tf.expand_dims( |
||
299 | layer_util.get_reference_grid(grid_size=data_loader.fixed_image_shape), axis=0 |
||
300 | ) # shape = (1, f_dim1, f_dim2, f_dim3, 3) |
||
301 | predict_on_dataset( |
||
302 | dataset=dataset, |
||
303 | fixed_grid_ref=fixed_grid_ref, |
||
304 | model=model, |
||
305 | model_method=config["train"]["method"], |
||
306 | save_dir=os.path.join(log_dir, "test"), |
||
307 | save_nifti=save_nifti, |
||
308 | save_png=save_png, |
||
309 | ) |
||
310 | |||
311 | # close the opened files in data loaders |
||
312 | data_loader.close() |
||
313 | |||
415 |