Conditions | 8 |
Total Lines | 133 |
Code Lines | 82 |
Lines | 13 |
Ratio | 9.77 % |
Changes | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.
There are several approaches to avoid long parameter lists:
1 | # coding=utf-8 |
||
64 | def train( |
||
65 | gpu: str, |
||
66 | config_path: Union[str, List[str]], |
||
67 | ckpt_path: str, |
||
68 | num_workers: int = 1, |
||
69 | gpu_allow_growth: bool = True, |
||
70 | exp_name: str = "", |
||
71 | log_dir: str = "logs", |
||
72 | max_epochs: int = -1, |
||
73 | ): |
||
74 | """ |
||
75 | Function to train a model. |
||
76 | |||
77 | :param gpu: which local gpu to use to train. |
||
78 | :param config_path: path to configuration set up. |
||
79 | :param ckpt_path: where to store training checkpoints. |
||
80 | :param num_workers: number of cpu cores to be used, <=0 means not limited. |
||
81 | :param gpu_allow_growth: whether to allocate whole GPU memory for training. |
||
82 | :param log_dir: path of the log directory. |
||
83 | :param exp_name: experiment name. |
||
84 | :param max_epochs: if max_epochs > 0, will use it to overwrite the configuration. |
||
85 | """ |
||
86 | # set env variables |
||
87 | if gpu is not None: |
||
88 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu |
||
89 | os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = ( |
||
90 | "true" if gpu_allow_growth else "false" |
||
91 | ) |
||
92 | View Code Duplication | if num_workers <= 0: # pragma: no cover |
|
|
|||
93 | logger.info( |
||
94 | "Limiting CPU usage by setting environment variables " |
||
95 | "OMP_NUM_THREADS, TF_NUM_INTRAOP_THREADS, TF_NUM_INTEROP_THREADS to %d. " |
||
96 | "This may slow down the training. " |
||
97 | "Please use --num_workers flag to modify the behavior. " |
||
98 | "Setting to 0 or negative values will remove the limitation.", |
||
99 | num_workers, |
||
100 | ) |
||
101 | # limit CPU usage |
||
102 | # https://github.com/tensorflow/tensorflow/issues/29968#issuecomment-789604232 |
||
103 | os.environ["OMP_NUM_THREADS"] = str(num_workers) |
||
104 | os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_workers) |
||
105 | os.environ["TF_NUM_INTEROP_THREADS"] = str(num_workers) |
||
106 | |||
107 | # load config |
||
108 | config, log_dir, ckpt_path = build_config( |
||
109 | config_path=config_path, |
||
110 | log_dir=log_dir, |
||
111 | exp_name=exp_name, |
||
112 | ckpt_path=ckpt_path, |
||
113 | max_epochs=max_epochs, |
||
114 | ) |
||
115 | |||
116 | # build dataset |
||
117 | data_loader_train, dataset_train, steps_per_epoch_train = build_dataset( |
||
118 | dataset_config=config["dataset"], |
||
119 | preprocess_config=config["train"]["preprocess"], |
||
120 | split="train", |
||
121 | training=True, |
||
122 | repeat=True, |
||
123 | ) |
||
124 | assert data_loader_train is not None # train data should not be None |
||
125 | data_loader_val, dataset_val, steps_per_epoch_val = build_dataset( |
||
126 | dataset_config=config["dataset"], |
||
127 | preprocess_config=config["train"]["preprocess"], |
||
128 | split="valid", |
||
129 | training=False, |
||
130 | repeat=True, |
||
131 | ) |
||
132 | |||
133 | # use strategy to support multiple GPUs |
||
134 | # the network is mirrored in each GPU so that we can use larger batch size |
||
135 | # https://www.tensorflow.org/guide/distributed_training |
||
136 | # only model, optimizer and metrics need to be defined inside the strategy |
||
137 | num_devices = max(len(tf.config.list_physical_devices("GPU")), 1) |
||
138 | batch_size = config["train"]["preprocess"]["batch_size"] |
||
139 | if num_devices > 1: # pragma: no cover |
||
140 | strategy = tf.distribute.MirroredStrategy() |
||
141 | if batch_size % num_devices != 0: |
||
142 | raise ValueError( |
||
143 | f"batch size {batch_size} can not be divided evenly " |
||
144 | f"by the number of devices." |
||
145 | ) |
||
146 | else: |
||
147 | strategy = tf.distribute.get_strategy() |
||
148 | with strategy.scope(): |
||
149 | model: tf.keras.Model = REGISTRY.build_model( |
||
150 | config=dict( |
||
151 | name=config["train"]["method"], |
||
152 | moving_image_size=data_loader_train.moving_image_shape, |
||
153 | fixed_image_size=data_loader_train.fixed_image_shape, |
||
154 | index_size=data_loader_train.num_indices, |
||
155 | labeled=config["dataset"]["train"]["labeled"], |
||
156 | batch_size=batch_size, |
||
157 | config=config["train"], |
||
158 | ) |
||
159 | ) |
||
160 | optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"]) |
||
161 | model.compile(optimizer=optimizer) |
||
162 | model.plot_model(output_dir=log_dir) |
||
163 | |||
164 | # build callbacks |
||
165 | tensorboard_callback = tf.keras.callbacks.TensorBoard( |
||
166 | log_dir=log_dir, |
||
167 | histogram_freq=config["train"]["save_period"], |
||
168 | update_freq=config["train"].get("update_freq", "epoch"), |
||
169 | ) |
||
170 | ckpt_callback, initial_epoch = build_checkpoint_callback( |
||
171 | model=model, |
||
172 | dataset=dataset_train, |
||
173 | log_dir=log_dir, |
||
174 | save_period=config["train"]["save_period"], |
||
175 | ckpt_path=ckpt_path, |
||
176 | ) |
||
177 | callbacks = [tensorboard_callback, ckpt_callback] |
||
178 | |||
179 | # train |
||
180 | # it's necessary to define the steps_per_epoch |
||
181 | # and validation_steps to prevent errors like |
||
182 | # BaseCollectiveExecutor::StartAbort Out of range: End of sequence |
||
183 | model.fit( |
||
184 | x=dataset_train, |
||
185 | steps_per_epoch=steps_per_epoch_train, |
||
186 | initial_epoch=initial_epoch, |
||
187 | epochs=config["train"]["epochs"], |
||
188 | validation_data=dataset_val, |
||
189 | validation_steps=steps_per_epoch_val, |
||
190 | callbacks=callbacks, |
||
191 | ) |
||
192 | |||
193 | # close file loaders in data loaders after training |
||
194 | data_loader_train.close() |
||
195 | if data_loader_val is not None: |
||
196 | data_loader_val.close() |
||
197 | |||
289 |