本节将介绍在pytorch中非常重要的类:nn.Module
。在实现自己设计的网络时,必须要继承这个类,示例写法如下
import torch
import torch.nn as nn
import torch.nn.functional as F
# 先定义自己的类
class MyNN(nn.Module):
def __init__(self, inp, outp):
# 初始化自己定义的类
super(MyNN, self).__init__()
self.w = nn.Parameter(torch.randn(outp, inp))
self.b = nn.Parameter(torch.randn(outp))
# 定义前向传播
def forward(self, x):
x = x @ self.w.t() + self.b
return x
那么nn.Module
这个类有哪些功能?
nn.Module
提供了很多已经编写好的功能,如Linear
、ReLU
、Sigmoid
、Conv2d
、ConvTransposed2d
、Dropout
...self.net = nn.Sequential(
# .Sequential()相当于设置了一个容器(Container)
# 将需要进行forward的函数写在其中
nn.Conv2d(1, 32, 5, 1, 1),
nn.MaxPool2d(2, 2),
nn.ReLU(True),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, 3, 1, 1),
nn.ReLU(True),
nn.BatchNorm2d(64),
nn.Conv2d(64, 64, 3, 1, 1),
nn.MaxPool2d(2, 2),
nn.ReLU(True),
nn.BatchNorm2d(64),
nn.Conv2d(64, 128, 3, 1, 1),
nn.ReLU(True),
nn.BatchNorm2d(128)
)
或者需要将自己设计的层连接在一起的情况
class Faltten(nn.Module):
def __init__(self):
super(Faltten, self).__init__()
def forward(self, input):
return input.view(inputt.size(0), -1)
class TestNet(nn.Module):
def __init__(self):
super(TestNet, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(1, 16, stride=1, padding=1),
nn.MaxPool2d(2, 2),
Flatten(),
nn.Linear(1*14*14, 10)
)
def forward(self, x):
return self.net(x)
nn.Module
可以对网络中的参数进行有效的管理net = nn.Sequential(
nn.Linear(in_features=4, out_features=2),
nn.Linear(in_features=2, out_features=2)
)
# 隐藏层的编号是从0开始的
list(net.parameters())[0] # [0]是layer0的w
list(net.parameters())[3].shape # [3]是layer1的b
dict(net.named_parameters()).items() # 返回所有层的参数
optimizer = optim.SGD(net.parameters(), lr=1e-3)
输出
torch.Size([2, 4])
torch.Size([2])
dict_items([('0.weight', Parameter containing:
tensor([[ 0.0195, 0.4698, -0.4913, -0.3336],
[ 0.1422, 0.2908, -0.2469, 0.0583]], requires_grad=True)), ('0.bias', Parameter containing:
tensor([-0.4704, -0.1133], requires_grad=True)), ('1.weight', Parameter containing:
tensor([[-0.6511, 0.2442],
[ 0.5658, 0.4419]], requires_grad=True)), ('1.bias', Parameter containing:
tensor([ 0.0114, -0.5664], requires_grad=True))])
.device()
函数device = torch.device('cuda')
net = Net()
net.to(device)
torch.save(net.state_dict(), 'ckpt.mdl')
net.load_state_dict(torch.load('ckpt.mdl'))
# train
net.train()
# test
net.eval()
扫码关注腾讯云开发者
领取腾讯云代金券
Copyright © 2013 - 2025 Tencent Cloud. All Rights Reserved. 腾讯云 版权所有
深圳市腾讯计算机系统有限公司 ICP备案/许可证号:粤B2-20090059 深公网安备号 44030502008569
腾讯云计算(北京)有限责任公司 京ICP证150476号 | 京ICP备11018762号 | 京公网安备号11010802020287
Copyright © 2013 - 2025 Tencent Cloud.
All Rights Reserved. 腾讯云 版权所有