转载请注明出处:小锋学长生活大爆炸[xfxuezhagn.cn] 如果本文帮助到了你,欢迎[点赞、收藏、关注]哦~
目录
sampler: Sampler | Iterable = None
batch_sampler: Sampler | Iterable = None
worker_init_fn: Callable = None
generator: torch.Generator = None
persistent_workers: bool = False
数据加载器(DataLoader) 将数据集(dataset)与采样器(sampler)结合起来,并提供一个可迭代对象,用于遍历给定的数据集。torch.utils.data.DataLoader 支持 map-style(映射式) 和 iterable-style(可迭代式) 两种数据集,支持单进程或多进程加载,可以自定义加载顺序,并可选择是否自动分批(collation)以及是否将数据固定到内存(pinning)。

torch.utils.data.DataLoader的所有参数:
class DataLoader(
# 要加载的数据集。
dataset: Dataset,
# 每个批次加载多少样本(默认:1)。
batch_size: int | None = 1,
# 若为 True,则在每个 epoch 开始时打乱数据(默认:False)。
shuffle: bool | None = None,
# 定义从数据集中抽取样本的策略。可以是任何实现了 __len__ 的 Iterable。如果指定了 sampler,则不能再设置 shuffle。
sampler: Sampler | Iterable | None = None,
# 与 sampler 类似,但一次返回一批索引。与 batch_size、shuffle、sampler 和 drop_last 互斥。
batch_sampler: Sampler[List] | Iterable[List] | None = None,
# 用于数据加载的子进程数量。设为 0 时,数据将在主进程中加载(默认:0)。
num_workers: int = 0,
# 将一个样本列表合并成一个小批量(mini-batch)张量。通常在从 map-style 数据集进行批量加载时使用。
collate_fn: _collate_fn_t | None = None,
# 若为 True,则 DataLoader 会在返回前把张量复制到设备/CUDA 的 锁页内存(pinned memory)。如果你的数据元素是自定义类型,或者 collate_fn 返回的是自定义类型。
pin_memory: bool = False,
# 若为 True,当数据集大小不能被 batch_size 整除时,丢弃最后一个不完整的批次。若为 False,则保留最后一个较小的批次(默认:False)。
drop_last: bool = False,
# 若大于 0,表示从 worker 获取一个批次的超时时间(秒)。必须为非负数(默认:0)。
timeout: float = 0,
# 若不为 None,则在每个 worker 子进程启动时调用,输入参数为 worker id(范围 [0, num_workers - 1])。它会在设置随机种子之后、开始加载数据之前被调用(默认:None)。
worker_init_fn: _worker_init_fn_t | None = None,
# 若为 None,则使用操作系统的默认多进程上下文(multiprocessing context)(默认:None)。
multiprocessing_context: Any | None = None,
# 若不为 None,则该随机数生成器将被 RandomSampler 用于生成随机索引,并被多进程机制用于生成 worker 的基础随机种子(默认:None)。
generator: Any | None = None,
# 每个 worker 预先加载的批次数。2 表示总共有 2 * num_workers 个批次被预取。默认值取决于 num_workers 的设置。如果 num_workers=0,则默认是 None;否则默认是 2。
prefetch_factor: int | None = None,
# 若为 True,则在数据集被消耗完一次后,worker 进程不会关闭。这允许在多个 epoch 之间保持 worker 中的 Dataset 实例常驻(默认:False)。
persistent_workers: bool = False,
# 当 pin_memory=True 时,指定锁页内存绑定到的设备。
pin_memory_device: str = ""
)
数据来源对象,告诉 DataLoader 去哪里拿样本。pytorch提供的torch.utils.data.Dataset类是一个抽象基类,供用户继承,编写自己的dataset,实现对数据的读取。允许两种格式的dataset:
注意:
shuffle/sampler/batch_sampler(很多会被忽略或报错)。
num_workers>0)会在子进程里各自构造 dataset,所以 dataset 内的对象要可被子进程安全创建/序列化。
__init__,放到 __getitem__/__iter__ 里。
示例(map-style):
from torch.utils.data import Dataset
class ToySet(Dataset):
def __init__(self, xs, ys):
self.xs, self.ys = xs, ys
def __len__(self):
return len(self.xs)
def __getitem__(self, idx):
return self.xs[idx], self.ys[idx]
示例(iterable-style,含多 worker 分片):
from torch.utils.data import IterableDataset, get_worker_info
class StreamSet(IterableDataset):
def __iter__(self):
info = get_worker_info()
wid, wnum = (info.id, info.num_workers) if info else (0, 1)
for i, sample in enumerate(infinite_source()): # 你自己的流
if i % wnum == wid: # 简单分片,避免重复
yield sample
每个批次里包含的样本数。控制一次前向/反向要处理多少样本,影响吞吐、显存占用与稳定性。默认 1,表示 不做批划分,每次返回单条样本。
注意:
batch_sampler 互斥。指定了 batch_sampler 时不要再传 batch_size。
stack expects each tensor to be equal size),需要自定义 collate_fn 做 padding/对齐。
drop_last 决定。
示例:
from torch.utils.data import DataLoader
loader = DataLoader(ToySet(xs, ys), batch_size=32)
for xb, yb in loader:
...
每个 epoch 是否打乱样本顺序。可以避免固定顺序带来的偏差,提高泛化。
注意:
sampler、batch_sampler 互斥。如果你已经自定义 sampler(如 RandomSampler),shuffle 必须设为 False,否则会报错。
DistributedSampler(shuffle=True) 接管打乱与分片,不再单独传 shuffle=True。
示例:
loader = DataLoader(ToySet(xs, ys), batch_size=64, shuffle=True)
返回单条样本的索引。可以自定义“取样本”的策略,实现加权采样、不放回采样、分布式分片等。
注意:
sampler 后不要再传 shuffle。
DistributedSampler 时,记得 sampler.set_epoch(epoch),确保每轮乱序不同。
示例(类别不均衡的权重采样):
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
weights = torch.tensor([1, 5, 1, 2, ...], dtype=torch.double) # 每样本权重
sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)
loader = DataLoader(ToySet(xs, ys), batch_size=32, sampler=sampler)
直接返回每个 batch的索引列表。可以把“打乱+分批”的逻辑完全自定义(如按长度分桶、同类打包)。
注意:
batch_size、shuffle、sampler、drop_last 互斥(这些都交给它控制)。
示例(用内置 BatchSampler 包一层):
from torch.utils.data import DataLoader, RandomSampler, BatchSampler
base = RandomSampler(range(len(xs))) # 先有一个样本级采样器
batches = BatchSampler(base, batch_size=16, drop_last=False) # 再打成批
loader = DataLoader(ToySet(xs, ys), batch_sampler=batches)
用于数据加载的子进程个数;0 表示在主进程中加载(无并行)。I/O 或解码较重时,可以提高吞吐(读取与训练并行)。
注意:
spawn:传入的对象(dataset、collate_fn 等)必须可被 pickle;避免 lambda、内嵌函数。
prefetch_factor、persistent_workers。
示例:
loader = DataLoader(ToySet(xs, ys), batch_size=64, num_workers=4)
把“样本列表”合并成“一个 batch”的函数。可以实现对齐变长序列(padding)、拼接字典/列表、返回自定义 Batch 对象等。
注意:
default_collate 已能处理 Tensor/数值/np/dict/list 等“等长”数据。
collate_fn 完成对齐。
__getitem__。
示例(序列 padding):
import torch
from torch.nn.utils.rnn import pad_sequence
def pad_collate(batch):
xs, ys = zip(*batch)
xs = [torch.as_tensor(x) for x in xs]
ys = torch.as_tensor(ys)
return pad_sequence(xs, batch_first=True), ys
loader = DataLoader(ToySet(xs, ys), batch_size=32, collate_fn=pad_collate)
pin_memory: bool = False把从 CPU 读取的数据直接拷贝到 页锁定(pinned)内存,当随后调用 tensor.cuda(non_blocking=True) 时,GPU 能通过 DMA 直接读取,显著提升传输带宽。
注意:
collate_fn 返回的是自定义对象,需要实现 .pin_memory() 或在 collate_fn 内手动对每个张量 pin_memory()。
示例(配合非阻塞拷贝):
device = "cuda"
loader = DataLoader(ToySet(xs, ys), batch_size=64, pin_memory=True)
for xb, yb in loader:
xb = xb.to(device, non_blocking=True)
yb = yb.to(device, non_blocking=True)
当数据量不是 batch_size 的整数倍时,是否丢弃最后那个“不满”的批。可以在需要固定 batch 大小的场景(如 BatchNorm、固定步长的梯度累积)保持一致性。
注意:训练可能设 True;验证/测试一般 False。
示例:
loader = DataLoader(ToySet(xs, ys), batch_size=64, drop_last=True)
从 workers 收集一个 batch 的超时秒数(非负)。可以排查卡死/死锁(到点抛错)。
注意:
num_workers>0 才有意义;设太小会误报。示例:
loader = DataLoader(ToySet(xs, ys), batch_size=64, num_workers=4, timeout=30)
每个 worker 进程启动后、正式取数据前会被调用一次的函数,入参是 worker_id。可以设置 numpy/random 等外部库的种子、全局变量初始化、初始化三方资源。
注意:
spawn 启动方式(Win/macOS 常见)下,必须是可 pickle 的顶层函数,不能用 lambda。
generator 一起用。
示例(可复现设置):
import random, numpy as np, torch
def seed_worker(worker_id):
seed = torch.initial_seed() % 2**32
random.seed(seed); np.random.seed(seed)
g = torch.Generator().manual_seed(42)
loader = DataLoader(
ToySet(xs, ys), shuffle=True, num_workers=4,
worker_init_fn=seed_worker, generator=g
)
指定 Python 多进程的“启动方式”:"fork" | "spawn" | "forkserver" 或对应 context。当某些库在 fork 下不安全(线程、显存句柄)时,改用 spawn 更稳。
注意:
spawn;Linux 默认 fork。
spawn 更通用但启动更慢。
示例:
ctx = torch.multiprocessing.get_context("spawn")
loader = DataLoader(ToySet(xs, ys), num_workers=4, multiprocessing_context=ctx)
DataLoader 用到的随机数发生器。可以控制乱序/随机采样,同时为 workers 生成基础种子。
注意:
generator.manual_seed(...) + 在 worker_init_fn 同步 numpy/random 的种子。
示例:
g = torch.Generator().manual_seed(2024)
loader = DataLoader(ToySet(xs, ys), shuffle=True, generator=g)
每个 worker 会 预取 prefetch_factor * batch_size 条样本放进内部队列。可以让读取/解码与训练形成流水线,减少“等数据”。 注意:
num_workers==0 → None(无预取);num_workers>0 → 2。
示例:
loader = DataLoader(ToySet(xs, ys), batch_size=64, num_workers=4, prefetch_factor=3)
True 时,跨 epoch 不销毁 worker 进程(复用已创建的 Dataset 实例)。长时间训练更高效,避免反复建/毁进程(特别是 num_workers 较大时)。
注意:
num_workers>0 生效。
示例:
loader = DataLoader(ToySet(xs, ys), num_workers=4, persistent_workers=True)
当 pin_memory=True 时,指定 pinned memory 绑定到哪个设备(比如 "cuda" 或 "cuda:0")。
更细粒度地控制固定内存的目标设备;多数场景留默认即可。 注意:
pin_memory=True;只有在多设备或特定优化需求时才设置它。
示例:
loader = DataLoader(
ToySet(xs, ys),
batch_size=64,
pin_memory=True,
pin_memory_device="cuda" # 或 "cuda:0"
)
spawn 启动方式:如果使用 spawn 启动多进程,worker_init_fn 不能是不可序列化对象(例如 lambda 函数)。详情见 multiprocessing 最佳实践。
len(dataloader) 的计算:
len(dataloader) 基于所用的 sampler 的长度。
IterableDataset 时,len(dataloader) 返回的估计值为 len(dataset) / batch_size,并根据 drop_last 进行四舍五入,与是否多进程无关。
dataset 能正确处理多进程加载以避免数据重复。
drop_last=True 时,可能会丢弃超过一个批次的数据。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。