Completed
Push — master ( 1d3564...feb62e )
by
unknown
13:32
created

train_models_on_samples()   A

Complexity

Conditions 2

Size

Total Lines 51

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 0
CRAP Score 6

Importance

Changes 2
Bugs 0 Features 1
Metric Value
cc 2
c 2
b 0
f 1
dl 0
loc 51
ccs 0
cts 12
cp 0
crap 6
rs 9.4109

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
'''
2
 Summary:
3
 Function generate_models from modelgen.py generates and compiles models
4
 Function train_models_on_samples trains those models
5
 Function plotTrainingProcess plots the training process
6
 Function find_best_architecture is wrapper function that combines
7
 these steps
8
 Example function calls in 'EvaluateDifferentModels.ipynb'
9
'''
10
import numpy as np
11
from matplotlib import pyplot as plt
12
from . import modelgen
13
from sklearn import neighbors, metrics
14
import warnings
15
16
17
def train_models_on_samples(X_train, y_train, X_val, y_val, models,
0 ignored issues
show
Coding Style Naming introduced by
The name X_train does not conform to the argument naming conventions ([a-z_][a-z0-9_]{1,30}$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
Coding Style Naming introduced by
The name X_val does not conform to the argument naming conventions ([a-z_][a-z0-9_]{1,30}$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
18
                            nr_epochs=5, subsize_set=100, verbose=True):
19
    '''
20
    Given a list of compiled models, this function trains
21
    them all on a subset of the train data. If the given size of the subset is
22
    smaller then the size of the data, the complete data set is used.
23
    Parameters
24
    ----------
25
    X_train : numpy array of shape (num_samples, num_timesteps, num_channels)
26
        The input dataset for training
27
    y_train : numpy array of shape (num_samples, num_classes)
28
        The output classes for the train data, in binary format
29
    X_val : numpy array of shape (num_samples_val, num_timesteps, num_channels)
30
        The input dataset for validation
31
    y_val : numpy array of shape (num_samples_val, num_classes)
32
        The output classes for the validation data, in binary format
33
    models : list of model, params, modeltypes
34
        List of keras models to train
35
    nr_epochs : int, optional
36
        nr of epochs to use for training one model
37
    subsize_set : int, optional
38
        number of samples to use from the training set for training these models
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (80/79).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
39
    verbose : bool, optional
40
        flag for displaying verbose output
41
42
    Returns
43
    ----------
44
    histories : list of Keras History objects
45
        train histories for all models
46
    val_accuracies : list of floats
47
        validation accuraracies of the models
48
    val_losses : list of floats
49
        validation losses of the models
50
    '''
51
    # if subset_size is smaller then X_train, this will work fine
52
    X_train_sub = X_train[:subsize_set, :, :]
0 ignored issues
show
Coding Style Naming introduced by
The name X_train_sub does not conform to the variable naming conventions ([a-z_][a-z0-9_]{1,30}$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
53
    y_train_sub = y_train[:subsize_set, :]
54
55
    histories = []
56
    val_accuracies = []
57
    val_losses = []
58
    for model, params, model_types in models:
59
        history = model.fit(X_train_sub, y_train_sub,
60
                            nb_epoch=nr_epochs, batch_size=20, # see comment on subsize_set
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (91/79).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
61
                            validation_data=(X_val, y_val),
62
                            verbose=verbose)
63
        histories.append(history)
64
        val_accuracies.append(history.history['val_acc'][-1])
65
        val_losses.append(history.history['val_loss'][-1])
66
67
    return histories, val_accuracies, val_losses
68
69
70
def plotTrainingProcess(history, name='Model', ax=None):
0 ignored issues
show
Coding Style Naming introduced by
The name plotTrainingProcess does not conform to the function naming conventions ([a-z_][a-z0-9_]{2,30}$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
71
    '''
72
    This function plots the loss and accuracy on the train and validation set,
73
    for each epoch in the history of one model.
74
75
    Parameters
76
    ----------
77
    history : keras History object for one model
78
        The history object of the training process corresponding to one model
79
    Returns
80
    ----------
81
82
    '''
83
    if ax is None:
84
        fig, ax = plt.subplots()
85
    ax2 = ax.twinx()
86
    LN = len(history.history['val_loss'])
0 ignored issues
show
Coding Style Naming introduced by
The name LN does not conform to the variable naming conventions ([a-z_][a-z0-9_]{1,30}$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
87
    val_loss, = ax.plot(range(LN), history.history['val_loss'], 'g--',
88
                        label='validation loss')
89
    train_loss, = ax.plot(range(LN), history.history['loss'], 'g-',
90
                          label='train loss')
91
    val_acc, = ax2.plot(range(LN), history.history['val_acc'], 'b--',
92
                        label='validation accuracy')
93
    train_acc, = ax2.plot(range(LN), history.history['acc'], 'b-',
94
                          label='train accuracy')
95
    ax.set_xlabel('epoch')
96
    ax.set_ylabel('loss', color='g')
97
    ax2.set_ylabel('accuracy', color='b')
98
    plt.legend(handles=[val_loss, train_loss, val_acc, train_acc],
99
               loc=2, bbox_to_anchor=(1.1, 1))
100
    plt.title(name)
101
102
103
def find_best_architecture(X_train, y_train, X_val, y_val, verbose=True,
0 ignored issues
show
Coding Style Naming introduced by
The name X_train does not conform to the argument naming conventions ([a-z_][a-z0-9_]{1,30}$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
Coding Style Naming introduced by
The name X_val does not conform to the argument naming conventions ([a-z_][a-z0-9_]{1,30}$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
104
                           number_of_models=5, nr_epochs=5, **kwargs):
105
    '''
106
    Tries out a number of models on a subsample of the data,
107
    and outputs the best found architecture and hyperparameters.
108
109
    Parameters
110
    ----------
111
    X_train : numpy array of shape (num_samples, num_timesteps, num_channels)
112
        The input dataset for training
113
    y_train : numpy array of shape (num_samples, num_classes)
114
        The output classes for the train data, in binary format
115
    X_val : numpy array of shape (num_samples_val, num_timesteps, num_channels)
116
        The input dataset for validation
117
    y_val : numpy array of shape (num_samples_val, num_classes)
118
        The output classes for the validation data, in binary format
119
    verbose : bool, optional
120
        flag for displaying verbose output
121
    **kwargs: key-value parameters
122
        parameters for generating the models
123
124
    Returns
125
    ----------
126
    best_model : Keras model
127
        Best performing model, already trained on a small sample data set.
128
    best_params : dict
129
        Dictionary containing the hyperparameters for the best model
130
    best_model_type : str
131
        Type of the best model
132
    knn_acc : float
133
        accuaracy for kNN prediction on validation set
134
    '''
135
    models = modelgen.generate_models(X_train.shape, y_train.shape[1],
136
                                      number_of_models=number_of_models,
137
                                      **kwargs)
138
    subsize_set = 100
139
    histories, val_accuracies, val_losses = train_models_on_samples(X_train,
140
                                                                    y_train,
141
                                                                    X_val,
142
                                                                    y_val,
143
                                                                    models,
144
                                                                    nr_epochs,
145
                                                                    subsize_set=subsize_set,
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (92/79).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
146
                                                                    verbose=verbose)
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (84/79).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
147
    best_model_index = np.argmax(val_accuracies)
148
    best_model, best_params, best_model_type = models[best_model_index]
149
    knn_acc = kNN_accuracy(X_train[:subsize_set, :, :], y_train[:subsize_set, :], X_val, y_val)
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (95/79).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
150
    if verbose:
151
        for i in range(len(models)): #<= now one plot per model, ultimately we
152
            # may want all models in one plot to allow for direct comparison
153
            name = str(models[i][1])
154
            plotTrainingProcess(histories[i], name)
155
        print('Best model: model ', best_model_index)
156
        print('Model type: ', best_model_type)
157
        print('Hyperparameters: ', best_params)
158
        print('Accuracy on validation set: ', val_accuracies[best_model_index])
159
        print('Accuracy of kNN on validation set', knn_acc)
160
161
    if val_accuracies[best_model_index] < knn_acc:
162
        warnings.warn('Best model not beter than kNN: ' +
163
                      str(val_accuracies[best_model_index]) + ' vs  ' +
164
                      str(knn_acc)
165
                      )
166
    return best_model, best_params, best_model_type, knn_acc
167
168
169
def kNN_accuracy(X_train, y_train, X_val, y_val, k=1):
0 ignored issues
show
Coding Style Naming introduced by
The name kNN_accuracy does not conform to the function naming conventions ([a-z_][a-z0-9_]{2,30}$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
Coding Style Naming introduced by
The name X_train does not conform to the argument naming conventions ([a-z_][a-z0-9_]{1,30}$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
Coding Style Naming introduced by
The name X_val does not conform to the argument naming conventions ([a-z_][a-z0-9_]{1,30}$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
170
    num_samples, num_timesteps, num_channels = X_train.shape
171
    clf = neighbors.KNeighborsClassifier(k)
172
    clf.fit(X_train.reshape(num_samples, num_timesteps*num_channels), y_train)
173
    num_samples, num_timesteps, num_channels = X_val.shape
174
    val_predict = clf.predict(X_val.reshape(num_samples, num_timesteps*num_channels))
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (85/79).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
175
    return metrics.accuracy_score(val_predict, y_val)
176