最近业余时间做些创新探索,在微信小程序上实现找到纸或笔记本,定位,然后取到纸上的简笔画,之后进行简笔画识别,找到对应位置(之后可以在此位置上加载对应3d模型,实现ar效果, 对应ar官方案例:https://github.com/bbSpider/miniprogramThree)
后续可以自己训练模型识别白纸和简笔画图形 2)也可以直接用tf.loadGraphModel加载自己训练的实物检测模型,不过只能得到识别结果信息,没有位置信息 在微信小程序中接入tensorflow,自己训练实物检测模型,实现识别摄像头数据流中的眼镜、老虎、纸、简笔画的花、简笔画的T-shirt,并分别给出可信度
使用的是tensorflow 的 layerModel格式的模型
有H5版的手绘图片识别:https://medium.com/tensorflow/train-on-google-colab-and-run-on-the-browser-a-case-study-8a45f9b1474e
本来计划用此手绘识别模型来识别摄像头数据源(即人在纸上画好的简笔画),但是发现识别准确率很差,后来用H5版的手绘画布转换成图片来识别也发现准确率跟摄像头数据识别一样差,而用像素数据则准确率高 原因在数据集的介绍里面也有说到:https://github.com/googlecreativelab/quickdraw-dataset
原本用于训练的数据集里每张手绘图的轮廓信息就是用坐标标识的,所以传入画布绘画api的坐标像素数据才会比较准确
因此此模型比较适用于画布的原始绘画api来画简笔画,再通过获取画布像素数据来做模型识别的传参比较合适,所以实现了此手绘图片识别的小程序版,如下
其实此种方式直接在画布交互反而比摄像头找纸笔绘画的交互好得多,用户送礼物时也可选择送简笔画图片或者是识别出来的实物图任意选择
究其深层原因,是因为画布手绘图是灰度图,用python api
np.array(image.getdata())
可以看出得到其数据是一维数组(即一阶矩阵),而彩色图片或者摄像头数据源是rgb图片,np.array(image.getdata())
得到的是三阶矩阵。所以可以通过Numpy转换:https://zhuanlan.zhihu.com/p/136754904,js调用python教程:https://zhuanlan.zhihu.com/p/448356773
1)用three.js实现识别结果的实物3d模型的生成 2)做大家送礼物的统一展示页面 3)可选项:可实现背景替换为摄像头数据,将实物置于摄像头背景之上,供用户导出图片,更具逼真性
https://github.com/tensorflow/models/blob/master/research/object_detection/colab_tutorials/eager_few_shot_od_training_tf2_colab.ipynb 这是coco-ssd小样本训练模型,具体获取路径为 https://github.com/tensorflow/tfjs-models/tree/master/coco-ssd
下面操作运行colab示例获取对应model过程:
https://github.com/tensorflow/models/tree/master/research/object_detection/colab_tutorials
将
https://colab.research.google.com/github/tensorflow/models/blob/master/research/object_detection/colab_tutorials/eager_few_shot_od_training_tf2_colab.ipynb
中.ipynb前改成对应文件名打开即可链接到对应示例,比如我们打开eager_few_shot_od_training_tflite.ipynb需要更换链接为
https://colab.research.google.com/github/tensorflow/models/blob/master/research/object_detection/colab_tutorials/eager_few_shot_od_training_tflite.ipynb
可以点击单元格修改对应代码
此时会提示保存副本,点击确认即可
点击全部运行
若报错# IndexError: invalid index to scalar variable. #9694 需替换代码,解决方法见 https://github.com/tensorflow/models/issues/9694
运行完这个示例默认下载的是model.tflite,我们要借助这个示例做修改转换成浏览器可用的模型,需要下载转换成tflite版本前的saveModel,即saved_model.pb
双击即可下载,之后放到对应目录,用完全路径执行以下命令即可生成我们想要的graph_model模型
tensorflowjs_converter --input_format=tf_saved_model \
--output_format=tfjs_graph_model \
--output_node_names='Postprocessor/ExpandDims_1,Postprocessor/Slice' \
/Users/echo/projects/vegetables_tf2.3/models/ \
./web_model
此示例训练的模型可以识别物体的位置轮廓,但需要训练时自己标注训练的图片中物体的轮廓
所以训练标注文字轮廓的模型会麻烦得多
采用的原项目为:https://gitee.com/song-laogou/vegetables_tf2.3
需要搭建anaconda虚拟环境:https://blog.csdn.net/ECHOSON/article/details/117220445
即安装python对应版本的miniconda:https://docs.conda.io/en/latest/miniconda.html
查看已有虚拟环境:conda env list
创建虚拟环境(创建python对应版本的虚拟环境,创建一次即可):conda create -n tf3.9 python==3.9
python 3.5才支持安装tf1.x 激活虚拟环境(即后续可在base和tf3.9之间切换环境):
conda activate tf3.9
激活后即可在此环境下在项目中执行对应命令
进入项目vegetables_tf2.3,执行python get_data.py
即百度搜索对应图片类别进行爬虫下载,数据置于data文件夹中
执行python data_split.py
进行数据划分,主要就是把data图片按比例在split_data文件夹下划分到对应流程文件夹中
test:用于测试 val:用于验证 train:用于训练
有两种模型:cnn、mobilenet,分别对应train_cnn.py、train_mobilenet.py文件 以训练mobilenet模型为例
执行python train_mobilnet.py
注意: 1、报错input empty是因为图像中有加载错误的,即size为0,此命令find ./split_data/train/ -size 0找出来是否有错误的图片
在对应文件夹全部删掉此文件,也可自己去data文件中对应数据源找出错误图片(size为0)删掉 2、报错图片类型无效的
from pathlib import Path
import imghdr
data_dir = "/home/user/datasets/samples/"
image_extensions = [".png", ".jpg"] # add there all your images file extensions
img_type_accepted_by_tf = ["bmp", "gif", "jpeg", "png"]
for filepath in Path(data_dir).rglob("*"):
if filepath.suffix.lower() in image_extensions:
img_type = imghdr.what(filepath)
if img_type is None:
print(f"{filepath} is not an image")
elif img_type not in img_type_accepted_by_tf:
print(f"{filepath} is a {img_type}, not accepted by TensorFlow")
其中,此类名数组后面测试模型的时候需要用到,表示分类,而.DS_Store是mac的隐藏文件(具体作用查看),需要删掉
在train、test、val目录下执行sudo ls
即可看到此隐藏文件,执行rmdir .DS_Store
即可删除
最后,训练好的模型生成与model.save所指定的路径中,如下
执行python test_model.py
,记得在test_model.py中修改为最新生成的model路径
model = tf.keras.models.load_model("models/mobilenet_sketch.h5")
结果置于results文件夹中
训练过程图
测试结果图,区分识别准确率
执行python window.py
,记得修改模型类名数组以及模型路径
此时便可以上传图片检验自己模型的正确性啦!
当需要在网页上检测时就需要把上面生成的.h5后缀的Keras模型转换格式为以下两种tensorflowjs支持的模型
--input_format {keras,tfjs_layers_model,tf_saved_model,tf_frozen_model,tf_hub,keras_saved_model}
Input format. For "keras", the input path can be one of the two following formats: - A topology+weights
combined HDF5 (e.g., generated with `tf.keras.model.save_model()` method). - A weights-only HDF5 (e.g.,
generated with Keras Model's `save_weights()` method). For "keras_saved_model", the input_path must point
to a subfolder under the saved model folder that is passed as the argument to
tf.contrib.save_model.save_keras_model(). The subfolder is generated automatically by tensorflow when
saving keras model in the SavedModel format. It is usually named as a Unix epoch time (e.g., 1542212752).
For "tf" formats, a SavedModel, frozen model, or TF-Hub module is expected.
--output_format {tfjs_graph_model,keras_saved_model,keras,tfjs_layers_model}
手绘图片识别的模型格式即为layerModel,需要用tfl.loadLayersModel来读取模型数据,tfl对应库为@tensorflow/tfjs-layers
处理图像数据的方式为
const imgData = {
data: new Uint8Array(frame.data),
height: frame.width,
width: frame.height,
};
preprocess(imgData) {
return tf.tidy(() => {
var tensor = tf.browser.fromPixels(imgData, 1)
const resized = tf.image.resizeBilinear(tensor, [28, 28]).toFloat()
const offset = tf.scalar(255.0);
const normalized = tf.scalar(1.0).sub(resized.div(offset));
const batched = normalized.expandDims(0)
return batched
})
}
coco ssd的模型格式为graphModel,需要用cocoSsd.load()来读取模型数据,cocoSsd对应库为@tensorflow-models/coco-ssd
处理图形数据的方式为
tf.tidy(() => {
const imgData = {
data: new Uint8Array(frame.data),
width: frame.width,
height: frame.height
}
const temp = tf.browser.fromPixels(imgData, 4)
const sliceOptions = getFrameSliceOptions(frame.width, frame.height, this.displaySize.width, this.displaySize.height)
return temp.slice(sliceOptions.start, sliceOptions.size).resizeBilinear([this.displaySize.height, this.displaySize.width]).asType('int32')
})
以下是相关coco-ssd的介绍,
https://github.com/tensorflow/tfjs-models/tree/master/coco-ssd
并且可实现原始模型数据转换对应格式的模型,如转换为graphModel方式如下
tensorflowjs_converter --input_format=keras \
--output_format=tfjs_graph_model \
--output_node_names='Postprocessor/ExpandDims_1,Postprocessor/Slice' \
./pokemon.h5 \
./web_model
小程序tensorflow插件文档: https://mp.weixin.qq.com/wxopen/plugindevdoc?appid=wx6afed118d9e81df9&token=378013697&lang=zh_CN 具体步骤为:
无需在小程序后台添加tensorflow插件,用测试号,并加入对应代码即可
// app.json
"plugins": {
"tfjsPlugin": {
"version": "0.2.0",
"provider": "wx6afed118d9e81df9"
}
配置环境,加载npm依赖,执行npm init初始化生成package.json,并添加以下依赖
"dependencies": {
"@tensorflow/tfjs-core": "^1.2.6",
"@tensorflow/tfjs-layers": "^1.2.2",
"fetch-wechat": "^0.0.3"
}
之后执行npm i --legacy-peer-deps
因为版本依赖错误,会报错说让加上--forece或者--legacy-peer-deps,我们选择后者,注意安装完依赖需要在小程序工具-构建npm,每次安装依赖都需要构建npm生成miniprogram_npm里对应依赖;如果报错建议固定依赖版本为以上版本号,去掉^即可
最后在app.js配置插件
// app.js 控制台正确戴银1,2,3,4就是引入成功
var fetchWechat = require('fetch-wechat');
var tf = require('@tensorflow/tfjs-core');
var plugin = requirePlugin('tfjsPlugin');
App({
onLaunch: function () {
plugin.configPlugin({
// polyfill fetch function
fetchFunc: fetchWechat.fetchFunc(),
// inject tfjs runtime
tf,
// inject webgl backend
// webgl,
// provide webgl canvas
canvas: wx.createOffscreenCanvas()
});
console.log('tf',tf);
tf.tensor([1,2,3,4]).print()
}
})
TensorFlow.js v2.0 有一个联合包 - @tensorflow/tfjs,包含了六个分npm包:
对于小程序而言,由于有2M的app大小限制,不建议直接使用联合包,而是按照需求加载分包。
引入所需依赖
const tf = require('@tensorflow/tfjs-core')
const tfl = require('@tensorflow/tfjs-layers')
加载layersmodel格式的模型
this.net = await tfl.loadLayersModel('http://192.168.3.5:8080/model2/model.json')
console.log(this.net);
this.net.summary() // 打印模型结构
模型预测,需要先处理图像数据 其中图像数据res为wx.canvasGetImageData获取的画布像素成功回调的数据,res.data为Uint8ClampedArray的buffer数据,但是小程序获取的像素数据跟h5获取的有些许不一样(参考此issue: https://developers.weixin.qq.com/community/develop/doc/000406550b8478e773c63883f5bc00 ),是翻转过的,需要翻转一次,即
wx.canvasGetImageData({
canvasId: 'firstCanvas',
x: 0,
y: 0,
width: 400,
height: 200,
success: function (res) {
const formatRes = that.revertImage(res.data, res.width, res.height)
// 需要转换一下,否则tf.browser.fromPixels(imgData, 1)会报参数错误
const imgData = {
data: new Uint8Array(formatRes.data),
height: formatRes.width,
width: formatRes.height,
};
}
})
// ...
revertImage(data, width, height) {
if (!data) return data;
let dataViews = [];
let len = width * 4;
for (let i = 0; i < height; i++) {
let start = i * width * 4;
let newBuff = data.slice(start, start + len);
dataViews.unshift(newBuff);
}
let result = this.concatArrayBuffer(...dataViews);
return {data: result, width, height};
},
getMinBox() {
//get coordinates
let coords = drawInfos[0].drawArr
var coorX = coords.map(function(p) {
return p.x
});
var coorY = coords.map(function(p) {
return p.y
});
//find top left and bottom right corners
var min_coords = {
x: Math.min.apply(null, coorX),
y: Math.min.apply(null, coorY)
}
var max_coords = {
x: Math.max.apply(null, coorX),
y: Math.max.apply(null, coorY)
}
//return as strucut
return {
min: min_coords,
max: max_coords
}
}
对图像数据归一化等处理
preprocess(imgData) {
return tf.tidy(() => {
var tensor = tf.browser.fromPixels(imgData, 1)
const resized = tf.image.resizeBilinear(tensor, [28, 28]).toFloat()
const offset = tf.scalar(255.0);
const normalized = tf.scalar(1.0).sub(resized.div(offset));
const batched = normalized.expandDims(0)
return batched
})
}
预测结果为
const pred = this.net.predict(formatData).dataSync()
之后便可以跟对应模型识别的类数据数组对应起来找出前5个识别出来的实物以及对应置信率,具体可参考H5版绘图识别的代码:https://github.com/zaidalyafeai/zaidalyafeai.github.io/tree/master/sketcher ,在线例子为:https://zaidalyafeai.github.io/sketcher/
本人写的对应小程序例子为:https://git.woa.com/yiqiuzheng/painting-ar-gifts ,其中的canvas-painting分支
以coco-ssd模型实物识别为例,其模型格式为GraphModel app.js需要做插件配置、环境修改
"dependencies": {
"@tensorflow-models/coco-ssd": "^2.1.0",
"@tensorflow-models/posenet": "^2.2.1",
"@tensorflow/tfjs": "^3.18.0",
"@tensorflow/tfjs-backend-cpu": "^2.7.0",
"@tensorflow/tfjs-backend-webgl": "^2.7.0",
"@tensorflow/tfjs-converter": "^2.7.0",
"@tensorflow/tfjs-core": "^2.7.0",
"@vant/weapp": "^1.6.1",
"fetch-wechat": "^0.0.3"
}
var fetchWechat = require('fetch-wechat');
var tf = require('@tensorflow/tfjs-core');
var webgl = require('@tensorflow/tfjs-backend-webgl');
var cpu = require('@tensorflow/tfjs-backend-cpu');
var plugin = requirePlugin('tfjsPlugin');
App({
onLaunch: function () {
this.getDeviceInfo();
tf.ENV.flagRegistry.WEBGL_VERSION.evaluationFn = () => { return 1 };
plugin.configPlugin({
// polyfill fetch function
fetchFunc: fetchWechat.fetchFunc(),
// inject tfjs runtime
tf,
// inject webgl backend
webgl,
// inject cpu backend
cpu,
// provide webgl canvas
canvas: wx.createOffscreenCanvas()
});
// tf.tensor([1, 2, 3, 4]).print();
}
})
小程序开启摄像流数据
// index.wxml
<camera class="camera" device-position="back" flash="off" frame-size="medium">
<canvas class="canvas" canvas-id="ssd"></canvas>
</camera>
// index.js
onReady: function () {
this.ctx = wx.createCanvasContext('ssd');
const context = wx.createCameraContext(this);
this.initClassifier();
let count = 0;
const listener = context.onCameraFrame(frame => {
count++;
if (count === 2) { // 控制帧数
if (this.classifier && this.classifier.isReady()) {
this.executeClassify(frame);
}
count = 0;
}
});
listener.start();
},
cocos-ssd模型识别
import * as tf from '@tensorflow/tfjs-core'
import * as cocoSsd from '@tensorflow-models/coco-ssd'
// 加载模型
cocoSsd.load({
modelUrl: SSD_NET_URL,
})
.then((model) => {
console.log('model', model);
this.ssdNet = model;
this.ready = true;
})
.catch((err) => {
});
处理图像的关键逻辑
detect(frame) {
return new Promise((resolve, reject) => {
const tensor = tf.tidy(() => {
const imgData = {
data: new Uint8Array(frame.data),
width: frame.width,
height: frame.height
}
const temp = tf.browser.fromPixels(imgData, 4)
const sliceOptions = getFrameSliceOptions(frame.width, frame.height, this.displaySize.width, this.displaySize.height)
return temp.slice(sliceOptions.start, sliceOptions.size).resizeBilinear([this.displaySize.height, this.displaySize.width]).asType('int32')
})
this.ssdNet.detect(tensor).then(res => {
tensor.dispose()
resolve(res)
}).catch(err => {
console.log(err)
tensor.dispose()
reject()
})
})
}
原始项目为:https://github.com/yaun369/tensorflow-wxapp
本人写的对应小程序例子为:https://git.woa.com/yiqiuzheng/painting-ar-gifts ,其中的coco-ssd-position分支
temp.slice(sliceOptions.start, sliceOptions.size).resizeBilinear([this.displaySize.height, this.displaySize.width]).asType('float32')中改为.asType('int32')
开发者工具调试没问题,但是真机预览的时候报错,报tfjs-converter找不到,但是明明路径都正确了
实践此手绘识别库遇到问题https://zaidalyafeai.github.io/sketcher/ ,即 使用手绘canvas识别准确
使用图片识别不准
经测试,下载手绘的图片之后再识别也有问题,因此怀疑是转换图片的方法有误,必须是canvas api绘制的图形才能检测
temp.slice(sliceOptions.start, sliceOptions.size).resizeBilinear([this.displaySize.height, this.displaySize.width]).asType('int32').expandDims(0)中.expandDims(0)去掉
通过对比正常加载的模型合自己训练的模型,调试源码断点发现,result得是一个数组,因此修改coco-ssd源码做兼容
至于result为什么是一个数组,调试发现其包含score以及bbox,即置信度和预测模型的位置信息
意味着训练模型要记录模型的位置轮廓信息
https://github.com/tensorflow/models/issues/10558 重置colab并替换第一行代码运行
!pip install tensorflow==2.8
!apt install --allow-change-held-packages libcudnn8=8.1.0.77-1+cuda11.2
意思是数据输入的宽高是300,但要求是320,即需要改以下数据
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
扫码关注腾讯云开发者
领取腾讯云代金券
Copyright © 2013 - 2025 Tencent Cloud. All Rights Reserved. 腾讯云 版权所有
深圳市腾讯计算机系统有限公司 ICP备案/许可证号:粤B2-20090059 深公网安备号 44030502008569
腾讯云计算(北京)有限责任公司 京ICP证150476号 | 京ICP备11018762号 | 京公网安备号11010802020287
Copyright © 2013 - 2025 Tencent Cloud.
All Rights Reserved. 腾讯云 版权所有