Loading [MathJax]/jax/output/CommonHTML/config.js
部署DeepSeek模型,进群交流最in玩法!
立即加群
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >基于PyTorch实现MNIST手写字识别

基于PyTorch实现MNIST手写字识别

作者头像
Awesome_Tang
发布于 2019-04-17 07:59:55
发布于 2019-04-17 07:59:55
2.1K00
代码可运行
举报
文章被收录于专栏:FSocietyFSociety
运行总次数:0
代码可运行

本篇不涉及模型原理,只是分享下代码。想要了解模型原理的可以去看网上很多大牛的博客。

目前代码实现了CNN和LSTM两个网络,整个代码分为四部分:

  • Config:项目中涉及的参数;
  • CNN卷积神经网络结构;
  • LSTM:长短期记忆网络结构;
  • TrainProcess: 模型训练及评估,参数model控制训练何种模型(CNN or LSTM)。
完整代码

Talk is cheap, show me the code.

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# -*- coding: utf-8 -*-

# @author: Awesome_Tang
# @date: 2019-04-05
# @version: python3.7

import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from datetime import datetime


class Config:
    batch_size = 64
    epoch = 10
    alpha = 1e-3

    print_per_step = 100  # 控制输出


class CNN(nn.Module):

    def __init__(self):
        super(CNN, self).__init__()
        """
        Conv2d参数:
        第一位:input channels  输入通道数
        第二位:output channels 输出通道数
        第三位:kernel size 卷积核尺寸
        第四位:stride 步长,默认为1
        第五位:padding size 默认为0,不补
        """
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        self.fc1 = nn.Sequential(
            nn.Linear(64 * 5 * 5, 128),
            nn.BatchNorm1d(128),
            nn.ReLU()
        )

        self.fc2 = nn.Sequential(
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),  # 加快收敛速度的方法(注:批标准化一般放在全连接层后面,激活函数层的前面)
            nn.ReLU()
        )

        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x


class LSTM(nn.Module):
    def __init__(self):
        super(LSTM, self).__init__()

        self.lstm = nn.LSTM(
            input_size=28,
            hidden_size=64,
            num_layers=1,
            batch_first=True,
        )

        self.output = nn.Linear(64, 10)

    def forward(self, x):
        r_out, (_, _) = self.lstm(x, None)

        out = self.output(r_out[:, -1, :])
        return out


class TrainProcess:

    def __init__(self, model="CNN"):
        self.train, self.test = self.load_data()
        self.model = model
        if self.model == "CNN":
            self.net = CNN()
        elif self.model == "LSTM":
            self.net = LSTM()
        else:
            raise ValueError('"CNN" or "LSTM" is expected, but received "%s".' % model)
        self.criterion = nn.CrossEntropyLoss()  # 定义损失函数
        self.optimizer = optim.Adam(self.net.parameters(), lr=Config.alpha)

    @staticmethod
    def load_data():
        print("Loading Data......")
        """加载MNIST数据集,本地数据不存在会自动下载"""
        train_data = datasets.MNIST(root='./data/',
                                    train=True,
                                    transform=transforms.ToTensor(),
                                    download=True)

        test_data = datasets.MNIST(root='./data/',
                                   train=False,
                                   transform=transforms.ToTensor())

        # 返回一个数据迭代器
        # shuffle:是否打乱顺序
        train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                                   batch_size=Config.batch_size,
                                                   shuffle=True)

        test_loader = torch.utils.data.DataLoader(dataset=test_data,
                                                  batch_size=Config.batch_size,
                                                  shuffle=False)
        return train_loader, test_loader

    def train_step(self):
        steps = 0
        start_time = datetime.now()

        print("Training & Evaluating based on '%s'......" % self.model)
        for epoch in range(Config.epoch):
            print("Epoch {:3}.".format(epoch + 1))

            for data, label in self.train:
                data, label = Variable(data.cpu()), Variable(label.cpu())
                # LSTM输入为3维,CNN输入为4if self.model == "LSTM":
                    data = data.view(-1, 28, 28)
                self.optimizer.zero_grad()  # 将梯度归零
                outputs = self.net(data)  # 将数据传入网络进行前向运算
                loss = self.criterion(outputs, label)  # 得到损失函数
                loss.backward()  # 反向传播
                self.optimizer.step()  # 通过梯度做一步参数更新

                # 每100次打印一次结果
                if steps % Config.print_per_step == 0:
                    _, predicted = torch.max(outputs, 1)
                    correct = int(sum(predicted == label))  # 计算预测正确个数
                    accuracy = correct / Config.batch_size  # 计算准确率
                    end_time = datetime.now()
                    time_diff = (end_time - start_time).seconds
                    time_usage = '{:3}m{:3}s'.format(int(time_diff / 60), time_diff % 60)
                    msg = "Step {:5}, Loss:{:6.2f}, Accuracy:{:8.2%}, Time usage:{:9}."
                    print(msg.format(steps, loss, accuracy, time_usage))

                steps += 1

        test_loss = 0.
        test_correct = 0
        for data, label in self.test:
            data, label = Variable(data.cpu()), Variable(label.cpu())
            if self.model == "LSTM":
                data = data.view(-1, 28, 28)
            outputs = self.net(data)
            loss = self.criterion(outputs, label)
            test_loss += loss * Config.batch_size
            _, predicted = torch.max(outputs, 1)
            correct = int(sum(predicted == label))
            test_correct += correct

        accuracy = test_correct / len(self.test.dataset)
        loss = test_loss / len(self.test.dataset)
        print("Test Loss: {:5.2f}, Accuracy: {:6.2%}".format(loss, accuracy))

        end_time = datetime.now()
        time_diff = (end_time - start_time).seconds
        print("Time Usage: {:5.2f} mins.".format(time_diff / 60.))


if __name__ == "__main__":
    p = TrainProcess(model='CNN')
    p.train_step()

Peace~~

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2019.04.06 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
Vue中的nexTick()
获取更新后的DOM言外之意就是,操作需要用到了更新后的DOM而不能使用之前的DOM或者使用更新前的DOM会出问题,所以就衍生出了这个获取更新后的 DOM的Vue方法。所以放在Vue.nextTick()回调函数中的执行的应该是会对DOM进行操作的 JS代码。
刘亦枫
2020/03/19
1.7K0
Vue.js 中 nextTick | 笔记
引言 对 Vue 组件数据(props 或状态)的更改不会立即反映在 DOM 中。 相反,Vue 异步更新 DOM。 你可以使用 Vue.nextTick() 或 vm.$nextTick() 函数捕获 Vue 更新 DOM 的时刻。 让我们详细了解这些函数的工作原理。 nextTick() 当 Vue 组件数据发生变化时,DOM 会异步更新。 Vue 会收集来自所有组件的多个虚拟 DOM 更新,然后创建一个批处理来更新DOM。 在单个批次中更新 DOM 比进行多个小的更新更高效。 例如,让我们考虑一个切换元素显示的组件:
yiyun
2023/09/30
3920
Vue.js 中 nextTick | 笔记
浅析$nextTick和$forceUpdate
vue中提供了数十种api供我们开发者日常使用,而常用的其实也就十多种,比如setup, mount, forceupdate, nextTick, compute, ref等,这些参数有的是在生命周期中进行管理,有的是在页面执行过程中,更新参数,有些是可以用来检测页面数据,这些随着项目的推进或多或少都是会使用到。其中nextTick和forceUpdate都是用来更新参数的,那这两个参数有什么差异呢?这还真值得仔细琢磨。
Yerik
2022/04/05
1.9K0
vue.js-详解三大流行框架VUE_快速进阶前端大咖-Vue基础
MVX模式简介: MVX框架模式:MVC+MVP+MVVM MVC: Model模型+View视图+Controller控制器
达达前端
2019/12/13
4.2K0
vue.js-详解三大流行框架VUE_快速进阶前端大咖-Vue基础
Vue1.x 写法示例
常见内置过滤器 capitalize, uppercase, lowercase, json, limitBy, filterBy。所有见这里。
前端GoGoGo
2018/08/24
7690
Vue.nextTick核心原理
相信大家在写vue项目的时候,一定会发现一个神奇的api,Vue.nextTick。为什么说它神奇呢,那是因为在你做某些操作不生效时,将操作写在Vue.nextTick内,就神奇的生效了。那这是什么原因呢?
yyds2026
2022/10/18
5860
面试官:Vue中的$nextTick怎么理解?
我们可以理解成,Vue 在更新 DOM 时是异步执行的。当数据发生变化,Vue将开启一个异步更新队列,视图需要等队列中所有数据变化完成之后,再统一进行更新
@超人
2021/02/26
1.5K0
面试官:Vue中的$nextTick怎么理解?
Vue开发实战(03)-组件化开发
Vue官方的示例图对组件化开发的形象展示。左边是一个网页,可以按照功能模块抽象成很多组件,这些组件就像积木一样拼接成网页。
JavaEdge
2023/07/09
2900
Vue开发实战(03)-组件化开发
Vue异步更新队列及nextTick
vue通常鼓励开发人员沿着“数据驱动”的方式思考,避免直接接触 DOM。Vue的dom更新是异步的,当数据发生变化,vue并不是里面去更新dom,而是开启一个队列。跟JavaScript原生的同步任务和异步任务相同。
wade
2020/04/24
8030
(24)打鸡儿教你Vue.js
1、使用Vue2.0版本实现响应式编程 2、理解Vue编程理念与直接操作Dom的差异 3、Vue常用的基础语法 4、使用Vue编写TodoList功能 5、什么是Vue的组件和实例 6、Vue-cli脚手架工具的使用 7、但文件组件,全局样式与局部样式
达达前端
2019/07/08
4070
Vue.js简介
摘要总结:本文介绍了Vue.js的核心特性,包括数据绑定、指令、过滤器、组件、模板和生命周期等。同时,还通过一个HelloWorld的实例来演示了如何使用Vue.js创建一个简单的单页面应用。
IMWeb前端团队
2017/12/29
3.1K0
Reactjs vs. Vuejs
这里要讨论的话题,不是前端框架哪家强,因为在 Vue 官网就已经有了比较全面客观的介绍,并且是中文的。 上图是二月份前端框架排名,React 位居第一,Vue 排名第三。还清晰记得,16 年十月份
纪俊
2017/05/02
6.6K0
Reactjs  vs. Vuejs
Vue.$nextTick的原理是什么-vue面试进阶
原理性的东西就会文字较多,请耐下心来,细细品味Vue中DOM更新机制当你气势汹汹地使用Vue大展宏图的时候,突然发现,咦,我明明对这个数据进行更改了,但是当我获取它的时候怎么是上一次的值(本人比较懒,就不具体举例了👵)此时,Vue就会说:“小样,这你就不懂了吧,我的DOM是异步更新的呀!!!”简单的说,Vue的响应式并不是只数据发生变化之后,DOM就立刻发生变化,而是按照一定的策略进行DOM的更新。这样的好处是可以避免一些对DOM不必要的操作,提高渲染性能。在Vue官方文档中是这样说明的:可能你还没有注意到
bb_xiaxia1998
2022/10/04
3240
Vue中 v-for 指令深入解析:原理、实践与性能优化
Vue.js 是一个渐进式 JavaScript 框架,它允许开发者使用声明式方式编写可重用的 UI 组件。在 Vue.js 中,v-for 是一个非常重要的指令,它用于基于一个数组来渲染一个列表。本文将深入探讨 v-for 指令的工作原理,并通过实践来展示如何使用它。
Front_Yue
2024/07/28
7860
Vue中 v-for 指令深入解析:原理、实践与性能优化
Vue.js 的一些小技巧
比如一个 <my-button> 上暴露了一个 width 属性,我们既可以传 100px,也可以传 100 :
步履不停凡
2019/09/11
8200
Vue零基础开发入门
讲解部分 Vue 基础语法,通过 TodoList 功能编写,在熟悉基础语法的基础上,扩展解析 MVVM 模式及前端组件化的概念及优势。
JavaEdge
2022/11/30
3.5K0
Vue零基础开发入门
vue生命周期钩子函数
1. created与mounted都常见于ajax请求,前者如果请求响应时间过长,容易白屏
用户1149564
2018/12/04
9820
vue生命周期钩子函数
【Vuejs】1146- 这些 Vue 的技巧你都掌握了吗?
故事第 1 集:CSS预处理器,你还是只会嵌套么 ?[2] 故事第 2 集:【自适应】px 转 rem,你还在手算么?[3]
pingan8787
2021/12/01
1.8K0
【Vuejs】1146- 这些 Vue 的技巧你都掌握了吗?
Vue 2 常见面试题速查
当一个 Vue 实例创建时,Vue 会遍历 data 中的属性,用 Object.defineProperty 将它们转为 getter / setter,并且在内部追踪相关依赖,在属性被访问和修改时通知变化。每个组件实例都有相应的 watcher 程序实例,它会在组件渲染的过程中把属性记录为依赖,之后当依赖项的 setter 被调用时,会通知 watcher 重新计算,从而致使它关联的组件得以更新。
Cellinlab
2023/05/17
1.2K0
Vue 2 常见面试题速查
前端常考vue面试题(必备)_2023-03-15
computed 内部实现了一个惰性的 watcher,也就是 computed watcher,computed watcher 不会立刻求值,同时持有一个 dep 实例。
yyds2026
2023/03/15
1.1K0
相关推荐
Vue中的nexTick()
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验