Conditions | 6 |
Total Lines | 119 |
Code Lines | 74 |
Lines | 0 |
Ratio | 0 % |
Changes | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.
There are several approaches to avoid long parameter lists:
1 | # coding=utf-8 |
||
65 | def train( |
||
66 | gpu: str, |
||
67 | config_path: Union[str, List[str]], |
||
68 | ckpt_path: str, |
||
69 | num_cpus: int = -1, |
||
70 | gpu_allow_growth: bool = True, |
||
71 | exp_name: str = "", |
||
72 | log_dir: str = "logs", |
||
73 | max_epochs: int = -1, |
||
74 | ): |
||
75 | """ |
||
76 | Function to train a model. |
||
77 | |||
78 | :param gpu: which local gpu to use to train. |
||
79 | :param config_path: path to configuration set up. |
||
80 | :param ckpt_path: where to store training checkpoints. |
||
81 | :param num_cpus: number of cpus to be used, -1 means not limited. |
||
82 | :param gpu_allow_growth: whether to allocate whole GPU memory for training. |
||
83 | :param log_dir: path of the log directory. |
||
84 | :param exp_name: experiment name. |
||
85 | :param max_epochs: if max_epochs > 0, will use it to overwrite the configuration. |
||
86 | """ |
||
87 | # set env variables |
||
88 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu |
||
89 | os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" if gpu_allow_growth else "false" |
||
90 | if num_cpus > 0: |
||
91 | # Maximum number of threads to use for OpenMP parallel regions. |
||
92 | os.environ["OMP_NUM_THREADS"] = str(num_cpus) |
||
93 | # Without setting below 2 environment variables, it didn't work for me. Thanks to @cjw85 |
||
94 | os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_cpus) |
||
95 | os.environ["TF_NUM_INTEROP_THREADS"] = str(num_cpus) |
||
96 | |||
97 | # load config |
||
98 | config, log_dir, ckpt_path = build_config( |
||
99 | config_path=config_path, |
||
100 | log_dir=log_dir, |
||
101 | exp_name=exp_name, |
||
102 | ckpt_path=ckpt_path, |
||
103 | max_epochs=max_epochs, |
||
104 | ) |
||
105 | |||
106 | # build dataset |
||
107 | data_loader_train, dataset_train, steps_per_epoch_train = build_dataset( |
||
108 | dataset_config=config["dataset"], |
||
109 | preprocess_config=config["train"]["preprocess"], |
||
110 | mode="train", |
||
111 | training=True, |
||
112 | repeat=True, |
||
113 | ) |
||
114 | assert data_loader_train is not None # train data should not be None |
||
115 | data_loader_val, dataset_val, steps_per_epoch_val = build_dataset( |
||
116 | dataset_config=config["dataset"], |
||
117 | preprocess_config=config["train"]["preprocess"], |
||
118 | mode="valid", |
||
119 | training=False, |
||
120 | repeat=True, |
||
121 | ) |
||
122 | |||
123 | # use strategy to support multiple GPUs |
||
124 | # the network is mirrored in each GPU so that we can use larger batch size |
||
125 | # https://www.tensorflow.org/guide/distributed_training |
||
126 | # only model, optimizer and metrics need to be defined inside the strategy |
||
127 | num_devices = max(len(tf.config.list_physical_devices("GPU")), 1) |
||
128 | if num_devices > 1: |
||
129 | strategy = tf.distribute.MirroredStrategy() # pragma: no cover |
||
130 | else: |
||
131 | strategy = tf.distribute.get_strategy() |
||
132 | with strategy.scope(): |
||
133 | model: tf.keras.Model = REGISTRY.build_model( |
||
134 | config=dict( |
||
135 | name=config["train"]["method"], |
||
136 | moving_image_size=data_loader_train.moving_image_shape, |
||
137 | fixed_image_size=data_loader_train.fixed_image_shape, |
||
138 | index_size=data_loader_train.num_indices, |
||
139 | labeled=config["dataset"]["labeled"], |
||
140 | batch_size=config["train"]["preprocess"]["batch_size"], |
||
141 | config=config["train"], |
||
142 | num_devices=num_devices, |
||
143 | ) |
||
144 | ) |
||
145 | optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"]) |
||
146 | |||
147 | # compile |
||
148 | model.compile(optimizer=optimizer) |
||
149 | model.plot_model(output_dir=log_dir) |
||
150 | |||
151 | # build callbacks |
||
152 | tensorboard_callback = tf.keras.callbacks.TensorBoard( |
||
153 | log_dir=log_dir, |
||
154 | histogram_freq=config["train"]["save_period"], |
||
155 | update_freq=config["train"].get("update_freq", "epoch"), |
||
156 | ) |
||
157 | ckpt_callback, initial_epoch = build_checkpoint_callback( |
||
158 | model=model, |
||
159 | dataset=dataset_train, |
||
160 | log_dir=log_dir, |
||
161 | save_period=config["train"]["save_period"], |
||
162 | ckpt_path=ckpt_path, |
||
163 | ) |
||
164 | callbacks = [tensorboard_callback, ckpt_callback] |
||
165 | |||
166 | # train |
||
167 | # it's necessary to define the steps_per_epoch |
||
168 | # and validation_steps to prevent errors like |
||
169 | # BaseCollectiveExecutor::StartAbort Out of range: End of sequence |
||
170 | model.fit( |
||
171 | x=dataset_train, |
||
172 | steps_per_epoch=steps_per_epoch_train, |
||
173 | initial_epoch=initial_epoch, |
||
174 | epochs=config["train"]["epochs"], |
||
175 | validation_data=dataset_val, |
||
176 | validation_steps=steps_per_epoch_val, |
||
177 | callbacks=callbacks, |
||
178 | ) |
||
179 | |||
180 | # close file loaders in data loaders after training |
||
181 | data_loader_train.close() |
||
182 | if data_loader_val is not None: |
||
183 | data_loader_val.close() |
||
184 | |||
275 |