Loading [MathJax]/jax/output/CommonHTML/config.js
前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >CIFAR10数据集实战-LeNet5神经网络(下)

CIFAR10数据集实战-LeNet5神经网络(下)

作者头像
用户6719124
发布于 2020-01-02 06:21:28
发布于 2020-01-02 06:21:28
65300
代码可运行
举报
运行总次数:0
代码可运行

下面开始加入test部分

先写入test部分代码

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
for x, label in cifar_test:
    x, label = x.to(device), label.to(device)

    logits = model(x)
    pred = logits.armax(dim=1)
    # 用argmax选出可能性最大的值的索引

为进行比对

定义正确率

写入对比

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
total_correct += torch.eq(pred, label).float().sum().item()
# torch.eq函数用于对比,同时要转为numpy数据
total_num += x.size(0)

再定义正确率并输出

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
acc = total_correct / total_num
print('acc:', acc)

可以加入模式切换

Model.train()和model.eval()

最终main.py文件为

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import torch
from torchvision import datasets
# 引入pytorch、datasets工具包
from torchvision import transforms
# 引入数据变换工具包
from torch.utils.data import DataLoader
# 多线程数据读取
from LeNet5 import LeNet5
import torch.nn as nn

import torch.optim as optim
def main():

    batchsz=32
    # 这个batch_size数值不宜太大也不宜过小

    cifar_train = datasets.CIFAR10('cifar', train=True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        # .Compose相当于一个数据转换的集合
        # 进行数据转换,首先将图片统一为32*32
        transforms.ToTensor()
        # 将数据转化到Tensor中

    ]), download=True)
    # 直接在datasets中导入CIFAR10数据集,放在"cifar"文件夹中

    cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
    # 按照其要求,这里的参数需要有batch_size,
    # 在该部分代码前面定义batch_size
    # 再使数据加载的随机化



    cifar_test = datasets.CIFAR10('cifar', train=False, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor()
    ]), download=True)

    cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)


    x, label = iter(cifar_train).next()
    # 通过.iter方法输出一个数据进行查看
    # print('s.shape:', x.shape, 'label.shape:', label.shape)
    # 输出shape进行查看




    device = torch.device('cuda')
    model = LeNet5().to(device)
    criteon = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    model.train()
    for epoch in range(1000):

        for batchidx, (x, label) in enumerate(cifar_train):
            # batchidx代表了有多少个batch,
            x, label = x.to(device), label.to(device)

            logits = model(x)
            loss = criteon(logits, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


        # print(epoch, loss.item())

        model.eval()
        total_correct = 0
        total_num = 0

        for x, label in cifar_test:
            x, label = x.to(device), label.to(device)

            logits = model(x)
            pred = logits.argmax(dim=1)
            # 用argmax选出可能性最大的值的索引
            # 进行比对
            total_correct += torch.eq(pred, label).float().sum().item()
            # torch.eq函数用于对比,同时要转为numpy数据
            total_num += x.size(0)
        acc = total_correct / total_num
        print('acc:', acc)

输出为

可以看出正确率在逐渐上升

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-12-27,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 python pytorch AI机器学习实践 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
05-PyTorch自定义数据集Datasets、Loader和tranform
对于机器学习中的许多不同问题,我们采取的步骤都是相似的。PyTorch 有许多内置数据集,用于大量机器学习基准测试。除此之外也可以自定义数据集,本问将使用我们自己的披萨、牛排和寿司图像数据集,而不是使用内置的 PyTorch 数据集。具体来说,我们将使用 torchvision.datasets 以及我们自己的自定义 Dataset 类来加载食物图像,然后我们将构建一个 PyTorch 计算机视觉模型,希望对三种物体进行分类。
renhai
2023/11/24
1.1K0
05-PyTorch自定义数据集Datasets、Loader和tranform
深度学习实战之手写签名识别(100%准确率、语音播报)
在完成了上述的环境搭建后,即可进入到准备阶段了。这里准备的有数据集的准备、以及相关代码的主备。
陶陶name
2022/05/13
1.7K0
CIFAR-10数据集实战——构建LeNet5神经网络
如果从官网下载数据集很慢,可以使用国内的地址http://ai-atest.bj.bcebos.com/cifar-10-python.tar.gz
mathor
2020/01/22
9700
CIFAR10数据集实战-LeNet5神经网络(中)
本节介绍在LeNet5中求loss的操作。 本结构使用CrossEntropyLoss进行求loss 首先引入工具包 import torch.nn.functional as F 加入代码 self.criteon = nn.CrossEntropyLoss() 返回logits return logits 下面开始写运行函数 返回main.py文件中 为加快运算速度,定义硬件加速 device = torch.device('cuda') 设置迭代次数 for epoch in range(1000):
用户6719124
2019/12/19
6170
PyTorch实现ResNet18
发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/141287.html原文链接:https://javaforall.cn
全栈程序员站长
2022/09/01
9891
PyTorch实现ResNet18
【论文复现】LeNet-5
LeNet是最早的卷积神经网络之一。1998年,Yann LeCun第一次将LeNet卷积神经网络应用到图像分类上,在手写数字识别任务中取得了巨大成功。LeNet通过连续使用卷积和池化层的组合提取图像特征。 出自论文 《Gradient-Based Learning Applied to Document Recognition》 。
Eternity._
2024/11/30
3510
【论文复现】LeNet-5
CIFAR-10 数据集实战——构建ResNet18神经网络
Block中进行了正则化处理,以使train过程更快更稳定。同时要考虑,如果两元素的ch_in和ch_out不匹配,进行加法时会报错,因此需要判断一下,如果不想等,就用1×1的卷积调整一下
mathor
2020/01/22
1.7K0
深度学习实战之垃圾分类
垃圾分类,指按一定规定或标准将垃圾分类储存、分类投放和分类搬运,从而转变成公共资源的一系列活动的总称。分类的目的是提高垃圾的资源价值和经济价值,力争物尽其用;然而我们在日常生活中认为对垃圾分类还是有些不知所措的,对干垃圾、湿垃圾……分的不是很清楚,由此我们就想到了使用深度学习的方法进行分类。简介 本篇博文主要会带领大家进行数据的预处理、网络搭建、模型训练、模型测试 1. 获取数据集 这里笔者已经为大家提供了一个比较完整的数据集,所以大家不必再自己去收集数据了 数据集链接:https://pan.baidu
陶陶name
2022/05/13
6420
LeNet-5(论文复现)
LeNet是最早的卷积神经网络之一。1998年,Yann LeCun第一次将LeNet卷积神经网络应用到图像分类上,在手写数字识别任务中取得了巨大成功。LeNet通过连续使用卷积和池化层的组合提取图像特征。 出自论文《Gradient-Based Learning Applied to Document Recognition》。
Srlua
2024/11/30
2040
LeNet-5(论文复现)
我用 PyTorch 复现了 LeNet-5 神经网络(CIFAR10 数据集篇)!
我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)!
红色石头
2022/01/10
1.3K0
我用 PyTorch 复现了 LeNet-5 神经网络(CIFAR10 数据集篇)!
PyTorch 实战(模型训练、模型加载、模型测试)
https://github.com/huzixuan1/Loader_DateSet
陶陶name
2022/05/12
2.6K0
基于Pytorch构建LeNet网络对cifar-10进行分类
LeNet5诞生于1994年,是最早的卷积神经网络之一,是Yann LeCun等人在多次研究后提出的最终卷积神经网络结构,是一种用于手写体字符识别非常高效的网络。一般LeNet即指代LeNet5。LeNet5 这个网络虽然很小,但是它包含了深度学习的基本模块:卷积层,池化层,全连接层,是其他深度学习模型的基础。
python与大数据分析
2023/09/03
5240
基于Pytorch构建LeNet网络对cifar-10进行分类
CIFAR10数据集实战-数据读取部分(下)
这里设置了batch_size=32,对于一般硬件配置来说32是个较合理的数值,若硬件性能够强可设更高。
用户6719124
2019/12/19
1.7K0
Pytorch小项目-基于卷积神经网络的CIFAR10分类器
今天我们来讲一篇入门级必做的项目,如何使用pytorch进行CIFAR10分类,即利用CIFAR10数据集训练一个简单的图片分类器。
AI深度学习求索
2018/12/11
3.1K0
我用 PyTorch 复现了 LeNet-5 神经网络(自定义数据集篇)!
我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)!
红色石头
2022/01/10
1.6K0
我用 PyTorch 复现了 LeNet-5 神经网络(自定义数据集篇)!
深度学习实战:1.LeNet实现CIFAR-10图像分类
利用torchvision.datasets函数可以在线导入pytorch中的数据集,包含一些常见的数据集如MNIST、CIFAR-10等。本次使用的是CIFAR10数据集,也是一个很经典的图像分类数据集,由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集,一共包含 10 个类别的 RGB 彩色图片。
AiCharm
2023/05/15
5130
深度学习实战:1.LeNet实现CIFAR-10图像分类
我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)!
详细介绍了卷积神经网络 LeNet-5 的理论部分。今天我们将使用 Pytorch 来实现 LeNet-5 模型,并用它来解决 MNIST数据集的识别。
红色石头
2022/01/10
2.4K0
我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)!
PyTorch中的LeNet-5入门
LeNet-5是一个经典的卷积神经网络(CNN)模型,由Yann LeCun等人在1998年提出。它在手写数字识别任务上取得了很好的性能,并被广泛应用于图像分类问题。本文将介绍如何使用PyTorch实现LeNet-5模型,并在MNIST手写数字数据集上进行训练和测试。
大盘鸡拌面
2023/10/18
5930
神经网络中测试部分的编写
上下两张图中蓝色的曲线分别代表training过程中accuracy和loss,可以看到,随着epoch的增加,accuracy在逐渐变大,loss也在逐渐变小。由图来看貌似训练过程良好,但实际上被骗了
mathor
2020/01/14
7490
神经网络中测试部分的编写
【深度学习入门篇 ⑧】关于卷积神经网络
关于卷积神经网络,你还有哪些不知道的知识点呢,之前我们介绍了大部分,今天再来补充一下~
@小森
2024/07/25
1480
【深度学习入门篇 ⑧】关于卷积神经网络
推荐阅读
相关推荐
05-PyTorch自定义数据集Datasets、Loader和tranform
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验