本节介绍在pytorch中十分重要的“类”:nn.Module。
在实现自己设计的层结构功能时,必须要使用自己继承的类。
类的书写如下
import torch
import torch.nn as nn
import torch.nn.functional as F
class MyLinear(nn.Module):
# 先定义自己的类
def __init__(self, inp, outp):
super(MyLinear, self).__init__()
# 初始化自己定义的类
self.w = nn.Parameter(torch.randn(outp, inp))
self.b = nn.Parameter(torch.randn(outp))
def forward(self, x):
# 定义前向
x = x @ self.w.t() + self.b
return x
那么nn.Module类到底是什么?
(1)nn.Module在pytorch中是基本的复类,继承它后会很方便的使用nn.linear、nn.normalize等。
(2)还可以进行嵌套,便于书写树形结构
(3)nn.Module提供了很多已经编写好的功能,如Linear、ReLU、Sigmoid、Conv2d、ConvTransposed2d、Dropout等。
最主要的功能是书写代码方便
self.net = nn.Sequential(
# .Sequential()相当于设定了一个容器,
# 将需要进行forward的函数代入其中,
# 但不用每一个步骤都写上,
# 直接放在容器中,后面再定义一个forward代码即可
nn.Conv2d(1, 32, 5, 1, 1),
nn.MaxPool2d(2, 2),
...
)
使用nn.Module的第三个好处是可以对网络中的参数进行有效的管理
通过.parameters()即可很方便的对参数进行查看
net = nn.Sequential(nn.Linear(4, 2), nn.Linear(2, 2))
print(list(net.parameters()))[0].shape
# 输出查看第0层的参数
也可用.named_parameters()来输出网络结构编好名字的参数
print(list(net.named_parameters()))[0].shape
后续再加上.item(),来对各种属性进行查看
print(list(net.named_parameters()))[0].item()
另外nn.Module还可以自己定义类的顺序。
也可以很方便的将所有的运算都转入到GPU上去。使用.device函数,
device = torch.device('cuda')
net = Net()
net.to(device)
还可以很方便的进行save和load,以防止突然发生的断点和系统崩溃的现象
net.load_state_dict(torch.load('ckpt.mdl'))
torch.save(net.state_dict(), 'ckpt.mdl')
nn.Modele还可以很方便的切换状态
# 切换到train状态
net.train()
# 切换到test
本文分享自 python pytorch AI机器学习实践 微信公众号,前往查看
如有侵权,请联系 cloudcommunity@tencent.com 删除。
本文参与 腾讯云自媒体同步曝光计划 ,欢迎热爱写作的你一起参与!