Logo Xingxin on Bug

How to Train and Use a Neural Network? The TL;DR

December 26, 2025
3 min read

This blog post is a quick self-reflection on the essential steps to train a neural network, boiled down to the absolute basics.

Step 1: Load Data

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
 
batch_size = 4
 
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
 
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
 
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Step 2: Design a Neural Network

import torch.nn as nn
import torch.nn.functional as F
 
 
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
 
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
 
 
net = Net()

(Optional) Move the Network to GPU

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

Step 3: Choose a Loss Function and Optimizer

We need

  • loss: measure error
  • optimizer: update the weights
import torch.optim as optim
 
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

Step 4: Train a Neural Network

for epoch in range(10):
 
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # ⭐️: move the data from CPU to GPU as well
        inputs, labels = data[0].to(device), data[1].to(device)
 
        # ⭐️: zero the parameter gradients before learning
        optimizer.zero_grad()
 
        # ⭐️⭐️IMPORTANT: forward + backward + optimize
        outputs = net(inputs)  # forward pass
        loss = criterion(outputs, labels)
        loss.backward()  # backward propagation
        optimizer.step()  # optimizer adjust the weights
 
        running_loss += loss.item()
        if i % 2000 == 1999:
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0
 
print('Finished Training')

Step 5: Save and Load a Network

Save the model:

torch.save(net.state_dict(), 'network_10epoch.pth')

Load the model:

saved_net = Net()
saved_net.to(device)
saved_net.load_state_dict(torch.load('network_10epoch.pth', weights_only=True))
saved_net.eval()

Remark

Saving/loading the model with the state_dict() is considered best practice. It treats the model as a simple dictionary of parameters. Without state_dict, the legacy method torch.save(net) essentially serialize the entire model structure and parameters to disk.

See more at Saving and loading torch.nn.Modules.

Step 6: Evaluate/Use the Network

correct = 0
total = 0
 
# ⭐️when USE the network, we should be in the "no_grad" mode
#   it simply say, hey, I am not asking you to learn right now.
#   just tell me what you HAVE learned.
with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        
        # ⭐️: this essentially is "OK, I have learned this."
        outputs = saved_net(images)
        
        # this is post processing the probability to a fixed category
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
 
print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')