前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【LangChain系列】【与SQL交互时如何得到更好的结果&输出的查询结果验证方案】

【LangChain系列】【与SQL交互时如何得到更好的结果&输出的查询结果验证方案】

原创
作者头像
知冷煖
发布2024-09-21 23:53:29
570
发布2024-09-21 23:53:29
举报
文章被收录于专栏:Langchain

一、LangChain介绍

LangChain是一个框架,用于开发由大型语言模型(LLM)驱动的应用程序。

LangChain 简化了 LLM 应用程序生命周期的每个阶段:

  • 开发:使用LangChain的开源构建块和组件构建应用程序。使用第三方集成和模板开始运行。
  • 生产化:使用 LangSmith 检查、监控和评估您的链条,以便您可以自信地持续优化和部署。
  • 部署:使用 LangServe 将任何链转换为 API。

二、在SQL问答时如何更好的提示?

2-1、安装

代码语言:python
代码运行次数:0
复制
pip install --upgrade --quiet  langchain langchain-community langchain-experimental langchain-openai

2-2、SQLite 样例数据

参考:https://database.guide/2-sample-databases-sqlite/

Chinook 数据: 它代表了一个数字媒体商店,包括艺术家、专辑、媒体曲目、发票和客户的信息,以表格形式呈现。

1、创建数据库: 使用sqlite3 命令来创建

代码语言:python
代码运行次数:0
复制
sqlite3 Chinook.db

2、sql脚本下载、运行

sql脚本地址: https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql

代码语言:python
代码运行次数:0
复制
# 将脚本粘贴到Chinook_Sqlite.sql文件内后,执行以下命令可以创建数据库表。
.read Chinook_Sqlite.sql

2-3、使用langchain与其进行交互

我们可以使用SQLAlchemy驱动的SQLDatabase类与它交互:

代码语言:python
代码运行次数:0
复制
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db", sample_rows_in_table_info=3)
print(db.dialect)
print(db.get_usable_table_names())
print(db.run("SELECT * FROM Artist LIMIT 10;"))

输出:

*sqlite

'Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track'

(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')*

优化:

代码语言:python
代码运行次数:0
复制
from langchain_community.utilities import SQLDatabase
import os

db_path = os.path.join(os.path.dirname(__file__), 'Chinook.db')
db_full_path = os.path.abspath(db_path)
db = SQLDatabase.from_uri(f"sqlite:///{db_full_path}")

2-4、查看模型提示语

安装:

代码语言:python
代码运行次数:0
复制
pip install -qU langchain-openai
代码语言:python
代码运行次数:0
复制
import getpass
import os

os.environ["OPENAI_API_KEY"] = getpass.getpass()

from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
代码语言:python
代码运行次数:0
复制
from langchain.chains import create_sql_query_chain

chain = create_sql_query_chain(llm, db)
chain.get_prompts()[0].pretty_print()

输出:

*You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.

Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.

Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.

Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here

SQLQuery: SQL Query to run

SQLResult: Result of the SQLQuery

Answer: Final answer here

Only use the following tables:

{table_info}

Question: {input}

None*

Notice:我这里使用的是阿里的模型,对传入的llm要做一个修改, 使用OpenAI的不需要修改。

代码语言:python
代码运行次数:0
复制
from langchain_community.chat_models.tongyi import ChatTongyi

# 环境变量设置,模型接口设置
os.environ["LANGCHAIN_TRACING_V2"] = ""
os.environ["LANGCHAIN_API_KEY"] = ""
os.environ["DASHSCOPE_API_KEY"] = ''
model = ChatTongyi(
    streaming=True,
)

2-5、提供表定义和示例行

概述: 在大多数SQL链中,我们至少需要向模型提供部分数据库大纲。没有这个,它将无法编写有效的查询。我们的数据库提供了一些方便的方法来提供相关的上下文。具体来说,我们可以从每个表中获取表名、表的概要和行示例。

代码语言:python
代码运行次数:0
复制
context = db.get_context()
print(list(context))
print(context["table_info"])

输出: 只截取部分。

2-6、将表信息插入到Prompt中去

代码语言:python
代码运行次数:0
复制
prompt_with_context = chain.get_prompts()[0].partial(table_info=context["table_info"])
print(prompt_with_context.pretty_repr()[:1500])

输出:

2-7、添加自然语言->SQL示例

概述: 在Prompt中包含将自然语言问题转换为针对数据库的有效SQL查询的示例,通常会提高模型性能,特别是对于复杂查询。

代码语言:python
代码运行次数:0
复制
examples = [
    {"input": "List all artists.", "query": "SELECT * FROM Artist;"},
    {
        "input": "Find all albums for the artist 'AC/DC'.",
        "query": "SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');",
    },
    {
        "input": "List all tracks in the 'Rock' genre.",
        "query": "SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');",
    },
    {
        "input": "Find the total duration of all tracks.",
        "query": "SELECT SUM(Milliseconds) FROM Track;",
    },
    {
        "input": "List all customers from Canada.",
        "query": "SELECT * FROM Customer WHERE Country = 'Canada';",
    },
    {
        "input": "How many tracks are there in the album with ID 5?",
        "query": "SELECT COUNT(*) FROM Track WHERE AlbumId = 5;",
    },
    {
        "input": "Find the total number of invoices.",
        "query": "SELECT COUNT(*) FROM Invoice;",
    },
    {
        "input": "List all tracks that are longer than 5 minutes.",
        "query": "SELECT * FROM Track WHERE Milliseconds > 300000;",
    },
    {
        "input": "Who are the top 5 customers by total purchase?",
        "query": "SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;",
    },
    {
        "input": "Which albums are from the year 2000?",
        "query": "SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';",
    },
    {
        "input": "How many employees are there",
        "query": 'SELECT COUNT(*) FROM "Employee"',
    },
]

构建提示词模板:

代码语言:python
代码运行次数:0
复制
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate

example_prompt = PromptTemplate.from_template("User input: {input}\nSQL query: {query}")
prompt = FewShotPromptTemplate(
    examples=examples[:5],
    example_prompt=example_prompt,
    prefix="You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run. Unless otherwise specificed, do not return more than {top_k} rows.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries.",
    suffix="User input: {input}\nSQL query: ",
    input_variables=["input", "top_k", "table_info"],
)
print(prompt.format(input="How many artists are there?", top_k=3, table_info="foo"))

输出:

*You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run. Unless otherwise specificed, do not return more than 3 rows.

Here is the relevant table info: foo

Below are a number of examples of questions and their corresponding SQL queries.

User input: List all artists.

SQL query: SELECT * FROM Artist;

User input: Find all albums for the artist 'AC/DC'.

SQL query: SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');

User input: List all tracks in the 'Rock' genre.

SQL query: SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');

User input: Find the total duration of all tracks.

SQL query: SELECT SUM(Milliseconds) FROM Track;

User input: List all customers from Canada.

SQL query: SELECT * FROM Customer WHERE Country = 'Canada';

User input: How many artists are there?

SQL query:*

2-8、验证输出结果

SQL问答的二次验证:

  • 构建思维链
  • 构建提示词,让模型二次检查SQL语句的准确性
  • 构建完整思维链
代码语言:python
代码运行次数:0
复制
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_sql_query_chain

chain = create_sql_query_chain(model, db)

system = """Double check the user's {dialect} query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query.
If there are no mistakes, just reproduce the original query with no further commentary.

Output the final SQL query only."""
prompt = ChatPromptTemplate.from_messages(
    [("system", system), ("human", "{query}")]
).partial(dialect=db.dialect)
validation_chain = prompt | model | StrOutputParser()
full_chain = {"query": chain} | validation_chain
query = full_chain.invoke(
    {
        "question": "How many artists are there?"
    }
)
print(query)

SQL问答的二次验证简化为一次:

代码语言:python
代码运行次数:0
复制
from langchain.chains import create_sql_query_chain

system = """You are a {dialect} expert. Given an input question, create a syntactically correct {dialect} query to run.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per {dialect}. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Only use the following tables:
{table_info}

Write an initial draft of the query. Then double check the {dialect} query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

Use format:

First draft: <<FIRST_DRAFT_QUERY>>
Final answer: <<FINAL_ANSWER_QUERY>>
"""

prompt = ChatPromptTemplate.from_messages(
    [("system", system), ("human", "{input}")]
).partial(dialect=db.dialect)


def parse_final_answer(output: str) -> str:
    return output.split("Final answer: ")[1]


chain = create_sql_query_chain(model, db, prompt=prompt) | parse_final_answer
prompt.pretty_print()

query = chain.invoke(
    {
        "question": "How many artists are there?"
    }
)
print(query)

Notice: 并不是说二次验证不好,在一般情况下,结果通常会受到大模型理解能力的影响,换句话说,规模较小、理解能力较差的模型,使用二次验证的效果反而会更好,因为会调用两次模型。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、LangChain介绍
  • 二、在SQL问答时如何更好的提示?
    • 2-1、安装
      • 2-2、SQLite 样例数据
        • 2-3、使用langchain与其进行交互
          • 2-4、查看模型提示语
            • 2-5、提供表定义和示例行
              • 2-6、将表信息插入到Prompt中去
                • 2-7、添加自然语言->SQL示例
                  • 2-8、验证输出结果
                  领券
                  问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档