PyTorch-Lightning 是一个基于 PyTorch 的轻量级高级封装库,旨在简化深度学习模型训练过程,并提供更多功能和组件。下面是使用 PyTorch-Lightning 进行简单预测的示例:
pip install torch
pip install pytorch-lightning
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
class SimpleModel(pl.LightningModule):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = F.softmax(self.fc2(x), dim=1)
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
# 数据处理和加载
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
train_dataset = MNIST('data/', train=True, download=True, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# 训练和预测
model = SimpleModel()
trainer = pl.Trainer(max_epochs=5, progress_bar_refresh_rate=20)
trainer.fit(model, train_dataloader)
# 预测
test_dataset = MNIST('data/', train=False, download=True, transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=32)
predictions = []
model.eval()
for batch in test_dataloader:
x, _ = batch
with torch.no_grad():
y_hat = model(x)
predictions.extend(torch.argmax(y_hat, dim=1).tolist())
print(predictions)
在上面的示例中,我们首先定义了一个简单的模型 SimpleModel
,继承自 pl.LightningModule
。模型包含两个全连接层,用于 MNIST 手写数字分类。
在 training_step
方法中,我们定义了训练的逻辑,计算模型的输出与真实标签的交叉熵损失,并将损失记录到训练日志中。
configure_optimizers
方法定义了优化器的设置,这里使用 Adam 优化器。
然后,我们创建了数据处理和加载的管道,加载了 MNIST 数据集并进行相应的预处理。
接着,创建了模型和训练器对象,并使用 fit
方法进行训练。
最后,我们加载测试数据集,并使用训练好的模型进行预测。预测过程中,将每个样本的预测结果存储在 predictions
列表中。
这是一个简单的使用 PyTorch-Lightning 进行预测的示例。通过使用 PyTorch-Lightning,我们可以更简单地定义和训练深度学习模型,提高开发效率。关于 PyTorch-Lightning 更详细的信息和功能,请参考腾讯云的 PyTorch-Lightning 产品介绍。
领取专属 10元无门槛券
手把手带您无忧上云