前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >基于PyTorch实现MNIST手写字识别

基于PyTorch实现MNIST手写字识别

作者头像
Awesome_Tang
发布于 2019-04-17 07:59:55
发布于 2019-04-17 07:59:55
2.1K00
代码可运行
举报
文章被收录于专栏:FSocietyFSociety
运行总次数:0
代码可运行

本篇不涉及模型原理,只是分享下代码。想要了解模型原理的可以去看网上很多大牛的博客。

目前代码实现了CNN和LSTM两个网络,整个代码分为四部分:

  • Config:项目中涉及的参数;
  • CNN卷积神经网络结构;
  • LSTM:长短期记忆网络结构;
  • TrainProcess: 模型训练及评估,参数model控制训练何种模型(CNN or LSTM)。
完整代码

Talk is cheap, show me the code.

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# -*- coding: utf-8 -*-

# @author: Awesome_Tang
# @date: 2019-04-05
# @version: python3.7

import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from datetime import datetime


class Config:
    batch_size = 64
    epoch = 10
    alpha = 1e-3

    print_per_step = 100  # 控制输出


class CNN(nn.Module):

    def __init__(self):
        super(CNN, self).__init__()
        """
        Conv2d参数:
        第一位:input channels  输入通道数
        第二位:output channels 输出通道数
        第三位:kernel size 卷积核尺寸
        第四位:stride 步长,默认为1
        第五位:padding size 默认为0,不补
        """
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        self.fc1 = nn.Sequential(
            nn.Linear(64 * 5 * 5, 128),
            nn.BatchNorm1d(128),
            nn.ReLU()
        )

        self.fc2 = nn.Sequential(
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),  # 加快收敛速度的方法(注:批标准化一般放在全连接层后面,激活函数层的前面)
            nn.ReLU()
        )

        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x


class LSTM(nn.Module):
    def __init__(self):
        super(LSTM, self).__init__()

        self.lstm = nn.LSTM(
            input_size=28,
            hidden_size=64,
            num_layers=1,
            batch_first=True,
        )

        self.output = nn.Linear(64, 10)

    def forward(self, x):
        r_out, (_, _) = self.lstm(x, None)

        out = self.output(r_out[:, -1, :])
        return out


class TrainProcess:

    def __init__(self, model="CNN"):
        self.train, self.test = self.load_data()
        self.model = model
        if self.model == "CNN":
            self.net = CNN()
        elif self.model == "LSTM":
            self.net = LSTM()
        else:
            raise ValueError('"CNN" or "LSTM" is expected, but received "%s".' % model)
        self.criterion = nn.CrossEntropyLoss()  # 定义损失函数
        self.optimizer = optim.Adam(self.net.parameters(), lr=Config.alpha)

    @staticmethod
    def load_data():
        print("Loading Data......")
        """加载MNIST数据集,本地数据不存在会自动下载"""
        train_data = datasets.MNIST(root='./data/',
                                    train=True,
                                    transform=transforms.ToTensor(),
                                    download=True)

        test_data = datasets.MNIST(root='./data/',
                                   train=False,
                                   transform=transforms.ToTensor())

        # 返回一个数据迭代器
        # shuffle:是否打乱顺序
        train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                                   batch_size=Config.batch_size,
                                                   shuffle=True)

        test_loader = torch.utils.data.DataLoader(dataset=test_data,
                                                  batch_size=Config.batch_size,
                                                  shuffle=False)
        return train_loader, test_loader

    def train_step(self):
        steps = 0
        start_time = datetime.now()

        print("Training & Evaluating based on '%s'......" % self.model)
        for epoch in range(Config.epoch):
            print("Epoch {:3}.".format(epoch + 1))

            for data, label in self.train:
                data, label = Variable(data.cpu()), Variable(label.cpu())
                # LSTM输入为3维,CNN输入为4if self.model == "LSTM":
                    data = data.view(-1, 28, 28)
                self.optimizer.zero_grad()  # 将梯度归零
                outputs = self.net(data)  # 将数据传入网络进行前向运算
                loss = self.criterion(outputs, label)  # 得到损失函数
                loss.backward()  # 反向传播
                self.optimizer.step()  # 通过梯度做一步参数更新

                # 每100次打印一次结果
                if steps % Config.print_per_step == 0:
                    _, predicted = torch.max(outputs, 1)
                    correct = int(sum(predicted == label))  # 计算预测正确个数
                    accuracy = correct / Config.batch_size  # 计算准确率
                    end_time = datetime.now()
                    time_diff = (end_time - start_time).seconds
                    time_usage = '{:3}m{:3}s'.format(int(time_diff / 60), time_diff % 60)
                    msg = "Step {:5}, Loss:{:6.2f}, Accuracy:{:8.2%}, Time usage:{:9}."
                    print(msg.format(steps, loss, accuracy, time_usage))

                steps += 1

        test_loss = 0.
        test_correct = 0
        for data, label in self.test:
            data, label = Variable(data.cpu()), Variable(label.cpu())
            if self.model == "LSTM":
                data = data.view(-1, 28, 28)
            outputs = self.net(data)
            loss = self.criterion(outputs, label)
            test_loss += loss * Config.batch_size
            _, predicted = torch.max(outputs, 1)
            correct = int(sum(predicted == label))
            test_correct += correct

        accuracy = test_correct / len(self.test.dataset)
        loss = test_loss / len(self.test.dataset)
        print("Test Loss: {:5.2f}, Accuracy: {:6.2%}".format(loss, accuracy))

        end_time = datetime.now()
        time_diff = (end_time - start_time).seconds
        print("Time Usage: {:5.2f} mins.".format(time_diff / 60.))


if __name__ == "__main__":
    p = TrainProcess(model='CNN')
    p.train_step()

Peace~~

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2019.04.06 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
详解 Pytorch 实现 MNIST[通俗易懂]
项目虽然简单,但是个人建议还是将各个模块分开创建,特别是对于新人而言,模块化的创建会让读者更加清晰、易懂。
全栈程序员站长
2022/07/23
1.1K0
详解 Pytorch 实现 MNIST[通俗易懂]
基于卷积神经网络的垃圾分类
自今年7月1日起,上海市将正式实施 《上海市生活垃圾管理条例》。垃圾分类,看似是微不足道的“小事”,实则关系到13亿多人生活环境的改善,理应大力提倡。
云微
2023/02/11
9650
基于卷积神经网络的垃圾分类
CNN使用MNIST手写数字识别实战的代码和心得
因为MNIST图片为长和宽相同的28像素,为黑白两色,所以图片的高度为1,为灰度通道。
flykiss
2021/09/13
1.7K0
我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)!
详细介绍了卷积神经网络 LeNet-5 的理论部分。今天我们将使用 Pytorch 来实现 LeNet-5 模型,并用它来解决 MNIST数据集的识别。
红色石头
2022/01/10
2.5K0
我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)!
pytorch:实践MNIST手写数字识别
在datasets.MNIST的中可以设置download=True,这样设置,系统会自动在root里面检测MNIST数据文件,如果存在则不下载,如果不存在则自动联网下载。我尝试自动联网下载,结果十几分钟之后,下载一半之后报错,网络出现问题。于是翻阅其它资源,将其手动下载下来添加到minst文件夹中自动创建的raw文件夹中。 (如果你也需要这个数据集,可以在微信公众号“我有一计”内回复“数据集”,即可获取下载链接)
zstar
2022/06/14
4920
手撕 CNN 经典网络之 VGGNet(PyTorch实战篇)
详细介绍了 VGGNet 的网络结构,今天我们将使用 PyTorch 来复现VGGNet网络,并用VGGNet模型来解决一个经典的Kaggle图像识别比赛问题。
红色石头
2022/04/14
9070
手撕 CNN 经典网络之 VGGNet(PyTorch实战篇)
用Pytorch自建6层神经网络训练Fashion-MNIST数据集,测试准确率达到 92%
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
Frank909
2019/09/18
3.4K0
用Pytorch自建6层神经网络训练Fashion-MNIST数据集,测试准确率达到 92%
CNN基础 & CV基本应用
直接上代码,kaiming初始化真的猛。与LeNet相比,这里的主要变化是使用更小的学习速率训练,这是因为网络更深更广、图像分辨率更高,训练卷积神经网络就更昂贵。
Sarlren
2022/10/28
4270
CNN基础 & CV基本应用
java落地AI模型-cnn手写体识别
第一层包含卷积、批量归一化、ReLU激活和最大池化操作; 第二层结构相同但输出通道数为32; 全连接层将前一层输出扁平化后接分类器。
AI拉呱
2024/10/01
2010
java落地AI模型-cnn手写体识别
利用pytorch实现图像识别demo
2.准备数据集,并将其转换为适合PyTorch使用的格式(例如,利用 torchvision 库中的 transform 处理图像数据,并将其转换为 tensor)。
疯狂的KK
2023/03/26
1.2K0
多注释:用PyTorch实现卷积神经网络对MNIST手写数字数据集的分类
本文将为尽可能多的代码作注释,用PyTorch实现对手写数字数据集MNIST的分类,我也是一个PyTorch的初学者,如果你也是一个刚学pytorch没多久的朋友,希望我的注释能够让您尽可能看明白。因个人水平有限,如有什么写错的地方,敬请指正。
用户7886150
2020/12/27
1.5K0
我用 PyTorch 复现了 LeNet-5 神经网络(自定义数据集篇)!
我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)!
红色石头
2022/01/10
1.6K0
我用 PyTorch 复现了 LeNet-5 神经网络(自定义数据集篇)!
手撕 CNN 之 AlexNet(PyTorch 实战篇)
详细介绍了 AlexNet 的网络结构,今天我们将使用 PyTorch 来复现AlexNet网络,并用AlexNet模型来解决一个经典的Kaggle图像识别比赛问题。
红色石头
2022/04/14
1.9K0
手撕 CNN 之 AlexNet(PyTorch 实战篇)
轻松学Pytorch-使用STN网络实现旋转对象检测
Pytorch刚刚发布的最新版本1.10上面支持使用STN网络,帮助CNN网络获取旋转不变性特征。而且只需要在原来的CNN网络中改动十行左右代码即可获得加持,从而让训练生成的分类或者对象检测网络具有更好的稳定性。
OpenCV学堂
2021/12/17
1.3K0
轻松学Pytorch-使用STN网络实现旋转对象检测
pyTorch入门(三)——GoogleNet和ResNet训练
这是Minist训练的第三篇了,本篇主要是把GoogleNet和ResNet的模型写出来做一个测试,再就是train.py里面代码加入了图例显示。
Vaccae
2022/12/29
4860
pyTorch入门(三)——GoogleNet和ResNet训练
【机器学习实战】从零开始深度学习(通过GPU服务器进行深度学习)
0.1. 利用GPU加速深度学习   疫情期间没有办法用实验室的电脑来跑模型,用领取的腾讯云实例来弄刚刚好。发现如果没有GPU来跑的话真的是太慢了,非常推荐利用GPU加速深度学习的训练速度。     如果采用GPU的话,训练函数train_model(*)中数据的输入要改变一下,也就是需要将数据放在GPU上
汉堡888
2022/05/03
8.7K0
【机器学习实战】从零开始深度学习(通过GPU服务器进行深度学习)
基于Pytorch构建LeNet网络对cifar-10进行分类
LeNet5诞生于1994年,是最早的卷积神经网络之一,是Yann LeCun等人在多次研究后提出的最终卷积神经网络结构,是一种用于手写体字符识别非常高效的网络。一般LeNet即指代LeNet5。LeNet5 这个网络虽然很小,但是它包含了深度学习的基本模块:卷积层,池化层,全连接层,是其他深度学习模型的基础。
python与大数据分析
2023/09/03
5420
基于Pytorch构建LeNet网络对cifar-10进行分类
我用 PyTorch 复现了 LeNet-5 神经网络(CIFAR10 数据集篇)!
我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)!
红色石头
2022/01/10
1.3K0
我用 PyTorch 复现了 LeNet-5 神经网络(CIFAR10 数据集篇)!
基于Pytorch构建AlexNet网络对cifar-10进行分类
AlexNet网络是CV领域最经典的网络结构之一了,在2012年横空出世,并在当年夺下了不少比赛的冠军。也是在那年之后,更多的更深的神经网络被提出,比如优秀的vgg,GoogLeNet。AlexNet和LeNet的设计非常类似,但AlexNet的结构比LeNet规模更大。
python与大数据分析
2023/09/03
8260
基于Pytorch构建AlexNet网络对cifar-10进行分类
用 PyTorch 从零创建 CIFAR-10 的图像分类器神经网络,并将测试准确率达到 85%
一般,深度学习的教材或者是视频,作者都会通过 MNIST 这个数据集,讲解深度学习的效果,但这个数据集太小了,而且是单色图片,随便弄些模型就可以取得比较好的结果,但如果我们不满足于此,想要训练一个神经网络来对彩色图像进行分类,可以不可以呢?
Frank909
2019/01/14
10.1K0
推荐阅读
相关推荐
详解 Pytorch 实现 MNIST[通俗易懂]
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验