在PyTorch中显示错误分类的图像,可以通过以下步骤实现:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# 加载预训练模型
net = torch.load('model.pth')
def show_misclassified_images(model, testloader, classes):
model.eval()
misclassified_images = []
correct_labels = []
predicted_labels = []
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = model(images)
_, predicted = torch.max(outputs, 1)
misclassified_idx = (predicted != labels).nonzero()
for idx in misclassified_idx:
misclassified_images.append(images[idx])
correct_labels.append(labels[idx])
predicted_labels.append(predicted[idx])
# 显示错误分类的图像
num_images = len(misclassified_images)
rows = int(np.sqrt(num_images))
cols = int(np.ceil(num_images / rows))
fig, axes = plt.subplots(rows, cols, figsize=(10, 10))
fig.suptitle('Misclassified Images', fontsize=20)
for i, ax in enumerate(axes.flat):
if i < num_images:
image = misclassified_images[i] / 2 + 0.5 # 反归一化
image = image.numpy().transpose((1, 2, 0))
ax.imshow(image)
ax.set_title(f'Predicted: {classes[predicted_labels[i]]}\nActual: {classes[correct_labels[i]]}')
ax.axis('off')
plt.tight_layout()
plt.show()
show_misclassified_images(net, testloader, classes)
这样,你就可以在PyTorch中显示错误分类的图像了。该函数会遍历测试集中的样本,找出模型预测错误的图像,并将其显示出来,同时显示预测的类别和实际的类别。
领取专属 10元无门槛券
手把手带您无忧上云