我使用Pytorch和FashionMNIST数据集,我想显示10个类中的每个类的8个图像样本。但是,我不知道如何将训练测试分成train_labels,因为我需要循环每个类的标签(类)并打印8个。知道我怎么能做到这一点吗?
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot')
# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
# transforms.Lambda(lambda x: x.repeat(3,1,1)),
transforms.Normalize((0.5, ), (0.5,))])
# Download and load the training data
trainset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)
# Download and load the test data
testset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True)
print('Training set size:', len(trainset))
print('Test set size:',len(testset))
发布于 2021-01-01 07:18:35
如果我正确理解您的意思,您希望按标签对数据集进行分组,然后显示它们。
您可以从构造一个字典开始,通过标签存储示例:
examples = {i: [] for i in range(len(classes))}
然后遍历该列集并使用标签的索引追加到列表中:
for x, i in trainset:
examples[i].append(x)
然而,这将贯穿整个过程。如果您希望提前停止并避免每堂课收集超过8次,可以通过添加条件来做到这一点:
n_examples = 8
for x, i in trainset:
if all([len(ex) == n_examples for ex in examples.values()])
break
if len(examples[i]) < n_examples:
examples[i].append(x)
剩下的就是用torchvision.transforms.ToPILImage
显示
transforms.ToPILImage()(examples[3][0])
如果您想显示不止一个,可以使用两个连续的torch.cat
,一个在dim=1
(按行)上,然后在dim=2
(按列)上创建网格。
grid = torch.cat([torch.cat(examples[i], dim=1) for i in range(len(classes))], dim=2)
transforms.ToPILImage()(grid)
可能的结果:
https://stackoverflow.com/questions/65528954
复制相似问题