pytorch_example.pytorch_cnn()   B
last analyzed

Complexity

Conditions 7

Size

Total Lines 57
Code Lines 38

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 7
eloc 38
nop 1
dl 0
loc 57
rs 7.568
c 0
b 0
f 0

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
import os
2
3
import torch
4
import torch.nn as nn
5
import torch.nn.functional as F
6
import torch.optim as optim
7
import torch.utils.data
8
from torchvision import datasets
9
from torchvision import transforms
10
11
from hyperactive import Hyperactive
12
13
14
"""
15
derived from optuna example:
16
https://github.com/optuna/optuna/blob/master/examples/pytorch_simple.py
17
"""
18
DEVICE = torch.device("cpu")
19
BATCHSIZE = 256
20
CLASSES = 10
21
DIR = os.getcwd()
22
EPOCHS = 10
23
LOG_INTERVAL = 10
24
N_TRAIN_EXAMPLES = BATCHSIZE * 30
25
N_VALID_EXAMPLES = BATCHSIZE * 10
26
27
28
# Get the MNIST dataset.
29
train_loader = torch.utils.data.DataLoader(
30
    datasets.MNIST(DIR, train=True, download=True, transform=transforms.ToTensor()),
31
    batch_size=BATCHSIZE,
32
    shuffle=True,
33
)
34
valid_loader = torch.utils.data.DataLoader(
35
    datasets.MNIST(DIR, train=False, transform=transforms.ToTensor()),
36
    batch_size=BATCHSIZE,
37
    shuffle=True,
38
)
39
40
41
def pytorch_cnn(params):
42
    linear0 = params["linear.0"]
43
    linear1 = params["linear.1"]
44
45
    layers = []
46
47
    in_features = 28 * 28
48
49
    layers.append(nn.Linear(in_features, linear0))
50
    layers.append(nn.ReLU())
51
    layers.append(nn.Dropout(0.2))
52
53
    layers.append(nn.Linear(linear0, linear1))
54
    layers.append(nn.ReLU())
55
    layers.append(nn.Dropout(0.2))
56
57
    layers.append(nn.Linear(linear1, CLASSES))
58
    layers.append(nn.LogSoftmax(dim=1))
59
60
    model = nn.Sequential(*layers)
61
62
    # model = create_model(params).to(DEVICE)
63
    optimizer = getattr(optim, "Adam")(model.parameters(), lr=0.01)
64
65
    # Training of the model.
66
    for epoch in range(EPOCHS):
67
        model.train()
68
        for batch_idx, (data, target) in enumerate(train_loader):
69
            # Limiting training data for faster epochs.
70
            if batch_idx * BATCHSIZE >= N_TRAIN_EXAMPLES:
71
                break
72
73
            data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE)
74
75
            optimizer.zero_grad()
76
            output = model(data)
77
            loss = F.nll_loss(output, target)
78
            loss.backward()
79
            optimizer.step()
80
81
        # Validation of the model.
82
        model.eval()
83
        correct = 0
84
        with torch.no_grad():
85
            for batch_idx, (data, target) in enumerate(valid_loader):
86
                # Limiting validation data.
87
                if batch_idx * BATCHSIZE >= N_VALID_EXAMPLES:
88
                    break
89
                data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE)
90
                output = model(data)
91
                # Get the index of the max log-probability.
92
                pred = output.argmax(dim=1, keepdim=True)
93
                correct += pred.eq(target.view_as(pred)).sum().item()
94
95
        accuracy = correct / min(len(valid_loader.dataset), N_VALID_EXAMPLES)
96
97
    return accuracy
0 ignored issues
show
introduced by
The variable accuracy does not seem to be defined in case the for loop on line 66 is not entered. Are you sure this can never be the case?
Loading history...
98
99
100
search_space = {
101
    "linear.0": list(range(10, 200, 10)),
102
    "linear.1": list(range(10, 200, 10)),
103
}
104
105
106
hyper = Hyperactive()
107
hyper.add_search(pytorch_cnn, search_space, n_iter=5)
108
hyper.run()
109