torchvision.datasets.ImageFolder
是PyTorch中一个常用的数据加载工具,用于从文件夹结构加载图像数据集。它假设数据集的组织结构如下:
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
其中每个子文件夹代表一个类别,包含该类别下的所有图像文件。
问题原因:ImageFolder要求特定的文件夹结构,如果不符合这种结构会导致加载失败。
解决方案:
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
# 正确的使用方式
transform = transforms.Compose([transforms.ToTensor()])
dataset = ImageFolder(root='path/to/your/dataset', transform=transform)
问题原因:ImageFolder默认支持的图像格式有限(如.jpg, .png等),如果包含不支持的格式会报错。
解决方案:
PIL
支持的格式from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
# 添加支持的扩展名
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
dataset = ImageFolder(
root='path/to/dataset',
transform=transforms.ToTensor(),
is_valid_file=lambda x: x.lower().endswith(IMG_EXTENSIONS)
)
问题原因:程序没有足够的权限访问图像文件。
解决方案:
问题原因:某些图像文件可能已损坏,导致加载失败。
解决方案:
verify_images
函数检查损坏文件from PIL import Image
import os
def verify_images(folder_path):
for root, _, files in os.walk(folder_path):
for file in files:
try:
img_path = os.path.join(root, file)
Image.open(img_path).verify()
except (IOError, SyntaxError) as e:
print(f'Bad file: {img_path}')
os.remove(img_path)
verify_images('path/to/your/dataset')
问题原因:某些系统对中文路径或特殊字符支持不好。
解决方案:
问题原因:加载大量高分辨率图像可能导致内存不足。
解决方案:
DataLoader
分批加载from torch.utils.data import DataLoader
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
dataset = ImageFolder('path/to/dataset', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
import os
print(os.listdir('path/to/your/dataset'))
from PIL import Image
Image.open('path/to/single/image.jpg')
torch.utils.data.Dataset
实现自定义数据集类。from torch.utils.data import Dataset
from PIL import Image
import os
class CustomImageDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.classes = sorted(os.listdir(root_dir))
self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
self.samples = self._make_dataset()
def _make_dataset(self):
samples = []
for target_class in self.classes:
class_idx = self.class_to_idx[target_class]
target_dir = os.path.join(self.root_dir, target_class)
for root, _, fnames in os.walk(target_dir):
for fname in fnames:
path = os.path.join(root, fname)
samples.append((path, class_idx))
return samples
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
path, target = self.samples[idx]
img = Image.open(path).convert('RGB')
if self.transform:
img = self.transform(img)
return img, target
DataLoader
进行批量加载和并行处理通过以上方法和调试技巧,应该能够解决大多数ImageFolder加载图像的问题。
没有搜到相关的文章