首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >问答首页 >使用Tensorflow.js的MNIST数据集的损失曲线在1.0处饱和

使用Tensorflow.js的MNIST数据集的损失曲线在1.0处饱和
EN

Stack Overflow用户
提问于 2020-01-17 04:54:35
回答 1查看 83关注 0票数 0

data.js

代码语言:javascript
运行
AI代码解释
复制
    /**
 * @license
 * Copyright 2018 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */

const IMAGE_SIZE = 784;
const NUM_CLASSES = 10;
const NUM_DATASET_ELEMENTS = 65000;

const TRAIN_TEST_RATIO = 5 / 6;

const NUM_TRAIN_ELEMENTS = Math.floor(TRAIN_TEST_RATIO * NUM_DATASET_ELEMENTS);
const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;

const MNIST_IMAGES_SPRITE_PATH =
    'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png';
const MNIST_LABELS_PATH =
    'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8';

/**
 * A class that fetches the sprited MNIST dataset and returns shuffled batches.
 *
 * NOTE: This will get much easier. For now, we do data fetching and
 * manipulation manually.
 */
export class MnistData {
  constructor() {
    this.shuffledTrainIndex = 0;
    this.shuffledTestIndex = 0;
  }

  async load() {
    // Make a request for the MNIST sprited image.
    const img = new Image();
    const canvas = document.createElement('canvas');
    const ctx = canvas.getContext('2d');
    const imgRequest = new Promise((resolve, reject) => {
      img.crossOrigin = '';
      img.onload = () => {
        img.width = img.naturalWidth;
        img.height = img.naturalHeight;

        const datasetBytesBuffer =
            new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);

        const chunkSize = 5000;
        canvas.width = img.width;
        canvas.height = chunkSize;

        for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
          const datasetBytesView = new Float32Array(
              datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
              IMAGE_SIZE * chunkSize);
          ctx.drawImage(
              img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width,
              chunkSize);

          const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);

          for (let j = 0; j < imageData.data.length / 4; j++) {
            // All channels hold an equal value since the image is grayscale, so
            // just read the red channel.
            datasetBytesView[j] = imageData.data[j * 4] / 255;
          }
        }
        this.datasetImages = new Float32Array(datasetBytesBuffer);

        resolve();
      };
      img.src = MNIST_IMAGES_SPRITE_PATH;
    });

    const labelsRequest = fetch(MNIST_LABELS_PATH);
    const [imgResponse, labelsResponse] =
        await Promise.all([imgRequest, labelsRequest]);

    this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());

    // Create shuffled indices into the train/test set for when we select a
    // random dataset element for training / validation.
    this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS);
    this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS);

    // Slice the the images and labels into train and test sets.
    this.trainImages =
        this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
    this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
    this.trainLabels =
        this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);
    this.testLabels =
        this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);
  }

  nextTrainBatch(batchSize) {
    return this.nextBatch(
        batchSize, [this.trainImages, this.trainLabels], () => {
          this.shuffledTrainIndex =
              (this.shuffledTrainIndex + 1) % this.trainIndices.length;
          return this.trainIndices[this.shuffledTrainIndex];
        });
  }

  nextTestBatch(batchSize) {
    return this.nextBatch(batchSize, [this.testImages, this.testLabels], () => {
      this.shuffledTestIndex =
          (this.shuffledTestIndex + 1) % this.testIndices.length;
      return this.testIndices[this.shuffledTestIndex];
    });
  }

  nextBatch(batchSize, data, index) {
    const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE);
    const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES);

    for (let i = 0; i < batchSize; i++) {
      const idx = index();

      const image =
          data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE);
      batchImagesArray.set(image, i * IMAGE_SIZE);

      const label =
          data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES);
      batchLabelsArray.set(label, i * NUM_CLASSES);
    }

    const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]);
    const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]);

    return {xs, labels};
  }
}

script.js

代码语言:javascript
运行
AI代码解释
复制
    import {MnistData} from './data.js';
var canvas,ctx,saveButton,clearButton;
var pos={x:0,y:0};
var rawImage;
var model;

function getModel()
{
    model=tf.sequential();
    model.add(tf.layers.conv2d({inputShape:[28,28,1],kernelSize:3,filters:8,activation:'relu'}));
    model.add(tf.layers.maxPooling2d({poolSize:[2,2]}));
    model.add(tf.layers.conv2d({filters:16,kernelSize:3,activation:'relu'}));
    model.add(tf.layers.maxPooling2d({poolSize:[2,2]}));
    model.add(tf.layers.flatten());
    model.add(tf.layers.dense({units:128,activation:'sigmoid'}));
    model.add(tf.layers.dense({units:10,activation:'softmax'}));

    model.compile({optimizer:tf.train.adam(),loss:'categoricalCrossentropy',metrics:['accuracy']});

    return model;

}

async function train(model,data){
    const metrics=['loss', 'val_loss', 'acc', 'val_acc'];
    const container={name:'Model training',styles:{height:'640px'}};
    const fitCallbacks=tfvis.show.fitCallbacks(container,metrics);

    const BATCH_SIZE = 512;
    const TRAIN_DATA_SIZE = 5500;
    const TEST_DATA_SIZE = 1000;

    const [trainXs,trainYs]=tf.tidy(()=>
                                   {
        const d=data.nextTrainBatch(TRAIN_DATA_SIZE);
        return[
            d.xs.reshape([TRAIN_DATA_SIZE,28,28,1]),
            d.labels
        ];
    });

    const [testXs,testYs]=tf.tidy(()=>{
        const d=data.nextTestBatch(TEST_DATA_SIZE);
        return[
            d.xs.reshape([TEST_DATA_SIZE,28,28,1]),
            d.labels
        ];
    });

    return model.fit(trainXs,trainYs,{
        batchSize:BATCH_SIZE,
        validationData:[testXs,testYs],
        epochs:20,
        shuffle:true,
        callbacks:fitCallbacks
    });

}

function setPosition(e){
    pos.x=e.clientX-100;
    pos.y=e.clientY-100;
}

function draw(e)
{
    if(e.buttons!=1)return ;
    ctx.beginPath();
    ctx.lineWidth=24;
    ctx.lineCap='round';
    ctx.strokeStyle='white';
    ctx.moveTo(pos.x,pos.y);
    setPosition(e);
    ctx.lineTo(pos.x,pos.y)
    ctx.stroke();

    rawImage.src=canvas.toDataURL('image/png');
}

function erase()
{
    ctx.fillStyle="black";
    ctx.fillRect(0,0,280,280);
}

function save()
{
    var raw=tf.browser.fromPixels(rawImage,1);
    var resized=tf.image.resizeBilinear(raw,[28,28]);
    var tensor=resized.expandDims(0);

    var prediction=model.predict(tensor);
    var pIndex=tf.argMax(prediction,1).dataSync();

    alert(pIndex);
}

function init()
{
    canvas=document.getElementById('canvas');
    rawImage=document.getElementById('canvasimg');
    ctx=canvas.getContext("2d");
    ctx.fillStyle="black";
    ctx.fillRect(0,0,280,280);
    canvas.addEventListener("mousemove",draw);
    canvas.addEventListener("mousedown",setPosition);
    canvas.addEventListener("mouseenter",setPosition);
    saveButton=document.getElementById('sb');
    saveButton.addEventListener("click",save);
    clearButton=document.getElementById('cb');
    clearButton.addEventListener("click",erase);


}

async function run()
{
    const data=new MnistData();
    await data.load();
    const model=getModel();
    tfvis.show.modelSummary({name:'Model Architecture'},model);
    await train(model,data);
    init();
    alert("Training is done, try classifying...");
}

document.addEventListener('DOMContentLoaded', run);

mnist.htm

代码语言:javascript
运行
AI代码解释
复制
    <html>
<head>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis"></script>
</head>
    <body>
    <h1> Handwritten character recognition</h1>
    <canvas id="canvas" width="280" height="280" style="position:absolute;top:100;left:100;border:8px solid;"></canvas>
    <img id="canvasimg" style="position:absolute;top:10%;left=52%;width:280;height=280;display:none;">
    <input type="button" value="classify" id="sb" size="48" style="position:absolute;top:400;left:100;">
    <input type="button" value="clear" id="cb" size="23" style="position:absolute;top:400;left:180;">

    <script src="data.js" type="module"></script>
    <script src="script.js" type="module"></script>
    </body>

</html>

我试图制作一个手写数字分类器,它可以根据我们在网页画布上绘制的内容来识别数字。但是在训练时,我的损失曲线饱和在1.0,我的准确率饱和在60%。因此,我尝试将128节点密集层的激活函数从relu更改为sigmoid。即使在改变这一点之后,我的损失也会饱和在1.0。请帮帮我。

EN

回答 1

Stack Overflow用户

发布于 2020-01-17 08:58:27

已经定义了用于对mnist数据集herethere进行分类的模型。如果你想重写你自己的模型,那么你需要将它与那些将作为基线的官方模型进行比较。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/59781402

复制
相关文章
jquery选择器用法_jQuery属性选择器
一、 基本选择器 1. ID选择器 ID选择器#id就是利用DOM元素的id属性值来筛选匹配的元素,并以iQuery包装集的形式返回给对象。 使用公式:(“#id”) 示例:(“#box”) //获取id属性值为box的元素 2. 元素选择器 元素选择器是根据元素名称匹配相应的元素。元素选择器指向的是DOM元素的标记名,也就是说元素选择器是根据元素的标记名选择的。 使用公式:(“element”) 示例:(“div”) //获取所有div元素 3.类名选择器 类选择器是通过元素拥有的CSS类的名称查找匹配的DOM元素。在一个页面中,一个元素可以有多个CSS类,一个CSS类又可以匹配多个元素,如果有元素中有一个匹配类的名称就可以被类选择器选取到。简单地说类名选择器就是以元素具有的CSS类名称查找匹配的元素。 使用公式:(“.class”) 示例:(“.box”) //获取class属性值为box的所有元素 4.复合选择器 复合选择器将多个选择器(可以是ID选择器、元素选择器或是类名选择器)组合在一起,两个选择器之间以逗号”,”分隔,只要符合其中的任何一个筛选条件就会被匹配,返回的是一个集合形式的jQuery包装集,利用jQuery索引器可以取得集合中的jQuery对象。 注意:多种匹配条件的选择器并不是匹配同时满足这几个选择器的匹配条件的元素,而是将每个匹配的元素合并后一起返回。 使用公式:(“selector1,selector2,……,selectorN”) selector1:一个有效的选择器,可以是ID选择器、元素选择器或类名选择器等 selector2:另一个有效的选择器,可以是ID选择器、元素选择器或类名选择器等 selectorN:(可选择)任意多个选择器,可以是ID选择器、元素选择器或类名选择器等 示例:(“div,#btn”) //要查询文档中的全部的<div>元素和id属性为btn的元素 5.通配符选择器
全栈程序员站长
2022/11/16
12.3K0
jQuery的addClass、siblings、removeClass、each、html、eq、show/hide用法
addClass() siblings() removeClass() each()
江一铭
2022/06/17
1.5K0
jQuery 选择器
原生 JS 获取元素方式很多,很杂,而且兼容性情况不一致,因此 jQuery 给我们做了封装,使获取元素统一标准。
星辰_大海
2020/10/09
2.9K0
jQuery 选择器
jquery 选择器
<script type="text/javascript"> $(".demo").click(function(){ alert() }) </script>
用户5760343
2019/10/08
1.6K0
jquery 选择器
[jQuery笔记] jQuery选择器
jquery选择器允许对html中的元素组合单个元素进行操作,jquery的选择器和css的选择器几乎大同小异,大致分为元素选择器、id选择器和类选择器。jquery的选择器基于元素的id、类、类型、属性、属性值等查找或选择html元素,基于已经存在的css选择器,另外,jquery也支持自定义选择器。
行 者
2019/12/05
1.8K0
jQuery选择器
说明: 可以使用length属性来判断标签是否选择成功, 如果length大于0表示选择成功,否则选择失败。
落雨
2022/03/01
30.4K0
jQuery 选择器
基本选择器 基本选择器是最简单的选择器,可以通过元素id、class和标签名等来直接查找DOM元素。 元素选择器 根据给定元素名匹配元素。如下选择的是所有div元素。 $("div").css("
静默虚空
2018/01/05
7.5K0
JQuery选择器
jQuery常用的事件: load:当文档加载时运行脚本 blur:当窗口失去焦点时运行脚本 focus:当窗口获得焦点时运行脚本 change:当元素改变时运行脚本 submit:当提交表单时运行脚本 keydown:当按下按键时运行脚本 keypress:当按下并松开按键时运行脚本 keyup:当松开按键时运行脚本 click:当单击鼠标时运行脚本 dblclick:当双击鼠标时运行脚本 mousedown:当按下鼠标按钮时运行脚本 mousemove:当鼠标指针移动时运行脚本 mouseout:当鼠标指针移出元素时运行脚本 mouseover:当鼠标指针移至元素之上时运行脚本 mouseup:当松开鼠标按钮时运行脚本 abort:当发生中止事件时运行脚本
我不是费圆
2020/09/21
7.5K0
JQuery选择器
1    $("*")      ---------选取所有元素 2   $(this)     --------选择当前HTML元素 3   $("p.a")   -----选取p元素下class为a的元素 4   $("p:first")  ----选取第一个p元素 5   $("ul li:first-child") ----选取ul下第一个li元素 6  $("tr:even")  -------选取偶数位置下的tr 7 $("tr :odd")   --------选取奇数位置的tr
用户3159471
2018/09/13
1.7K0
jQuery(选择器)
注意:但是:first-child选择器可以匹配多个:即为每个父级元素匹配第一 个子元素。这相当于:nth-child(1);
全栈开发日记
2022/05/12
1.5K0
jQuery 选择器
原生 JS 获取元素方式很多,很杂,而且兼容性情况不一致,因此 jQuery 给我们做了封装,使获取元素统一标准。
清出于兰
2020/10/26
1.8K0
jQuery 选择器
jquery选择器
jquery选择器可以快速地选择元素,选择规则和css样式相同,使用length属性判断是否选择成功。
Devops海洋的渔夫
2019/05/30
1.8K0
jQuery 选择器
jQuery网页脚本语言核心之一 概述: 1. 选择器是jQuery的基础 2. 对事件处理,遍历DOM和Ajax操作都依赖于选择器 3. 可简化代码 什么是jQuery选择器? 层叠样式表 良好地继承了css选择器语法,还继承了其获取页面元素便捷高效的特点 于css不同,jQuery选择器获取元素后,为该元素添加的是行为 有良好的兼容性 优势 1. 简洁的写法 (1) $(选择) 2. 完善的处理机制 (1) 简洁,避免某些错误 类型: 可通过css选择器和过滤选择器两种方式选择元素,每种又有不同的方法来
房上的猫
2018/04/18
2.7K0
jQuery 选择器
Jquery选择器
1、  基本选择器 选择器 描述 结果 示例 #id 根据id获取元素 单个 $(“#myid”)选取id的值为myid的元素 .class 根据class获取元素 集合 $(“.myclass”)选取class的值为myclass的元素 a,p,img等html标签 根据指定的html标签获取元素 集合 $(“img”)选取所有的img标签 * 获取所有的元素 集合 $(“*”)获取所有标签元素 a,.myclass,#id等 获取对应标签元素 集合 $(“a,.myclass,
苦咖啡
2018/05/07
2K0
jQuery常用的选择器
当我们想要操所页面中的元素时,首先要做的就是选取元素。选取页面中元素可以使用jQuery给我们提供的$()方法,该方法需要提供选择器作为参数,方法执行完成后会返回给我们一个jQuery对象,被选取的元素就包含在该对象中。
小周sir
2019/09/23
7450
jquery的基本选择器
关于基本选择器包括 “*” ,“.class”,"element","#id","selector1 selementN" "*" 选择器,可以找到文档中的所有的元素,包括 head body $(function(){ // $("#test").find("*").css("border","3px solid red"); //找到 #test 这个元素下面的所有的 元素 //}); ".class" 选择给定样式类名的所有的元素。 //$(function
用户1197315
2018/01/19
6460
锋利的JQuery —— 选择器
图片猛戳链接
用户1154259
2018/01/17
7230
锋利的JQuery —— 选择器
jQuery选择器总结
本文主要介绍了jQuery选择器的使用,包括基本选择器、层次选择器、过滤选择器、内容过滤器、可见选择器、属性选择器和子元素选择器等。通过实例讲解了如何在JavaScript中使用这些选择器,以及如何使用表单元素过滤选择器。
IMWeb前端团队
2018/01/08
1.2K0
jQuery选择器总结
jQuery层次选择器
jQuery是一种流行的JavaScript库,用于简化HTML文档的操作和动态交互。在jQuery中,层次选择器是一种非常有用的选择器,可以根据元素之间的层次关系选择特定的元素。
堕落飞鸟
2023/05/18
4780
jquery 层级选择器
关于层级选择器。 $("parent > child") 选择所有指定“parent”元素中指定的“child”的直接子项元素。 parent :任何有效的选择器。 child: 用来筛选子元素的选择器 $(function(){ $("ul.myul > li").css("border","2px solid red"); //将ul带有.myul选择下面的li 标签添加边框。}); $("ancestor descendant") 选择给定的祖先元素的所有后代
用户1197315
2018/01/19
9120

相似问题

jQuery :eq(索引)选择器

20

jQuery嵌套:eq选择器

31

jQuery选择器eq:()不工作

24

jQuery eq()选择器不工作

41

jQuery选择器(子项、eq和innerhtml)

229
添加站长 进交流群

领取专属 10元无门槛券

AI混元助手 在线答疑

扫码加入开发者社群
关注 腾讯云开发者公众号

洞察 腾讯核心技术

剖析业界实践案例

扫码关注腾讯云开发者公众号
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
查看详情【社区公告】 技术创作特训营有奖征文