| Conditions | 19 |
| Total Lines | 116 |
| Lines | 0 |
| Ratio | 0 % |
| Tests | 23 |
| CRAP Score | 38.5596 |
| Changes | 6 | ||
| Bugs | 1 | Features | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like train_models_on_samples() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | """ |
||
| 53 | 1 | def train_models_on_samples(X_train, y_train, X_val, y_val, models, |
|
| 54 | nr_epochs=5, subset_size=100, verbose=True, outputfile=None, |
||
|
|
|||
| 55 | model_path=None, early_stopping=False, |
||
| 56 | batch_size=20, metric='accuracy', use_noodles=None): |
||
| 57 | """ |
||
| 58 | Given a list of compiled models, this function trains |
||
| 59 | them all on a subset of the train data. If the given size of the subset is |
||
| 60 | smaller then the size of the data, the complete data set is used. |
||
| 61 | |||
| 62 | Parameters |
||
| 63 | ---------- |
||
| 64 | X_train : numpy array of shape (num_samples, num_timesteps, num_channels) |
||
| 65 | The input dataset for training |
||
| 66 | y_train : numpy array of shape (num_samples, num_classes) |
||
| 67 | The output classes for the train data, in binary format |
||
| 68 | X_val : numpy array of shape (num_samples_val, num_timesteps, num_channels) |
||
| 69 | The input dataset for validation |
||
| 70 | y_val : numpy array of shape (num_samples_val, num_classes) |
||
| 71 | The output classes for the validation data, in binary format |
||
| 72 | models : list of model, params, modeltypes |
||
| 73 | List of keras models to train |
||
| 74 | nr_epochs : int, optional |
||
| 75 | nr of epochs to use for training one model |
||
| 76 | subset_size : |
||
| 77 | The number of samples used from the complete train set |
||
| 78 | verbose : bool, optional |
||
| 79 | flag for displaying verbose output |
||
| 80 | outputfile: str, optional |
||
| 81 | Filename to store the model training results |
||
| 82 | model_path : str, optional |
||
| 83 | Directory to store the models as HDF5 files |
||
| 84 | early_stopping: bool |
||
| 85 | Stop when validation loss does not decrease |
||
| 86 | batch_size : int |
||
| 87 | nr of samples per batch |
||
| 88 | metric : str |
||
| 89 | metric to store in the history object |
||
| 90 | |||
| 91 | Returns |
||
| 92 | ---------- |
||
| 93 | histories : list of Keras History objects |
||
| 94 | train histories for all models |
||
| 95 | val_metrics : list of floats |
||
| 96 | validation accuraracies of the models |
||
| 97 | val_losses : list of floats |
||
| 98 | validation losses of the models |
||
| 99 | """ |
||
| 100 | # if subset_size is smaller then X_train, this will work fine |
||
| 101 | 1 | X_train_sub = X_train[:subset_size, :, :] |
|
| 102 | 1 | y_train_sub = y_train[:subset_size, :] |
|
| 103 | |||
| 104 | 1 | metric_name = get_metric_name(metric) |
|
| 105 | |||
| 106 | 1 | val_metrics = [] |
|
| 107 | 1 | val_losses = [] |
|
| 108 | |||
| 109 | 1 | def make_history(model, i=None): |
|
| 110 | 1 | model_metrics = [get_metric_name(name) for name in model.metrics] |
|
| 111 | 1 | if metric_name not in model_metrics: |
|
| 112 | raise ValueError( |
||
| 113 | 'Invalid metric. The model was not compiled with {} as metric'.format(metric_name)) |
||
| 114 | 1 | if early_stopping: |
|
| 115 | callbacks = [ |
||
| 116 | EarlyStopping(monitor='val_loss', patience=0, verbose=verbose, mode='auto')] |
||
| 117 | else: |
||
| 118 | 1 | callbacks = [] |
|
| 119 | |||
| 120 | 1 | args = (model, X_train_sub, y_train_sub) |
|
| 121 | 1 | kwargs = {'epochs': nr_epochs, |
|
| 122 | 'batch_size': batch_size, |
||
| 123 | 'validation_data': (X_val, y_val), |
||
| 124 | 'verbose': verbose, |
||
| 125 | 'callbacks': callbacks} |
||
| 126 | |||
| 127 | 1 | if use_noodles is None: |
|
| 128 | # if not using noodles, save every nugget when it comes |
||
| 129 | 1 | trained_model = train_model(*args, **kwargs) |
|
| 130 | 1 | if outputfile is not None: |
|
| 131 | store_train_hist_as_json(models[i][1], models[i][2], |
||
| 132 | trained_model.history, outputfile) |
||
| 133 | 1 | if model_path is not None: |
|
| 134 | trained_model.save( |
||
| 135 | os.path.join(model_path, 'model_{}.h5'.format(i))) |
||
| 136 | 1 | return trained_model |
|
| 137 | |||
| 138 | else: |
||
| 139 | assert has_noodles, "Noodles is not installed, or could not be imported." |
||
| 140 | return noodles.schedule_hint(call_by_ref=['model']) \ |
||
| 141 | (train_model)(*args, **kwargs) |
||
| 142 | |||
| 143 | 1 | if use_noodles is None: |
|
| 144 | 1 | trained_models = [ |
|
| 145 | make_history(model[0], i) |
||
| 146 | for i, model in enumerate(models)] |
||
| 147 | |||
| 148 | else: |
||
| 149 | assert has_noodles, "Noodles is not installed, or could not be imported." |
||
| 150 | |||
| 151 | # in case of noodles, first run everything |
||
| 152 | training_wf = noodles.gather_all([make_history(model[0]) for model in models]) |
||
| 153 | trained_models = use_noodles(training_wf) |
||
| 154 | |||
| 155 | # then save everything |
||
| 156 | for i, (history, model) in enumerate(trained_models): |
||
| 157 | if outputfile is not None: |
||
| 158 | store_train_hist_as_json(models[i][1], models[i][2], |
||
| 159 | history, outputfile) |
||
| 160 | if model_path is not None: |
||
| 161 | model.save(os.path.join(model_path, 'model_{}.h5'.format(i))) |
||
| 162 | |||
| 163 | # accumulate results |
||
| 164 | 1 | val_metrics = [tm.history['val_' + metric_name] |
|
| 165 | for tm in trained_models] |
||
| 166 | 1 | val_losses = [tm.history['val_loss'] |
|
| 167 | for tm in trained_models] |
||
| 168 | 1 | return [tm.history for tm in trained_models], val_metrics, val_losses |
|
| 169 | |||
| 357 |
This check looks for lines that are too long. You can specify the maximum line length.