Conditions | 7 |
Total Lines | 57 |
Code Lines | 38 |
Lines | 0 |
Ratio | 0 % |
Changes | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
1 | import os |
||
41 | def pytorch_cnn(params): |
||
42 | linear0 = params["linear.0"] |
||
43 | linear1 = params["linear.1"] |
||
44 | |||
45 | layers = [] |
||
46 | |||
47 | in_features = 28 * 28 |
||
48 | |||
49 | layers.append(nn.Linear(in_features, linear0)) |
||
50 | layers.append(nn.ReLU()) |
||
51 | layers.append(nn.Dropout(0.2)) |
||
52 | |||
53 | layers.append(nn.Linear(linear0, linear1)) |
||
54 | layers.append(nn.ReLU()) |
||
55 | layers.append(nn.Dropout(0.2)) |
||
56 | |||
57 | layers.append(nn.Linear(linear1, CLASSES)) |
||
58 | layers.append(nn.LogSoftmax(dim=1)) |
||
59 | |||
60 | model = nn.Sequential(*layers) |
||
61 | |||
62 | # model = create_model(params).to(DEVICE) |
||
63 | optimizer = getattr(optim, "Adam")(model.parameters(), lr=0.01) |
||
64 | |||
65 | # Training of the model. |
||
66 | for epoch in range(EPOCHS): |
||
67 | model.train() |
||
68 | for batch_idx, (data, target) in enumerate(train_loader): |
||
69 | # Limiting training data for faster epochs. |
||
70 | if batch_idx * BATCHSIZE >= N_TRAIN_EXAMPLES: |
||
71 | break |
||
72 | |||
73 | data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE) |
||
74 | |||
75 | optimizer.zero_grad() |
||
76 | output = model(data) |
||
77 | loss = F.nll_loss(output, target) |
||
78 | loss.backward() |
||
79 | optimizer.step() |
||
80 | |||
81 | # Validation of the model. |
||
82 | model.eval() |
||
83 | correct = 0 |
||
84 | with torch.no_grad(): |
||
85 | for batch_idx, (data, target) in enumerate(valid_loader): |
||
86 | # Limiting validation data. |
||
87 | if batch_idx * BATCHSIZE >= N_VALID_EXAMPLES: |
||
88 | break |
||
89 | data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE) |
||
90 | output = model(data) |
||
91 | # Get the index of the max log-probability. |
||
92 | pred = output.argmax(dim=1, keepdim=True) |
||
93 | correct += pred.eq(target.view_as(pred)).sum().item() |
||
94 | |||
95 | accuracy = correct / min(len(valid_loader.dataset), N_VALID_EXAMPLES) |
||
96 | |||
97 | return accuracy |
||
|
|||
98 | |||
109 |