This blog post is a quick self-reflection on the essential steps to train a neural network, boiled down to the absolute basics.
Step 1: Load Data
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
batch_size = 4
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')Step 2: Design a Neural Network
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()(Optional) Move the Network to GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')Step 3: Choose a Loss Function and Optimizer
We need
- loss: measure error
- optimizer: update the weights
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)Step 4: Train a Neural Network
for epoch in range(10):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# ⭐️: move the data from CPU to GPU as well
inputs, labels = data[0].to(device), data[1].to(device)
# ⭐️: zero the parameter gradients before learning
optimizer.zero_grad()
# ⭐️⭐️IMPORTANT: forward + backward + optimize
outputs = net(inputs) # forward pass
loss = criterion(outputs, labels)
loss.backward() # backward propagation
optimizer.step() # optimizer adjust the weights
running_loss += loss.item()
if i % 2000 == 1999:
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
running_loss = 0.0
print('Finished Training')Step 5: Save and Load a Network
Save the model:
torch.save(net.state_dict(), 'network_10epoch.pth')Load the model:
saved_net = Net()
saved_net.to(device)
saved_net.load_state_dict(torch.load('network_10epoch.pth', weights_only=True))
saved_net.eval()Remark
Saving/loading the model with the
state_dict()is considered best practice. It treats the model as a simple dictionary of parameters. Withoutstate_dict, the legacy methodtorch.save(net)essentially serialize the entire model structure and parameters to disk.See more at Saving and loading torch.nn.Modules.
Step 6: Evaluate/Use the Network
correct = 0
total = 0
# ⭐️when USE the network, we should be in the "no_grad" mode
# it simply say, hey, I am not asking you to learn right now.
# just tell me what you HAVE learned.
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
# ⭐️: this essentially is "OK, I have learned this."
outputs = saved_net(images)
# this is post processing the probability to a fixed category
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')