Conditions | 19 |
Total Lines | 116 |
Lines | 0 |
Ratio | 0 % |
Tests | 0 |
CRAP Score | 380 |
Changes | 6 | ||
Bugs | 1 | Features | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like train_models_on_samples() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | """ |
||
53 | def train_models_on_samples(X_train, y_train, X_val, y_val, models, |
||
54 | nr_epochs=5, subset_size=100, verbose=True, outputfile=None, |
||
|
|||
55 | model_path=None, early_stopping=False, |
||
56 | batch_size=20, metric='accuracy', use_noodles=None): |
||
57 | """ |
||
58 | Given a list of compiled models, this function trains |
||
59 | them all on a subset of the train data. If the given size of the subset is |
||
60 | smaller then the size of the data, the complete data set is used. |
||
61 | |||
62 | Parameters |
||
63 | ---------- |
||
64 | X_train : numpy array of shape (num_samples, num_timesteps, num_channels) |
||
65 | The input dataset for training |
||
66 | y_train : numpy array of shape (num_samples, num_classes) |
||
67 | The output classes for the train data, in binary format |
||
68 | X_val : numpy array of shape (num_samples_val, num_timesteps, num_channels) |
||
69 | The input dataset for validation |
||
70 | y_val : numpy array of shape (num_samples_val, num_classes) |
||
71 | The output classes for the validation data, in binary format |
||
72 | models : list of model, params, modeltypes |
||
73 | List of keras models to train |
||
74 | nr_epochs : int, optional |
||
75 | nr of epochs to use for training one model |
||
76 | subset_size : |
||
77 | The number of samples used from the complete train set |
||
78 | verbose : bool, optional |
||
79 | flag for displaying verbose output |
||
80 | outputfile: str, optional |
||
81 | Filename to store the model training results |
||
82 | model_path : str, optional |
||
83 | Directory to store the models as HDF5 files |
||
84 | early_stopping: bool |
||
85 | Stop when validation loss does not decrease |
||
86 | batch_size : int |
||
87 | nr of samples per batch |
||
88 | metric : str |
||
89 | metric to store in the history object |
||
90 | |||
91 | Returns |
||
92 | ---------- |
||
93 | histories : list of Keras History objects |
||
94 | train histories for all models |
||
95 | val_metrics : list of floats |
||
96 | validation accuraracies of the models |
||
97 | val_losses : list of floats |
||
98 | validation losses of the models |
||
99 | """ |
||
100 | # if subset_size is smaller then X_train, this will work fine |
||
101 | X_train_sub = X_train[:subset_size, :, :] |
||
102 | y_train_sub = y_train[:subset_size, :] |
||
103 | |||
104 | metric_name = get_metric_name(metric) |
||
105 | |||
106 | val_metrics = [] |
||
107 | val_losses = [] |
||
108 | |||
109 | def make_history(model, i=None): |
||
110 | model_metrics = [get_metric_name(name) for name in model.metrics] |
||
111 | if metric_name not in model_metrics: |
||
112 | raise ValueError( |
||
113 | 'Invalid metric. The model was not compiled with {} as metric'.format(metric_name)) |
||
114 | if early_stopping: |
||
115 | callbacks = [ |
||
116 | EarlyStopping(monitor='val_loss', patience=0, verbose=verbose, mode='auto')] |
||
117 | else: |
||
118 | callbacks = [] |
||
119 | |||
120 | args = (model, X_train_sub, y_train_sub) |
||
121 | kwargs = {'epochs': nr_epochs, |
||
122 | 'batch_size': batch_size, |
||
123 | 'validation_data': (X_val, y_val), |
||
124 | 'verbose': verbose, |
||
125 | 'callbacks': callbacks} |
||
126 | |||
127 | if use_noodles is None: |
||
128 | # if not using noodles, save every nugget when it comes |
||
129 | trained_model = train_model(*args, **kwargs) |
||
130 | if outputfile is not None: |
||
131 | store_train_hist_as_json(models[i][1], models[i][2], |
||
132 | trained_model.history, outputfile) |
||
133 | if model_path is not None: |
||
134 | trained_model.save( |
||
135 | os.path.join(model_path, 'model_{}.h5'.format(i))) |
||
136 | return trained_model |
||
137 | |||
138 | else: |
||
139 | assert has_noodles, "Noodles is not installed, or could not be imported." |
||
140 | return noodles.schedule_hint(call_by_ref=['model']) \ |
||
141 | (train_model)(*args, **kwargs) |
||
142 | |||
143 | if use_noodles is None: |
||
144 | trained_models = [ |
||
145 | make_history(model[0], i) |
||
146 | for i, model in enumerate(models)] |
||
147 | |||
148 | else: |
||
149 | assert has_noodles, "Noodles is not installed, or could not be imported." |
||
150 | |||
151 | # in case of noodles, first run everything |
||
152 | training_wf = noodles.gather_all([make_history(model[0]) for model in models]) |
||
153 | trained_models = use_noodles(training_wf) |
||
154 | |||
155 | # then save everything |
||
156 | for i, (history, model) in enumerate(trained_models): |
||
157 | if outputfile is not None: |
||
158 | store_train_hist_as_json(models[i][1], models[i][2], |
||
159 | history, outputfile) |
||
160 | if model_path is not None: |
||
161 | model.save(os.path.join(model_path, 'model_{}.h5'.format(i))) |
||
162 | |||
163 | # accumulate results |
||
164 | val_metrics = [tm.history['val_' + metric_name] |
||
165 | for tm in trained_models] |
||
166 | val_losses = [tm.history['val_loss'] |
||
167 | for tm in trained_models] |
||
168 | return [tm.history for tm in trained_models], val_metrics, val_losses |
||
169 | |||
357 |
This check looks for lines that are too long. You can specify the maximum line length.