Completed
Push — master ( 295812...1d3564 )
by
unknown
12:22
created

train_models_on_samples()   A

Complexity

Conditions 2

Size

Total Lines 51

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 12
CRAP Score 2

Importance

Changes 2
Bugs 0 Features 1
Metric Value
cc 2
c 2
b 0
f 1
dl 0
loc 51
ccs 12
cts 12
cp 1
crap 2
rs 9.4109

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
'''
2
 Summary:
3
 Function generate_models from modelgen.py generates and compiles models
4
 Function train_models_on_samples trains those models
5
 Function plotTrainingProcess plots the training process
6
 Function find_best_architecture is wrapper function that combines
7
 these steps
8
 Example function calls in 'EvaluateDifferentModels.ipynb'
9
'''
10 1
import numpy as np
11 1
from matplotlib import pyplot as plt
12 1
from . import modelgen
13 1
from sklearn import neighbors, metrics
14 1
import warnings
15
16
17 1
def train_models_on_samples(X_train, y_train, X_val, y_val, models,
0 ignored issues
show
Coding Style Naming introduced by
The name X_train does not conform to the argument naming conventions ([a-z_][a-z0-9_]{1,30}$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
Coding Style Naming introduced by
The name X_val does not conform to the argument naming conventions ([a-z_][a-z0-9_]{1,30}$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
18
                            nr_epochs=5, subsize_set=100, verbose=True):
19
    '''
20
    Given a list of compiled models, this function trains
21
    them all on a subset of the train data. If the given size of the subset is
22
    smaller then the size of the data, the complete data set is used.
23
    Parameters
24
    ----------
25
    X_train : numpy array of shape (num_samples, num_timesteps, num_channels)
26
        The input dataset for training
27
    y_train : numpy array of shape (num_samples, num_classes)
28
        The output classes for the train data, in binary format
29
    X_val : numpy array of shape (num_samples_val, num_timesteps, num_channels)
30
        The input dataset for validation
31
    y_val : numpy array of shape (num_samples_val, num_classes)
32
        The output classes for the validation data, in binary format
33
    models : list of model, params, modeltypes
34
        List of keras models to train
35
    nr_epochs : int, optional
36
        nr of epochs to use for training one model
37
    subsize_set : int, optional
38
        number of samples to use from the training set for training these models
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (80/79).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
39
    verbose : bool, optional
40
        flag for displaying verbose output
41
42
    Returns
43
    ----------
44
    histories : list of Keras History objects
45
        train histories for all models
46
    val_accuracies : list of floats
47
        validation accuraracies of the models
48
    val_losses : list of floats
49
        validation losses of the models
50
    '''
51
    # if subset_size is smaller then X_train, this will work fine
52 1
    X_train_sub = X_train[:subsize_set, :, :]
0 ignored issues
show
Coding Style Naming introduced by
The name X_train_sub does not conform to the variable naming conventions ([a-z_][a-z0-9_]{1,30}$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
53 1
    y_train_sub = y_train[:subsize_set, :]
54
55 1
    histories = []
56 1
    val_accuracies = []
57 1
    val_losses = []
58 1
    for model, params, model_types in models:
59 1
        history = model.fit(X_train_sub, y_train_sub,
60
                            nb_epoch=nr_epochs, batch_size=20, # see comment on subsize_set
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (91/79).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
61
                            validation_data=(X_val, y_val),
62
                            verbose=verbose)
63 1
        histories.append(history)
64 1
        val_accuracies.append(history.history['val_acc'][-1])
65 1
        val_losses.append(history.history['val_loss'][-1])
66
67 1
    return histories, val_accuracies, val_losses
68
69
70 1
def plotTrainingProcess(history, name='Model', ax=None):
0 ignored issues
show
Coding Style Naming introduced by
The name plotTrainingProcess does not conform to the function naming conventions ([a-z_][a-z0-9_]{2,30}$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
71
    '''
72
    This function plots the loss and accuracy on the train and validation set,
73
    for each epoch in the history of one model.
74
75
    Parameters
76
    ----------
77
    history : keras History object for one model
78
        The history object of the training process corresponding to one model
79
    Returns
80
    ----------
81
82
    '''
83
    if ax is None:
84
        fig, ax = plt.subplots()
85
    ax2 = ax.twinx()
86
    LN = len(history.history['val_loss'])
0 ignored issues
show
Coding Style Naming introduced by
The name LN does not conform to the variable naming conventions ([a-z_][a-z0-9_]{1,30}$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
87
    val_loss, = ax.plot(range(LN), history.history['val_loss'], 'g--',
88
                        label='validation loss')
89
    train_loss, = ax.plot(range(LN), history.history['loss'], 'g-',
90
                          label='train loss')
91
    val_acc, = ax2.plot(range(LN), history.history['val_acc'], 'b--',
92
                        label='validation accuracy')
93
    train_acc, = ax2.plot(range(LN), history.history['acc'], 'b-',
94
                          label='train accuracy')
95
    ax.set_xlabel('epoch')
96
    ax.set_ylabel('loss', color='g')
97
    ax2.set_ylabel('accuracy', color='b')
98
    plt.legend(handles=[val_loss, train_loss, val_acc, train_acc],
99
               loc=2, bbox_to_anchor=(1.1, 1))
100
    plt.title(name)
101
102
103 1
def find_best_architecture(X_train, y_train, X_val, y_val, verbose=True,
0 ignored issues
show
Coding Style Naming introduced by
The name X_train does not conform to the argument naming conventions ([a-z_][a-z0-9_]{1,30}$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
Coding Style Naming introduced by
The name X_val does not conform to the argument naming conventions ([a-z_][a-z0-9_]{1,30}$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
104
                           number_of_models=5, nr_epochs=5, **kwargs):
105
    '''
106
    Tries out a number of models on a subsample of the data,
107
    and outputs the best found architecture and hyperparameters.
108
109
    Parameters
110
    ----------
111
    X_train : numpy array of shape (num_samples, num_timesteps, num_channels)
112
        The input dataset for training
113
    y_train : numpy array of shape (num_samples, num_classes)
114
        The output classes for the train data, in binary format
115
    X_val : numpy array of shape (num_samples_val, num_timesteps, num_channels)
116
        The input dataset for validation
117
    y_val : numpy array of shape (num_samples_val, num_classes)
118
        The output classes for the validation data, in binary format
119
    verbose : bool, optional
120
        flag for displaying verbose output
121
    **kwargs: key-value parameters
122
        parameters for generating the models
123
124
    Returns
125
    ----------
126
    best_model : Keras model
127
        Best performing model, already trained on a small sample data set.
128
    best_params : dict
129
        Dictionary containing the hyperparameters for the best model
130
    best_model_type : str
131
        Type of the best model
132
    knn_acc : float
133
        accuaracy for kNN prediction on validation set
134
    '''
135 1
    models = modelgen.generate_models(X_train.shape, y_train.shape[1],
136
                                      number_of_models=number_of_models,
137
                                      **kwargs)
138 1
    subsize_set = 100
139 1
    histories, val_accuracies, val_losses = train_models_on_samples(X_train,
140
                                                                    y_train,
141
                                                                    X_val,
142
                                                                    y_val,
143
                                                                    models,
144
                                                                    nr_epochs,
145
                                                                    subsize_set=subsize_set,
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (92/79).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
146
                                                                    verbose=verbose)
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (84/79).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
147 1
    best_model_index = np.argmax(val_accuracies)
148 1
    best_model, best_params, best_model_type = models[best_model_index]
149 1
    knn_acc = kNN_accuracy(X_train[:subsize_set, :, :], y_train[:subsize_set, :], X_val, y_val)
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (95/79).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
150 1
    if verbose:
151
        for i in range(len(models)): #<= now one plot per model, ultimately we
152
            # may want all models in one plot to allow for direct comparison
153
            name = str(models[i][1])
154
            plotTrainingProcess(histories[i], name)
155
        print('Best model: model ', best_model_index)
156
        print('Model type: ', best_model_type)
157
        print('Hyperparameters: ', best_params)
158
        print('Accuracy on validation set: ', val_accuracies[best_model_index])
159
        print('Accuracy of kNN on validation set', knn_acc)
160
161 1
    if val_accuracies[best_model_index] < knn_acc:
162
        warnings.warn('Best model not beter than kNN: ' +
163
                      str(val_accuracies[best_model_index]) + ' vs  ' +
164
                      str(knn_acc)
165
                      )
166 1
    return best_model, best_params, best_model_type, knn_acc
167
168
169 1
def kNN_accuracy(X_train, y_train, X_val, y_val, k=1):
0 ignored issues
show
Coding Style Naming introduced by
The name kNN_accuracy does not conform to the function naming conventions ([a-z_][a-z0-9_]{2,30}$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
Coding Style Naming introduced by
The name X_train does not conform to the argument naming conventions ([a-z_][a-z0-9_]{1,30}$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
Coding Style Naming introduced by
The name X_val does not conform to the argument naming conventions ([a-z_][a-z0-9_]{1,30}$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
170 1
    num_samples, num_timesteps, num_channels = X_train.shape
171 1
    clf = neighbors.KNeighborsClassifier(k)
172 1
    clf.fit(X_train.reshape(num_samples, num_timesteps*num_channels), y_train)
173 1
    num_samples, num_timesteps, num_channels = X_val.shape
174 1
    val_predict = clf.predict(X_val.reshape(num_samples, num_timesteps*num_channels))
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (85/79).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
175
    return metrics.accuracy_score(val_predict, y_val)
176