社区首页 >问答首页 >不兼容的形状:[11,768]与[1,5,768] -在生产中使用huggingface保存的模型进行推断

不兼容的形状:[11,768]与[1,5,768] -在生产中使用huggingface保存的模型进行推断
EN

Stack Overflow用户
提问于 2020-08-29 00:47:39
回答 1查看 277关注 0票数 1

我从huggingface模型中保存了一个预训练版本的distilbert,distilbert-base-uncased-finetuned-sst-2-english,,我正试图通过Tensorflow服务和进行预测来提供它。目前所有的测试都在Colab进行。

我在通过TensorFlow Serve将预测转换为正确的模型格式时遇到了问题。Tensorflow服务已经启动并运行良好,为模型提供了服务,但是我的预测代码不正确,我需要一些帮助来理解如何通过API通过json进行预测。

代码语言:javascript
代码运行次数:0
复制
# tokenize and encode a simple positive instance
instances = tokenizer.tokenize('this is the best day of my life!')
instances = tokenizer.encode(instances)
data = json.dumps({"signature_name": "serving_default", "instances": instances, })
print(data)

{"signature_name":"serving_default",“实例”:101,2023,2003,1996,2190,2154,1997,2026,2166,999,102}

代码语言:javascript
代码运行次数:0
复制
# setup json_response object
headers = {"content-type": "application/json"}
json_response = requests.post('http://localhost:8501/v1/models/my_model:predict', data=data, headers=headers)
predictions = json.loads(json_response.text)

预测

代码语言:javascript
代码运行次数:0
复制
{'error': '{{function_node __inference__wrapped_model_52602}} {{function_node __inference__wrapped_model_52602}} Incompatible shapes: [11,768] vs. [1,5,768]\n\t [[{{node tf_distil_bert_for_sequence_classification_3/distilbert/embeddings/add}}]]\n\t [[StatefulPartitionedCall/StatefulPartitionedCall]]'}

这里的任何方向都将不胜感激。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-08-31 00:48:40

能够通过为输入形状和注意掩码设置签名来找到解决方案,如下所示。这是一个简单的实现,它为保存的模型使用固定的输入形状,并要求您将输入填充到预期的输入形状384。我已经看到了调用自定义签名和创建模型来匹配预期输入形状的实现,但是下面的简单案例适用于我希望通过TF服务实现的huggingface模型。如果任何人有任何更好的例子或方法来更好地扩展此功能,请张贴以供将来使用。

代码语言:javascript
代码运行次数:0
复制
# create callable
from transformers import TFDistilBertForQuestionAnswering
distilbert = TFDistilBertForQuestionAnswering.from_pretrained('distilbert-base-cased-distilled-squad')
callable = tf.function(distilbert.call)

通过调用get_concrete_function,我们为输入签名跟踪编译模型的TensorFlow操作,该签名由两个形状为None,384的张量组成,第一个是输入ids,第二个是注意掩码。

代码语言:javascript
代码运行次数:0
复制
concrete_function = callable.get_concrete_function([tf.TensorSpec([None, 384], tf.int32, name="input_ids"), tf.TensorSpec([None, 384], tf.int32, name="attention_mask")])

保存带有签名的模型:

代码语言:javascript
代码运行次数:0
复制
# stored model path for TF Serve (1 = version 1) --> '/path/to/my/model/distilbert_qa/1/'
distilbert_qa_save_path = 'path_to_model'
tf.saved_model.save(distilbert, distilbert_qa_save_path, signatures=concrete_function)

检查它是否包含正确的签名:

代码语言:javascript
代码运行次数:0
复制
saved_model_cli show --dir 'path_to_model' --tag_set serve --signature_def serving_default

输出应如下所示:

代码语言:javascript
代码运行次数:0
复制
The given SavedModel SignatureDef contains the following input(s):
  inputs['attention_mask'] tensor_info:
      dtype: DT_INT32
      shape: (-1, 384)
      name: serving_default_attention_mask:0
  inputs['input_ids'] tensor_info:
      dtype: DT_INT32
      shape: (-1, 384)
      name: serving_default_input_ids:0
The given SavedModel SignatureDef contains the following output(s):
  outputs['output_0'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1, 384)
      name: StatefulPartitionedCall:0
  outputs['output_1'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1, 384)
      name: StatefulPartitionedCall:1
Method name is: tensorflow/serving/predict

测试模型:

代码语言:javascript
代码运行次数:0
复制
from transformers import DistilBertTokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')

question, text = "Who was Benjamin?", "Benjamin was a silly dog."
input_dict = tokenizer(question, text, return_tensors='tf')

start_scores, end_scores = distilbert(input_dict)

all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])

对于TF服务(在colab中):(这是我的初衷)

代码语言:javascript
代码运行次数:0
复制
!echo "deb http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | tee /etc/apt/sources.list.d/tensorflow-serving.list && \
curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | apt-key add -
!apt update
代码语言:javascript
代码运行次数:0
复制
!apt-get install tensorflow-model-server
代码语言:javascript
代码运行次数:0
复制
import os
# path_to_model --> versions directory --> '/path/to/my/model/distilbert_qa/'
# actual stored model path version 1 --> '/path/to/my/model/distilbert_qa/1/'
MODEL_DIR = 'path_to_model'
os.environ["MODEL_DIR"] = os.path.abspath(MODEL_DIR)
代码语言:javascript
代码运行次数:0
复制
%%bash --bg
nohup tensorflow_model_server --rest_api_port=8501 --model_name=my_model --model_base_path="${MODEL_DIR}" >server.log 2>&1
代码语言:javascript
代码运行次数:0
复制
!tail server.log

发出POST请求:

代码语言:javascript
代码运行次数:0
复制
import json
!pip install -q requests
import requests
import numpy as np

max_length = 384  # must equal model signature expected input value
question, text = "Who was Benjamin?", "Benjamin was a good boy."

# padding='max_length' pads the input to the expected input length (else incompatible shapes error)
input_dict = tokenizer(question, text, return_tensors='tf', padding='max_length', max_length=max_length)

input_ids = input_dict["input_ids"].numpy().tolist()[0]
att_mask = input_dict["attention_mask"].numpy().tolist()[0]
features = [{'input_ids': input_ids, 'attention_mask': att_mask}]

data = json.dumps({ "signature_name": "serving_default", "instances": features})

headers = {"content-type": "application/json"}
json_response = requests.post('http://localhost:8501/v1/models/my_model:predict', data=data, headers=headers)
print(json_response)

predictions = json.loads(json_response.text)['predictions']

all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
answer = ' '.join(all_tokens[tf.math.argmax(predictions[0]['output_0']) : tf.math.argmax(predictions[0]['output_1'])+1])
print(answer)
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63642440

复制
相关文章
Java反射探索-----从类加载说起
林炳文Evankaka原创作品。转载请注明出处http://blog.csdn.net/evankaka
bear_fish
2018/09/20
5220
Java反射探索-----从类加载说起
枚举帮助类
1 using System; 2 using System.Collections.Generic; 3 using System.ComponentModel; 4 using System.Linq; 5 6 namespace EnumHelper 7 { 8 /// <summary> 9 /// 枚举帮助类 10 /// 1、获取枚举的描述文本 11 /// 2、获取枚举名和描述信息的列表 12 /// </summary> 13
用户6362579
2019/09/29
5380
日志帮助类
 1.代码 using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.IO; using System.Configuration; using System.Reflection; namespace LogHelper.Common { public class LogHelper { private string logFile
用户1055830
2018/01/18
6520
日志帮助类
从源码角度学习JVM类加载器及自定义类加载器
负责加载支撑JVM运行的位于JRE的lib目录下的核心类库,这个加载器是由C++写的,所以我们在java源码里面是找不到它的实现,如果尝试对它进行打印,输出将为空值。
AI码师
2022/12/22
3860
从源码角度学习JVM类加载器及自定义类加载器
内存泄漏 - 从Class类加载器说起
某公司技术人员针对企业应用系统12月10日内存溢出事件进行了广泛的技术探讨,并得到了一些建设性的建议和结论。
IT技术小咖
2019/09/24
2.9K0
内存泄漏 - 从Class类加载器说起
详细讲解!从JVM直到类加载器
整个过程是,x.java文件需要编译成x.class文件,通过类加载器加载到内存中,然后通过解释器或者即时编译器进行解释和编译,最后交给执行引擎执行,执行引擎操作OS硬件。
java技术爱好者
2020/09/22
4320
JVM | 从类加载到JVM内存结构
我在上篇文章:JVM | 基于类加载的一次完全实践 中为你讲解如何请“建筑工人”来做一些定制化的工作。但是,大型的Java应用程序时,材料(类)何止数万,我们直接堆放在工地上(JVM)上吗?相反,JVM有着一套精密的管理机制,来确保类的加载、验证、解析和初始化等任务能够有序且高效地完成。
kfaino
2023/10/02
2750
JVM | 从类加载到JVM内存结构
在成为CTO之前,程序员怎样赚外快?
作为一个码code的程序员,虽然可能没有朋友,比较宅,但是整体花销往往不比正常人少。VPS,域名,MAC还有一堆的收费软件,数码设备等,都是卖肾的节奏。 当然作为程序员,我们也可以有更多的赚钱姿势,如果你认为只有接私单,那么你就OUT了,我们看看有没有其他的方式呢? 私单 最理想的单子还是直接接海外的项目,比如freelance.com等网站。一方面是因为挣的是美刀比较划算,之前看到像给WordPress写支付+发送注册码这种大家一个周末就能做完的项目,也可以到200~300美刀;另一方面是在国外接单子比较
春哥大魔王
2018/04/16
1.8K0
原 数据接收和数据返回呈现,都用一个类代替
import java.util.Collection; import java.util.HashMap; import java.util.Iterator; import java.util.Map; import java.util.Set; import javax.servlet.http.HttpServletRequest; public class Record extends HashMap implements Map {     private static final lo
kinbug [进阶者]
2018/06/13
4500
Java类加载-加载
我们已经将整个Class的构成讲述完了,不清楚的同学可以看一下关于Class文件的介绍,但是空有Class并没有什么用,在Class中的各种描述信息都需要被加载到虚拟机以后才能运行使用。
shysh95
2021/02/25
1.3K0
Java类加载-加载
python 数据图表呈现
平时压力测试,生成一些数据后分析,直接看 log 不是很直观,前段时间看到公司同事分享了一个绘制图表python 模块 : plotly, 觉得很实用,利用周末时间熟悉下。
orientlu
2018/09/13
1.2K0
python 数据图表呈现
ajax请求完之前的loading加载
很多时候我们需要引入框架来开发项目,这时我们可能会遇到页面还没加载完源码出来了的问题,给用户一种不好的视觉体验,这是便需要loading加载了,来完善用户体验!
ProsperLee
2018/10/24
1.5K0
ajax请求完之前的loading加载
DevExpress数据绑定呈现
数据库这里为了方便演示,用的SQL Server 由于我数据库中的表有8列数据,这里添加8列,并设置列名和绑定的数据名称:
别团等shy哥发育
2023/02/27
1.6K0
DevExpress数据绑定呈现
类加载
其中类加载的过程包括了加载、验证、准备、解析、初始化五个阶段。在这五个阶段中,加载、验证、准备和初始化这四个阶段发生的顺序是确定的,而解析阶段则不一定,它在某些情况下可以在初始化阶段之后开始,这是为了支持Java语言的运行时绑定(也成为动态绑定或晚期绑定)。另外注意这里的几个阶段是按顺序开始,而不是按顺序进行或完成,因为这些阶段通常都是互相交叉地混合进行的,通常在一个阶段执行的过程中调用或激活另一个阶段。
码农戏码
2021/03/23
4980
类加载
我们知道在运行Java程序时,首先需要把源代码编译成二进制文件也就是class文件,然后虚拟机才能执行。那虚拟机在执行class文件时,都进行了哪些步骤呢。下面我们将详细分享一下。当类也就是class文件被加载到虚拟机内存开始,到卸载出内存为止。它将要执行以下7个步骤:
吉林乌拉
2019/08/14
4970
类加载
启动类加载器,Bootstrap ClassLoader,加载JACA_HOME\lib,或者被-Xbootclasspath参数限定的类 扩展类加载器,Extension ClassLoader,加载\lib\ext,或者被java.ext.dirs系统变量指定的类 应用程序类加载器,Application ClassLoader,加载ClassPath中的类库 自定义类加载器,通过继承ClassLoader实现,一般是加载我们的自定义类
葆宁
2019/04/18
4850
类加载
【Android 逆向】类加载器 ClassLoader ( 启动类加载器 | 扩展类加载器 | 应用类加载器 | 类加载的双亲委托机制 )
类加载器加载类流程 : Bootstrap ClassLoader 先加载系统的核心类库 , Extention ClassLoader 加载额外的 /lib/ext 类库 , Application ClassLoader 加载开发者自己开发的类库 ;
韩曙亮
2023/03/30
8920
从ng1看ng2 关于NgModule的简易归纳
最近开始折腾ng2,其实说是ng2,到目前为止,它已经发布了4.3版,就是这么的高产,高产似*,我连2都还木有完整的看完它竟然发布了4.的版本(鄙视脸)。
littlelyon
2018/10/19
9550
写一个Foreach帮助类,在razor中使用
esterday, during my ASP.NET MVC 3 talk at Mix 11, I wrote a useful helper method demonstrating an advanced feature of Razor, Razor Templated Delegates.
javascript.shop
2019/09/04
4980
jvm怎么加载类_jvm类加载器
原因: 1、存放在自定义路径上的类,需要通过自定义类加载器去加载。【注意:AppClassLoader加载classpath下的类】 2、类不一定从文件中加载,也可能从网络中的流中加载,这就需要自定义加载器去实现加密解密。 3、可以定义类的实现机制,实现类的热部署, 如OSGi中的bundle模块就是通过实现自己的ClassLoader实现的, 如tomcat实现的自定义类加载模型。
全栈程序员站长
2022/10/29
4670

相似问题

Shell脚本在特定单词之后打印一定数量的单词

34

在shell脚本中从文件中选择特定的单词

20

在特定单词之后提取单词

30

在shell脚本中搜索特定的单词

34

PHP在某个单词之后选择单词

25
添加站长 进交流群

领取专属 10元无门槛券

AI混元助手 在线答疑

扫码加入开发者社群
关注 腾讯云开发者公众号

洞察 腾讯核心技术

剖析业界实践案例

扫码关注腾讯云开发者公众号
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档