Using PyTorch for image classification in 2023: a basic tutorial
Introduction to PyTorch and Image Classification
Diving into the world of deep learning, especially for tasks such as image classification, can be daunting. Thankfully, PyTorch, developed by researchers at Facebook’s AI Research lab, offers a seamless pathway to building powerful models with an approach that is both intuitive and flexible.
PyTorch is an open-source machine learning library for Python, used for applications such as natural language processing and computer vision. One of the things I particularly appreciate about PyTorch is how it handles dynamic computation graphs, which is quite a breath of fresh air if you’ve dealt with static graphs before. This enables a more interactive and iterative approach when defining models and makes debugging a whole lot easier.
When I first set out to learn image classification using PyTorch, I quickly realized the broad community support the library enjoys. There are countless resources, including tutorials, forums, and research papers available. A particularly valuable asset is the official PyTorch documentation and tutorials (https://pytorch.org/tutorials/), which are up-to-date and incredibly detailed, walking you through from basics to more advanced topics.
Coming to image classification, it’s a task where a computer is trained to assign a label from a fixed set of categories to an image. I’ll walk you through a simple example using PyTorch to classify images of handwritten digits from the MNIST dataset - a good starting point due to its simplicity.
To get a taste of PyTorch in action, let’s begin with importing the necessary libraries:
import torch
import torchvision
import torchvision.transforms as transforms
You’ll notice torchvision
being imported here; it’s a package in the PyTorch library that’s filled with popular datasets, model architectures, and common image transformations for computer vision.
Before diving into building models and training them, which the later sections will cover in-depth, let’s quickly go through how to load and visualize data, which is essential for understanding the kind of inputs we’re dealing with.
# Transform the data to torch tensors and normalize it
= transforms.Compose(
transform
[transforms.ToTensor(),0.5,), (0.5,))]
transforms.Normalize((
)
# Download the MNIST training dataset
= torchvision.datasets.MNIST(root='./data', train=True,
trainset =True, transform=transform)
download
= torch.utils.data.DataLoader(trainset, batch_size=4,
trainloader =True, num_workers=2)
shuffle
# Get some random training images
= iter(trainloader)
dataiter = dataiter.next()
images, labels
# Show images and label
print('Labels: ', labels)
print('Image Batch Shape: ', images.size())
In the code above, you’ll notice that we’re instantiating a DataLoader with a batch size of 4. This means that when we iterate over this loader, we’ll get batches of 4 images along with their corresponding labels. DataLoader is a fantastic feature in PyTorch that abstracts away a lot of the cumbersome work involved in iterating through datasets.
At this stage, you’ve got a sneak peek into PyTorch’s abilities for handling data - it’s concise and heavily optimized for GPU computation (though you can absolutely run it on a CPU).
In upcoming sections of the tutorial, we’ll dive deeper into model building, including how to define neural networks in PyTorch, the training process, and eventually evaluating our models to see how well they can classify new images they’ve never seen before.
Alright, enough talk. I hope this introduction has set the stage for your journey with PyTorch and image classification. The combination of Python’s simplicity and PyTorch’s power creates a compelling duo for anyone looking to break into the field of deep learning. Let’s get ready to build and train our models next.
Setting Up the PyTorch Environment
Before we dive into the nuts and bolts of image classification with PyTorch, it’s essential to ensure that you’ve got your PyTorch environment up and running. Setting up an environment where you can fearlessly tinker and break things without consequence is crucial. Trust me, when I say it saves a ton of headaches down the line.
I would typically begin by installing Anaconda, an open-source distribution that simplifies package management and deployment. You can download it from the Anaconda website. Once installed, I like to create a new virtual environment specifically for PyTorch; this helps in avoiding package conflicts.
conda create -n pytorch_env python=3.8
Activate the new environment before moving forward.
conda activate pytorch_env
The next step is installing PyTorch. PyTorch provides a handy package selector on their website. For instance, if you’re working without GPU support, the command would be something like:
conda install pytorch torchvision torchaudio cpuonly -c pytorch
With GPU support (and assuming you’ve got the correct CUDA version installed), it’ll look like this:
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch
After installation, verify that PyTorch is correctly installed by running:
import torch
print(torch.__version__)
This command should output the version of PyTorch you’ve just installed. You should now be able to import PyTorch and its submodules like torchvision
that provides datasets, models and image transformations for computer vision, which is just what we need for image classification.
Now on to the Python packages that make life easier. matplotlib
is great for plotting, and numpy
is indispensable for any number crunching:
conda install matplotlib numpy
Throughout my journey, I found Jupyter notebooks an excellent tool for experimenting with PyTorch, as they allow me to execute code blocks sequentially and see outputs in real-time. Install it by:
conda install jupyter
Once that’s done, launch a new notebook:
jupyter notebook
Prepare your imports within the notebook:
# These are your typically needed imports for working with PyTorch and visualizing data.
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
Using the notebook, write small snippets and test them immediately. For instance, you could load and visualize an image from a dataset to see if everything is working:
# It’s always good to test your setup with a small sample.
= torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
sample_dataset = transforms.ToPILImage()(sample_dataset[0][0])
img
plt.imshow(img) plt.show()
At this stage, you’ve successfully waded through the configuration swamp, and you’re all set to take on more complex tasks like understanding your dataset and building your classification model.
Remember, the PyTorch community is incredibly active and supportive. If you hit snags, a quick search could pull up a GitHub issue discussion, Stack Overflow thread, or tutorial that gets you back on track. Dive into the documentation and existing codebases, and you’ll pick up best practices along the way.
Take this setup as your machine learning playground – nobody becomes a PyTorch wizard overnight, but with this environment you’ve got just the wand you need to start casting spells.
Understanding and Preparing the Dataset
Before I delve into the exciting world of building image classification models with PyTorch, I think it’s crucial to talk about the first step that doesn’t get as much limelight: preparing the dataset. If the dataset isn’t up to scratch, trust me, even the fanciest algorithm won’t be able to save the day.
First off, let’s understand what goes into a good dataset. For image classification, besides needing a ton of images, they should be labeled accurately. This could either mean having folders for each class containing the relevant images or having an annotation file mapping each image to its label.
Now, let’s grab a dataset. A solid choice for beginners is the CIFAR-10 dataset, a set with 60,000 32x32 color images across 10 classes. You can easily download it from the torchvision package:
import torchvision
import torchvision.transforms as transforms
= transforms.Compose(
transform
[transforms.ToTensor(),0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
transforms.Normalize((
= torchvision.datasets.CIFAR10(root='./data', train=True,
trainset =True, transform=transform)
download= torch.utils.data.DataLoader(trainset, batch_size=4,
trainloader =True, num_workers=2) shuffle
These lines not only download the CIFAR-10 but also apply a transformation to normalize the image data. Normalization helps in speeding up the training by making sure that the input parameter scales don’t impact the learning process adversely.
With the dataset downloaded, it’s common to split it into a training set and a validation set. This allows me to train the model on a large chunk of the data and then validate its performance on unseen data. Here’s how you can manually split the dataset:
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
= 0.2
validation_ratio = 42
random_seed
= len(trainset)
num_train = list(range(num_train))
indices = int(np.floor(validation_ratio * num_train))
split
np.random.seed(random_seed)
np.random.shuffle(indices)
= indices[split:], indices[:split]
train_idx, valid_idx = SubsetRandomSampler(train_idx)
train_sampler = SubsetRandomSampler(valid_idx) valid_sampler
And then, I create the actual loaders, which will pull the data from the dataset during training:
= torch.utils.data.DataLoader(
train_loader =4, sampler=train_sampler, num_workers=2
trainset, batch_size
)
= torch.utils.data.DataLoader(
validation_loader =4, sampler=valid_sampler, num_workers=2
trainset, batch_size )
Another point to consider is data augmentation. It’s a way for me to artificially expand my dataset and introduce variability. This could prevent the model from overfitting and help it generalize better. Here’s how to add basic data augmentation using transforms:
= transforms.Compose(
transform
[transforms.RandomHorizontalFlip(),10),
transforms.RandomRotation(
transforms.ToTensor(),0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
transforms.Normalize((
= torchvision.datasets.CIFAR10(root='./data', train=True,
trainset =True, transform=transform) download
Preparing the dataset might seem tedious, but it’s a fundamental step that can’t be overlooked. Once the data is prepped and loaded, you’ll be in the clear to move on to building and training your model, but more on that later.
Remember, if you need more detailed info about the CIFAR-10 dataset or PyTorch’s DataLoader
, always consult the official PyTorch documentation. It’s your go-to for understanding all the intricacies of dataset operations in PyTorch.
Building the Image Classification Model
Once your environment is set up and your dataset is ready, it’s time to dive into the exciting part of machine learning: building the model. PyTorch provides a clean and modular way to create your image classification model. I’ll walk you through the steps to construct a simple convolutional neural network (CNN), which is exceptionally good at handling image data.
We’ll start by importing the necessary modules. PyTorch’s nn
module offers a way to build our network. I’ll be using OrderedDict
to keep layers in order, but you don’t have to.
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
Next, we define our CNN. Here’s a simple architecture I’ve had success with:
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.features = nn.Sequential(
OrderedDict(['conv1', nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2)),
('relu1', nn.ReLU()),
('pool1', nn.MaxPool2d(kernel_size=2, stride=2)),
('conv2', nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2)),
('relu2', nn.ReLU()),
('pool2', nn.MaxPool2d(kernel_size=2, stride=2)),
(
])
)self.classifier = nn.Sequential(
OrderedDict(['fc1', nn.Linear(64*56*56, 1024)),
('relu3', nn.ReLU()),
('fc2', nn.Linear(1024, 10)), # 10 is the number of classes
('output', nn.LogSoftmax(dim=1))
(
])
)
def forward(self, x):
= self.features(x)
x = x.view(x.size(0), -1) # Flatten the output
x = self.classifier(x)
x return x
Note the first layer in the features
expects three input channels because images in our dataset are color (RGB). Adjust the numbers of input and output channels to fit your dataset.
Initializing the model and moving it to the GPU (if available) is straightforward:
= SimpleCNN()
model = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device model.to(device)
Loss function and optimizer are critical components in training your model. Cross-entropy loss works well for classification problems:
= nn.CrossEntropyLoss()
criterion = torch.optim.Adam(model.parameters(), lr=0.001) optimizer
Now, let’s sketch out the training loop:
for epoch in range(num_epochs): # Loop over the dataset multiple times
= 0.0
running_loss for i, data in enumerate(train_loader, 0):
= data
inputs, labels = inputs.to(device), labels.to(device)
inputs, labels
optimizer.zero_grad()
# Forward + backward + optimize
= model(inputs)
outputs = criterion(outputs, labels)
loss
loss.backward()
optimizer.step()
+= loss.item()
running_loss if i % 2000 == 1999: # Print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
+ 1, i + 1, running_loss / 2000))
(epoch = 0.0
running_loss
print('Finished Training')
Remember, this is a very general representation of your training loop and you will need to fill in several details specific to your situation, like the number of epochs or how often you log the data. As for the train_loader
, it should be a PyTorch DataLoader that provides batches of images and labels to iterate over.
After your model is trained, you can save it to disk:
'simple_cnn.pth') torch.save(model.state_dict(),
Similarly, to load a saved model:
= SimpleCNN()
model 'simple_cnn.pth'))
model.load_state_dict(torch.load(eval() # Set the model to evaluation mode model.
The evaluation mode will notify all your layers that you are in inferencing mode and not training mode. This is particularly important for certain types of layers that have different behavior during training and inference, like dropout layers.
That was a whirlwind tour! But with this basic setup, you’ve built and trained a model to classify images using PyTorch. Keep experimenting with different architectures, and you’ll surely improve the performance. Happy coding!
Training and Evaluating the Model
Once our image classification model is constructed in PyTorch, the next vital steps are training and evaluating it. I’ll break down these processes into digestible stages so you can understand how to effectively train your model and assess its performance.
Training isn’t just about feeding data to the model; it’s about tweaking and tuning the model while it learns. The evaluation, on the other hand, is the true test of how well your model generalizes to unseen data. Let’s dive into the code.
# Assuming we have our DataLoader objects -> train_loader and val_loader
import torch.optim as optim
from torch import nn
# Define the loss function and the optimizer
= nn.CrossEntropyLoss()
criterion = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
optimizer
# Training function
def train_model(epoch_count):
for epoch in range(epoch_count): # loop over the dataset multiple times
= 0.0
running_loss for i, data in enumerate(train_loader, 0):
# get the inputs; data is a list of [inputs, labels]
= data
inputs, labels
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
= model(inputs)
outputs = criterion(outputs, labels)
loss
loss.backward()
optimizer.step()
# print statistics
+= loss.item()
running_loss if i % 2000 == 1999: # print every 2000 mini-batches
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
= 0.0
running_loss
print('Finished Training')
= 5
number_of_epochs train_model(number_of_epochs)
After several epochs, you should see the loss decreasing. Now let’s evaluate how our model performs on data it hasn’t seen during training—the validation set.
= 0
correct = 0
total # since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
for data in val_loader:
= data
images, labels # calculate outputs by running images through the network
= model(images)
outputs # the class with the highest energy is what we choose as prediction
= torch.max(outputs.data, 1)
_, predicted += labels.size(0)
total += (predicted == labels).sum().item()
correct
print(f'Accuracy of the network on the validation images: {100 * correct // total} %')
Achieving a high accuracy is satisfying, but it’s crucial to remember to not overfit on the validation set. If the accuracy isn’t up to scratch, we may need to go back and adjust the training parameters.
Occasionally, you’ll see different behaviors in the loss trends. Maybe the loss stagnates, or your validation accuracy starts to decline after a certain number of epochs, a classic sign of overfitting. If that occurs, you might want to introduce techniques like dropout, data augmentation, or tweak learning rates and batch sizes.
And that sums up the training and evaluation of our image classification model using PyTorch. I’ve personally found the process quite thrilling—at times it’s a bit like trying to solve a puzzle. After training, don’t forget to save your model:
'model.pth')
torch.save(model.state_dict(), print('Saved trained model')
This ensures that you can load the trained model at any time and either continue training or use it for inference without starting from scratch.
As a beginner, stepping through these steps systematically will solidify your understanding and give you the confidence to experiment with more complex models and datasets. Each time you train a model, think of it as adding a new piece to your machine learning toolkit.
Remember, this is just a starting point. The world of deep learning is vast and full of possibilities, and each model you build further enriches your journey. Keep experimenting, and never hesitate to dive into the documentation or source code to deepen your understanding.