| Conditions | 7 |
| Total Lines | 57 |
| Code Lines | 38 |
| Lines | 0 |
| Ratio | 0 % |
| Changes | 0 | ||
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
| 1 | import os |
||
| 41 | def pytorch_cnn(params): |
||
| 42 | linear0 = params["linear.0"] |
||
| 43 | linear1 = params["linear.1"] |
||
| 44 | |||
| 45 | layers = [] |
||
| 46 | |||
| 47 | in_features = 28 * 28 |
||
| 48 | |||
| 49 | layers.append(nn.Linear(in_features, linear0)) |
||
| 50 | layers.append(nn.ReLU()) |
||
| 51 | layers.append(nn.Dropout(0.2)) |
||
| 52 | |||
| 53 | layers.append(nn.Linear(linear0, linear1)) |
||
| 54 | layers.append(nn.ReLU()) |
||
| 55 | layers.append(nn.Dropout(0.2)) |
||
| 56 | |||
| 57 | layers.append(nn.Linear(linear1, CLASSES)) |
||
| 58 | layers.append(nn.LogSoftmax(dim=1)) |
||
| 59 | |||
| 60 | model = nn.Sequential(*layers) |
||
| 61 | |||
| 62 | # model = create_model(params).to(DEVICE) |
||
| 63 | optimizer = getattr(optim, "Adam")(model.parameters(), lr=0.01) |
||
| 64 | |||
| 65 | # Training of the model. |
||
| 66 | for epoch in range(EPOCHS): |
||
| 67 | model.train() |
||
| 68 | for batch_idx, (data, target) in enumerate(train_loader): |
||
| 69 | # Limiting training data for faster epochs. |
||
| 70 | if batch_idx * BATCHSIZE >= N_TRAIN_EXAMPLES: |
||
| 71 | break |
||
| 72 | |||
| 73 | data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE) |
||
| 74 | |||
| 75 | optimizer.zero_grad() |
||
| 76 | output = model(data) |
||
| 77 | loss = F.nll_loss(output, target) |
||
| 78 | loss.backward() |
||
| 79 | optimizer.step() |
||
| 80 | |||
| 81 | # Validation of the model. |
||
| 82 | model.eval() |
||
| 83 | correct = 0 |
||
| 84 | with torch.no_grad(): |
||
| 85 | for batch_idx, (data, target) in enumerate(valid_loader): |
||
| 86 | # Limiting validation data. |
||
| 87 | if batch_idx * BATCHSIZE >= N_VALID_EXAMPLES: |
||
| 88 | break |
||
| 89 | data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE) |
||
| 90 | output = model(data) |
||
| 91 | # Get the index of the max log-probability. |
||
| 92 | pred = output.argmax(dim=1, keepdim=True) |
||
| 93 | correct += pred.eq(target.view_as(pred)).sum().item() |
||
| 94 | |||
| 95 | accuracy = correct / min(len(valid_loader.dataset), N_VALID_EXAMPLES) |
||
| 96 | |||
| 97 | return accuracy |
||
|
|
|||
| 98 | |||
| 109 |