import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
def create_data_loader_cifar10():
transform = transforms.Compose([
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
batch_size = 256
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=10, pin_memory=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, num_workers=10)
return trainloader, testloaderdef train(net, trainloader):
print("Start training...")
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
epochs = 1
num_of_batches = len(trainloader)
for epoch in range(epochs):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
images, labels = inputs.cuda(), labels.cuda()
optimizer.zero_grad()
outputs = net(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'[Epoch {epoch + 1}/{epochs}] loss: {running_loss / num_of_batches:.3f}')
print('Finished Training')单GPU训练结果:总耗时69.03秒,训练1个epoch耗时13.08秒,测试准确率27%。
DataParallel是单进程、多线程的,仅适用于单机多GPU场景。它在每个GPU上使用相同的模型进行前向传播,将数据分散到各个GPU中。
net = torchvision.models.resnet50(False)
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
net = nn.DataParallel(net)使用DataParallel时,批大小应该能被GPU数量整除。该方法在每个前向传播中将模块复制到每个GPU上,产生显著开销。
def init_distributed():
dist_url = "env://"
rank = int(os.environ["RANK"])
world_size = int(os.environ['WORLD_SIZE'])
local_rank = int(os.environ['LOCAL_RANK'])
dist.init_process_group(
backend="nccl",
init_method=dist_url,
world_size=world_size,
rank=rank
)
torch.cuda.set_device(local_rank)
dist.barrier()net = torchvision.models.resnet50(False).cuda()
net = nn.SyncBatchNorm.convert_sync_batchnorm(net)
local_rank = int(os.environ['LOCAL_RANK'])
net = nn.parallel.DistributedDataParallel(net, device_ids=[local_rank])def create_data_loader_cifar10():
transform = transforms.Compose([
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
batch_size = 256
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
train_sampler = DistributedSampler(dataset=trainset, shuffle=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
sampler=train_sampler, num_workers=10, pin_memory=True)
return trainloader, testloader在每个epoch开始前调用trainloader.sampler.set_epoch(epoch)以确保多epoch间的正确洗牌。
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def is_main_process():
return get_rank() == 0混合精度结合了FP16和FP32在不同训练步骤中的使用,在减少训练时间的同时保持与FP32相当的性能。
fp16_scaler = torch.cuda.amp.GradScaler(enabled=True)
for epoch in range(epochs):
trainloader.sampler.set_epoch(epoch)
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
images, labels = inputs.cuda(), labels.cuda()
optimizer.zero_grad()
with torch.cuda.amp.autocast():
outputs = net(images)
loss = criterion(outputs, labels)
fp16_scaler.scale(loss).backward()
fp16_scaler.step(optimizer)
fp16_scaler.update()训练配置 | 时间(秒) |
|---|---|
单GPU(基准) | 13.2 |
DataParallel 4 GPU | 19.1 |
DistributedDataParallel 2 GPU | 9.8 |
DistributedDataParallel 4 GPU | 6.1 |
DistributedDataParallel 4 GPU + 混合精度 | 6.5 |
在理想的并行世界中,N个worker应该提供N倍的加速。实际中,使用4个GPU的DistributedDataParallel模式可以获得2倍的加速。混合精度训练通常能提供显著的加速,但在A100 GPU和其他基于Ampere的GPU架构上增益有限。
需要注意的是,DistributedDataParallel使用的有效批大小为4*256=1024,因此模型更新次数较少,这导致验证准确率较低(14%对比基准的27%)。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。