In this post, you will learn about how to load and predict using pre-trained Resnet model using PyTorch library. Here is arxiv paper on Resnet.
Before getting into the aspect of loading and predicting using Resnet (Residual neural network) using PyTorch, you would want to learn about how to load different pretrained models such as AlexNet, ResNet, DenseNet, GoogLenet, VGG etc. The PyTorch Torchvision projects allows you to load the models. Note that the torchvision package consists of popular datasets, model architectures, and common image transformations for computer vision. Here is the command:
from torchvision import models
dir(models)
The output of above will list down all the pre-trained models available for loading and prediction.
You may note that the list consists of number of Python classes such as AlexNet, ResNet (starting with capital letters) etc and a set of convenience methods related to each Python classes to create the model using different parameters including layers information. The following are convenience functions for loading ResNet models having different number of layers:
In this post, you will learn about how to use ResNet with 101 layers. Here are the four steps to loading the pre-trained model and making predictions using same:
Here is the details of above pipeline steps:
#
# Download the resnet-101 layers pre-trained model
#
resnet = models.resnet101(pretrained=True)
#
#
resnet
Here is how the loading ResNet will look like.
#
# Load the image
#
from PIL import Image
img_cat = Image.open("/Users/apple/Downloads/cat.png").convert('RGB')
Now, it is time to do some of the following for making the predictions using ResNet network. Same code can be applied for all kinds of ResNet network including some of the popular pre-trained ResNet models such as resnet-18, resnet-34, resnet-50, resnet-152 networks.
Here is how the pipeline will look like for predicting image type using ResNet model:
from torchvision import transforms
#
# Create a preprocessing pipeline
#
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)])
#
# Pass the image for preprocessing and the image preprocessed
#
img_cat_preprocessed = preprocess(img_cat)
#
# Reshape, crop, and normalize the input tensor for feeding into network for evaluation
#
batch_img_cat_tensor = torch.unsqueeze(img_cat_preprocessed, 0)
#
# Resnet is required to be put in evaluation mode in order
# to do prediction / evaluation
#
resnet.eval()
#
# Get the predictions of image as scores related to how the loaded image
# matches with 1000 ImageNet classes. The variable, out is a vector of 1000 scores
#
out = resnet(batch_img_cat_tensor)
#
# Load the file containing the 1,000 labels for the ImageNet dataset classes
#
with open('/Users/apple/Downloads/imagenet_classes.txt') as f:
labels = [line.strip() for line in f.readlines()]
#
# Find the index (tensor) corresponding to the maximum score in the out tensor.
# Torch.max function can be used to find the information
#
_, index = torch.max(out, 1)
#
# Find the score in terms of percentage by using torch.nn.functional.softmax function
# which normalizes the output to range [0,1] and multiplying by 100
#
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
#
# Print the name along with score of the object identified by the model
#
print(labels[index[0]], percentage[index[0]].item())
#
# Print the top 5 scores along with the image label. Sort function is invoked on the torch to sort the scores.
#
_, indices = torch.sort(out, descending=True)
[(labels[idx], percentage[idx].item()) for idx in indices[0][:5]]
Here is what you learned about loading the ResNet pre-trained model using PyTorch and doing the predictions:
In recent years, artificial intelligence (AI) has evolved to include more sophisticated and capable agents,…
Adaptive learning helps in tailoring learning experiences to fit the unique needs of each student.…
With the increasing demand for more powerful machine learning (ML) systems that can handle diverse…
Anxiety is a common mental health condition that affects millions of people around the world.…
In machine learning, confounder features or variables can significantly affect the accuracy and validity of…
Last updated: 26 Sept, 2024 Credit card fraud detection is a major concern for credit…
View Comments
Where did you download the `imagenet_classes.txt` from?