首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

TensorFlow服务:使用REST Api向TFServing模型发送多个输入的字典?

在 TensorFlow Serving 中,你可以使用 REST API 向模型发送多个输入的字典。以下是一个详细的步骤指南,展示如何使用 REST API 向 TensorFlow Serving 模型发送包含多个输入的字典。

1. 启动 TensorFlow Serving

首先,确保你已经启动了 TensorFlow Serving 并加载了你的模型。假设你的模型名称是 my_model,并且它在端口 8501 上运行。

代码语言:javascript
复制
docker run -p 8501:8501 --name=tfserving_my_model --mount type=bind,source=/path/to/your/model,destination=/models/my_model -e MODEL_NAME=my_model -t tensorflow/serving

2. 准备输入数据

假设你的模型接受两个输入:input_1input_2。你需要将这些输入数据组织成一个字典,并将其转换为 JSON 格式。

代码语言:javascript
复制
import json

# 准备输入数据
data = {
    "signature_name": "serving_default",
    "instances": [
        {
            "input_1": [1.0, 2.0, 3.0],
            "input_2": [4.0, 5.0, 6.0]
        },
        {
            "input_1": [7.0, 8.0, 9.0],
            "input_2": [10.0, 11.0, 12.0]
        }
    ]
}

# 将数据转换为 JSON 格式
json_data = json.dumps(data)

3. 发送请求

使用 requests 库向 TensorFlow Serving 发送 POST 请求。

代码语言:javascript
复制
import requests

# TensorFlow Serving REST API URL
url = 'http://localhost:8501/v1/models/my_model:predict'

# 发送 POST 请求
response = requests.post(url, data=json_data, headers={"content-type": "application/json"})

# 打印响应
print(response.json())

4. 处理响应

响应将包含模型的预测结果。你可以根据需要处理这些结果。

代码语言:javascript
复制
# 处理响应
predictions = response.json()['predictions']
for i, prediction in enumerate(predictions):
    print(f"Prediction for instance {i}: {prediction}")

完整示例

以下是一个完整的示例代码,展示了如何准备输入数据、发送请求并处理响应。

代码语言:javascript
复制
import json
import requests

# 准备输入数据
data = {
    "signature_name": "serving_default",
    "instances": [
        {
            "input_1": [1.0, 2.0, 3.0],
            "input_2": [4.0, 5.0, 6.0]
        },
        {
            "input_1": [7.0, 8.0, 9.0],
            "input_2": [10.0, 11.0, 12.0]
        }
    ]
}

# 将数据转换为 JSON 格式
json_data = json.dumps(data)

# TensorFlow Serving REST API URL
url = 'http://localhost:8501/v1/models/my_model:predict'

# 发送 POST 请求
response = requests.post(url, data=json_data, headers={"content-type": "application/json"})

# 打印响应
print(response.json())

# 处理响应
predictions = response.json()['predictions']
for i, prediction in enumerate(predictions):
    print(f"Prediction for instance {i}: {prediction}")

注意事项

  1. 模型签名:确保 signature_name 与模型的签名名称匹配。默认情况下,TensorFlow Serving 使用 serving_default 签名。
  2. 输入格式:确保输入数据的格式与模型的输入格式匹配。输入数据应为 JSON 格式,并且每个输入字段的名称和数据类型应与模型的输入签名一致。
  3. 错误处理:在实际应用中,添加错误处理代码以处理可能的请求失败或响应错误。
页面内容是否对你有帮助?
有帮助
没帮助

相关·内容

领券