前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Pytorch-nn.Module

Pytorch-nn.Module

作者头像
用户6719124
发布2019-12-05 11:18:03
6300
发布2019-12-05 11:18:03
举报
文章被收录于专栏:python pytorch AI机器学习实践

本节介绍在pytorch中十分重要的“类”:nn.Module。

在实现自己设计的层结构功能时,必须要使用自己继承的类。

类的书写如下

代码语言:javascript
复制
import torch
import torch.nn as nn
import torch.nn.functional as F


class MyLinear(nn.Module):
    # 先定义自己的类
    def __init__(self, inp, outp):
        super(MyLinear, self).__init__()
        # 初始化自己定义的类

        self.w = nn.Parameter(torch.randn(outp, inp))
        self.b = nn.Parameter(torch.randn(outp))

    def forward(self, x):
        # 定义前向
        x = x @ self.w.t() + self.b
        return x

那么nn.Module类到底是什么?

(1)nn.Module在pytorch中是基本的复类,继承它后会很方便的使用nn.linear、nn.normalize等。

(2)还可以进行嵌套,便于书写树形结构

(3)nn.Module提供了很多已经编写好的功能,如Linear、ReLU、Sigmoid、Conv2d、ConvTransposed2d、Dropout等。

最主要的功能是书写代码方便

代码语言:javascript
复制
self.net = nn.Sequential(
    # .Sequential()相当于设定了一个容器,
    # 将需要进行forward的函数代入其中,
    # 但不用每一个步骤都写上,
    # 直接放在容器中,后面再定义一个forward代码即可
    nn.Conv2d(1, 32, 5, 1, 1),
    nn.MaxPool2d(2, 2),
    ...
    
)

使用nn.Module的第三个好处是可以对网络中的参数进行有效的管理

通过.parameters()即可很方便的对参数进行查看

代码语言:javascript
复制
net = nn.Sequential(nn.Linear(4, 2), nn.Linear(2, 2))
print(list(net.parameters()))[0].shape
# 输出查看第0层的参数

也可用.named_parameters()来输出网络结构编好名字的参数

代码语言:javascript
复制
print(list(net.named_parameters()))[0].shape

后续再加上.item(),来对各种属性进行查看

代码语言:javascript
复制
print(list(net.named_parameters()))[0].item()

另外nn.Module还可以自己定义类的顺序。

也可以很方便的将所有的运算都转入到GPU上去。使用.device函数,

代码语言:javascript
复制
device = torch.device('cuda')
net = Net()
net.to(device)

还可以很方便的进行save和load,以防止突然发生的断点和系统崩溃的现象

代码语言:javascript
复制
net.load_state_dict(torch.load('ckpt.mdl'))
torch.save(net.state_dict(), 'ckpt.mdl')

nn.Modele还可以很方便的切换状态

代码语言:javascript
复制
# 切换到train状态
net.train()
# 切换到test
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-12-02,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 python pytorch AI机器学习实践 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
容器服务
腾讯云容器服务(Tencent Kubernetes Engine, TKE)基于原生 kubernetes 提供以容器为核心的、高度可扩展的高性能容器管理服务,覆盖 Serverless、边缘计算、分布式云等多种业务部署场景,业内首创单个集群兼容多种计算节点的容器资源管理模式。同时产品作为云原生 Finops 领先布道者,主导开源项目Crane,全面助力客户实现资源优化、成本控制。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档