构建一个基于SQL数据的问答系统

构建一个基于SQL数据的问答系统 #

https://python.langchain.com/docs/tutorials/sql_qa/

先决条件

本教程假定您对以下概念有一定了解:

使LLM系统能够查询结构化数据可能与查询非结构化文本数据有质的不同。在后者中,通常生成可以在向量数据库中搜索的文本,而对于结构化数据,LLM通常需要用特定领域语言(如SQL)编写并执行查询。在本指南中,我们将介绍在数据库的表格数据上创建问答系统的基本方法。我们将涵盖使用链和代理的实现。这些系统将允许我们询问有关数据库中数据的问题,并得到自然语言的回答。两者的主要区别在于,我们的代理可以根据需要多次循环查询数据库以回答问题。

⚠️ 安全提示 ⚠️

构建SQL数据库的问答系统需要执行模型生成的SQL查询。这样做存在固有的风险。请确保您的数据库连接权限总是尽可能地限制在链/代理的需求范围内。这将减轻(但不能消除)构建模型驱动系统的风险。关于一般安全最佳实践的更多信息,请参见此处

架构 #

在高层次上,系统的步骤是:

  1. 将问题转换为DSL查询:模型将用户输入转换为SQL查询。

  2. 执行SQL查询:执行查询。

  3. 回答问题:模型使用查询结果回应用户输入。

请注意,查询CSV数据可以采用类似的方法。有关更多详细信息,请参阅我们关于CSV数据问答的操作指南

sql_usecase

环境准备 #

依赖 #

我们将使用以下包:

1pip install --upgrade --quiet langchain langchain-community langchain-openai faiss-cpu

在本教程中,我们将使用OpenAI模型和一个由FAISS驱动的向量存储

加载环境变量配置 #

OPENAI_API_KEY, OPENAI_BASE_URL, MODEL_NAME, EMBEDDING_MODEL_NAME.env文件中配置:

1pip install python-dotenv
1from dotenv import load_dotenv
2assert load_dotenv()
3
4import os
5MODEL_NAME = os.environ.get("MODEL_NAME")
6EMBEDDING_MODEL_NAME = os.environ.get("EMBEDDING_MODEL_NAME")

LangSmith跟踪配置(可选) #

略,参见这里

使用sqlite #

以下示例将使用SQLite连接Chinook数据库。请按照这些安装步骤在与本教程相同的目录中创建Chinook.db

  • 此文件保存为Chinook.sql
  • 运行sqlite3 Chinook.db
  • 运行.read Chinook.sql
  • 测试SELECT * FROM Artist LIMIT 10;

现在,Chinook.db已经在我们的目录中,我们可以使用由SQLAlchemy驱动的SQLDatabase类与之进行交互:

1from langchain_community.utilities import SQLDatabase
2
3db = SQLDatabase.from_uri("sqlite:///Chinook.db")
4print(db.dialect)
5print(db.get_usable_table_names())
6db.run("SELECT * FROM Artist LIMIT 10;")
1['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
1"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

很好!我们已经有了一个可以查询的SQL数据库。现在让我们尝试将它连接到一个LLM。

Chains #

链(即LangChain Runnable的组合)支持步骤可预测的应用程序。我们可以创建一个简单的链,它接收一个问题并执行以下操作:

  • 将问题转换为SQL查询;
  • 执行查询;
  • 使用结果回答原始问题。

这种安排无法支持某些场景。例如,这个系统会对任何用户输入执行SQL查询——即使是"hello"。重要的是,正如我们将在下面看到的,有些问题需要多个查询才能回答。我们将在Agents部分解决这些场景。

将问题转换为SQL查询 #

SQL链或代理的第一步是将用户输入转换为SQL查询。LangChain内置了一个用于此目的的链:create_sql_query_chain

1from langchain_openai import ChatOpenAI
2
3llm = ChatOpenAI(model=MODEL_NAME)
1from langchain.chains import create_sql_query_chain
2
3chain = create_sql_query_chain(llm, db)
4response = chain.invoke({"question": "How many employees are there"})
5print(response)
1SELECT COUNT("EmployeeId") AS "TotalEmployees" FROM "Employee"\nLIMIT 1;

我们可以执行这个查询来确保它是有效的:

1db.run(response)
1'[(8,)]'

我们可以查看LangSmith追踪来更好地理解这个链在做什么。我们还可以直接检查链的提示。查看下面的提示,我们可以看到它:

  • 是针对特定方言的。在这种情况下,它明确提到了SQLite。
  • 包含了所有可用表的定义。
  • 为每个表提供了三个示例行。

这种技术受到了一些论文(如:Evaluating the Text-to-SQL Capabilities of Large Language Models)的启发,这些论文建议展示示例行并明确说明表格可以提高性能。我们还可以像这样检查完整的提示:

1chain.get_prompts()[0].pretty_print()
 1You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
 2Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
 3Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
 4Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
 5Pay attention to use date('now') function to get the current date, if the question involves "today".
 6
 7Use the following format:
 8
 9Question: Question here
10SQLQuery: SQL Query to run
11SQLResult: Result of the SQLQuery
12Answer: Final answer here
13
14Only use the following tables:
15{table_info}
16
17Question: {input}

执行SQL查询 #

现在我们已经生成了一个SQL查询,接下来我们要执行它。**这是创建SQL链中最危险的部分。**请仔细考虑是否允许在您的数据上运行自动化查询。尽可能地最小化数据库连接权限。考虑在执行查询之前为您的链添加一个人工审批步骤(见下文)。

我们可以使用QuerySQLDatabaseTool来轻松地将查询执行添加到我们的链中:

1'[(8,)]'

回答问题 #

现在我们已经有了自动生成和执行查询的方法,我们只需要将原始问题和SQL查询结果结合起来,生成最终答案。我们可以通过再次将问题和结果传递给LLM来实现这一点:

 1from operator import itemgetter
 2
 3from langchain_core.output_parsers import StrOutputParser
 4from langchain_core.prompts import PromptTemplate
 5from langchain_core.runnables import RunnablePassthrough
 6
 7answer_prompt = PromptTemplate.from_template(
 8    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.
 9
10Question: {question}
11SQL Query: {query}
12SQL Result: {result}
13Answer: """
14)
15
16chain = (
17    RunnablePassthrough.assign(query=write_query).assign(
18        result=itemgetter("query") | execute_query
19    )
20    | answer_prompt
21    | llm
22    | StrOutputParser()
23)
24
25chain.invoke({"question": "How many employees are there"})
1'There are a total of 8 employees.'

让我们回顾一下上述LCEL中发生的情况。假设这个链被调用。

  • 在第一个RunnablePassthrough.assign之后,我们得到一个包含两个元素的runnable: {"question": question, "query": write_query.invoke(question)}。其中write_query将生成一个用于回答问题的SQL查询。

  • 在第二个RunnablePassthrough.assign之后,我们添加了第三个元素"result",它包含execute_query.invoke(query),其中query是在上一步中计算得出的。

  • 这三个输入被格式化到提示中,并传递给LLM。

  • StrOutputParser()提取输出消息的字符串内容。

注意,我们正在将LLM、工具、提示和其他链组合在一起,但由于每个组件都实现了Runnable接口,它们的输入和输出可以以合理的方式连接在一起。

下一步 #

对于更复杂的查询生成,我们可能需要创建少样本提示或添加查询检查步骤。有关这些高级技术和更多内容,请查看:

  • 提示策略:高级提示工程技术。
  • 查询检查:添加查询验证和错误处理。
  • 大型数据库:处理大型数据库的技术。

Agents #

LangChain拥有一个SQL Agent,它提供了比链更灵活的与SQL数据库交互的方式。使用SQL Agent的主要优势包括:

  • 它可以根据数据库的模式以及数据库的内容(如描述特定表)来回答问题。
  • 它可以通过运行生成的查询、捕获回溯并正确重新生成来从错误中恢复。
  • 它可以根据需要多次查询数据库以回答用户问题。
  • 它将通过只检索相关表的schema来节省令牌。

要初始化agent,我们将使用SQLDatabaseToolkit创建一系列工具:

  • 创建和执行查询
  • 检查查询语法
  • 检索表描述
  • … 以及更多
1from langchain_community.agent_toolkits import SQLDatabaseToolkit
2
3toolkit = SQLDatabaseToolkit(db=db, llm=llm)
4
5tools = toolkit.get_tools()
6
7tools
1[QuerySQLDataBaseTool(description="Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7c6f42ed4d90>),
2 InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7c6f42ed4d90>),
3 ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7c6f42ed4d90>),
4 QuerySQLCheckerTool(description='Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7c6f42ed4d90>, llm=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x7c6f4216dbd0>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x7c6f42627810>, root_client=<openai.OpenAI object at 0x7c6f42625010>, root_async_client=<openai.AsyncOpenAI object at 0x7c6f421aaf10>, model_name='gpt-4o', model_kwargs={}, openai_api_key=SecretStr('**********')), llm_chain=LLMChain(verbose=False, prompt=PromptTemplate(input_variables=['dialect', 'query'], input_types={}, partial_variables={}, template='\n{query}\nDouble check the {dialect} query above for common mistakes, including:\n- Using NOT IN with NULL values\n- Using UNION when UNION ALL should have been used\n- Using BETWEEN for exclusive ranges\n- Data type mismatch in predicates\n- Properly quoting identifiers\n- Using the correct number of arguments for functions\n- Casting to the correct data type\n- Using the proper columns for joins\n\nIf there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\n\nOutput the final SQL query only.\n\nSQL Query: '), llm=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x7c6f4216dbd0>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x7c6f42627810>, root_client=<openai.OpenAI object at 0x7c6f42625010>, root_async_client=<openai.AsyncOpenAI object at 0x7c6f421aaf10>, model_name='gpt-4o', model_kwargs={}, openai_api_key=SecretStr('**********')), output_parser=StrOutputParser(), llm_kwargs={}))]

系统提示 #

我们还需要为我们的agent创建一个系统提示。这将包含关于如何行为的指令。

 1from langchain_core.messages import SystemMessage
 2
 3SQL_PREFIX = """You are an agent designed to interact with a SQL database.
 4Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
 5Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
 6You can order the results by a relevant column to return the most interesting examples in the database.
 7Never query for all the columns from a specific table, only ask for the relevant columns given the question.
 8You have access to tools for interacting with the database.
 9Only use the below tools. Only use the information returned by the below tools to construct your final answer.
10You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
11
12DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
13
14To start you should ALWAYS look at the tables in the database to see what you can query.
15Do NOT skip this step.
16Then you should query the schema of the most relevant tables."""
17
18system_message = SystemMessage(content=SQL_PREFIX)

初始化agent #

首先,获取所需的包LangGraph

1pip install -Uq langgraph

我们将使用预构建的LangGraph agent来构建我们的agent:

1from langchain_core.messages import HumanMessage
2from langgraph.prebuilt import create_react_agent
3agent_executor = create_react_agent(llm, tools, state_modifier=system_message)

API Reference: create_react_agent

考虑agent如何回应以下问题:

1for s in agent_executor.stream(
2    {"messages": [HumanMessage(content="Which country's customers spent the most?")]}
3):
4    print(s)
5    print("----")
 1{'agent': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_vnHKe3oul1xbpX0Vrb2vsamZ', 'function': {'arguments': '{"query":"SELECT c.Country, SUM(i.Total) AS Total_Spent FROM customers c JOIN invoices i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY Total_Spent DESC LIMIT 1"}', 'name': 'sql_db_query'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 53, 'prompt_tokens': 557, 'total_tokens': 610}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': 'fp_3b956da36b', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-da250593-06b5-414c-a9d9-3fc77036dd9c-0', tool_calls=[{'name': 'sql_db_query', 'args': {'query': 'SELECT c.Country, SUM(i.Total) AS Total_Spent FROM customers c JOIN invoices i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY Total_Spent DESC LIMIT 1'}, 'id': 'call_vnHKe3oul1xbpX0Vrb2vsamZ'}])]}}
 2----
 3{'action': {'messages': [ToolMessage(content='Error: (sqlite3.OperationalError) no such table: customers\n[SQL: SELECT c.Country, SUM(i.Total) AS Total_Spent FROM customers c JOIN invoices i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY Total_Spent DESC LIMIT 1]\n(Background on this error at: https://sqlalche.me/e/20/e3q8)', name='sql_db_query', id='1a5c85d4-1b30-4af3-ab9b-325cbce3b2b4', tool_call_id='call_vnHKe3oul1xbpX0Vrb2vsamZ')]}}
 4----
 5{'agent': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_pp3BBD1hwpdwskUj63G3tgaQ', 'function': {'arguments': '{}', 'name': 'sql_db_list_tables'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 12, 'prompt_tokens': 699, 'total_tokens': 711}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': 'fp_3b956da36b', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-04cf0e05-61d0-4673-b5dc-1a9b5fd71fff-0', tool_calls=[{'name': 'sql_db_list_tables', 'args': {}, 'id': 'call_pp3BBD1hwpdwskUj63G3tgaQ'}])]}}
 6----
 7{'action': {'messages': [ToolMessage(content='Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track', name='sql_db_list_tables', id='c2668450-4d73-4d32-8d75-8aac8fa153fd', tool_call_id='call_pp3BBD1hwpdwskUj63G3tgaQ')]}}
 8----
 9{'agent': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_22Asbqgdx26YyEvJxBuANVdY', 'function': {'arguments': '{"query":"SELECT c.Country, SUM(i.Total) AS Total_Spent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY Total_Spent DESC LIMIT 1"}', 'name': 'sql_db_query'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 53, 'prompt_tokens': 744, 'total_tokens': 797}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': 'fp_3b956da36b', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-bdd94241-ca49-4f15-b31a-b7c728a34ea8-0', tool_calls=[{'name': 'sql_db_query', 'args': {'query': 'SELECT c.Country, SUM(i.Total) AS Total_Spent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY Total_Spent DESC LIMIT 1'}, 'id': 'call_22Asbqgdx26YyEvJxBuANVdY'}])]}}
10----
11{'action': {'messages': [ToolMessage(content="[('USA', 523.0600000000003)]", name='sql_db_query', id='f647e606-8362-40ab-8d34-612ff166dbe1', tool_call_id='call_22Asbqgdx26YyEvJxBuANVdY')]}}
12----
13{'agent': {'messages': [AIMessage(content='Customers from the USA spent the most, with a total amount spent of $523.06.', response_metadata={'token_usage': {'completion_tokens': 20, 'prompt_tokens': 819, 'total_tokens': 839}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': 'fp_3b956da36b', 'finish_reason': 'stop', 'logprobs': None}, id='run-92e88de0-ff62-41da-8181-053fb5632af4-0')]}}
14----

注意agent执行多个查询,直到获得所需的信息:

  1. 列出可用的表;
  2. 检索三个表的模式;
  3. 通过联接操作查询多个表。

然后,agent能够使用最终查询的结果来生成对原始问题的答案。

agent同样可以处理定性问题:

1for s in agent_executor.stream(
2    {"messages": [HumanMessage(content="Describe the playlisttrack table")]}
3):
4    print(s)
5    print("----")
 1{'agent': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_WN0N3mm8WFvPXYlK9P7KvIEr', 'function': {'arguments': '{"table_names":"playlisttrack"}', 'name': 'sql_db_schema'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 17, 'prompt_tokens': 554, 'total_tokens': 571}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': 'fp_3b956da36b', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-be278326-4115-4c67-91a0-6dc97e7bffa4-0', tool_calls=[{'name': 'sql_db_schema', 'args': {'table_names': 'playlisttrack'}, 'id': 'call_WN0N3mm8WFvPXYlK9P7KvIEr'}])]}}
 2----
 3{'action': {'messages': [ToolMessage(content="Error: table_names {'playlisttrack'} not found in database", name='sql_db_schema', id='fe32b3d3-a40f-4802-a6b8-87a2453af8c2', tool_call_id='call_WN0N3mm8WFvPXYlK9P7KvIEr')]}}
 4----
 5{'agent': {'messages': [AIMessage(content='I apologize for the error. Let me first check the available tables in the database.', additional_kwargs={'tool_calls': [{'id': 'call_CzHt30847ql2MmnGxgYeVSL2', 'function': {'arguments': '{}', 'name': 'sql_db_list_tables'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 30, 'prompt_tokens': 592, 'total_tokens': 622}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': 'fp_3b956da36b', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-f6c107bb-e945-4848-a83c-f57daec1144e-0', tool_calls=[{'name': 'sql_db_list_tables', 'args': {}, 'id': 'call_CzHt30847ql2MmnGxgYeVSL2'}])]}}
 6----
 7{'action': {'messages': [ToolMessage(content='Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track', name='sql_db_list_tables', id='a4950f74-a0ad-4558-ba54-7bcf99539a02', tool_call_id='call_CzHt30847ql2MmnGxgYeVSL2')]}}
 8----
 9{'agent': {'messages': [AIMessage(content='The database contains a table named "PlaylistTrack". Let me retrieve the schema and sample rows from the "PlaylistTrack" table.', additional_kwargs={'tool_calls': [{'id': 'call_wX9IjHLgRBUmxlfCthprABRO', 'function': {'arguments': '{"table_names":"PlaylistTrack"}', 'name': 'sql_db_schema'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 44, 'prompt_tokens': 658, 'total_tokens': 702}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': 'fp_3b956da36b', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-e8d34372-1159-4654-a185-1e7d0cb70269-0', tool_calls=[{'name': 'sql_db_schema', 'args': {'table_names': 'PlaylistTrack'}, 'id': 'call_wX9IjHLgRBUmxlfCthprABRO'}])]}}
10----
11{'action': {'messages': [ToolMessage(content='\nCREATE TABLE "PlaylistTrack" (\n\t"PlaylistId" INTEGER NOT NULL, \n\t"TrackId" INTEGER NOT NULL, \n\tPRIMARY KEY ("PlaylistId", "TrackId"), \n\tFOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), \n\tFOREIGN KEY("PlaylistId") REFERENCES "Playlist" ("PlaylistId")\n)\n\n/*\n3 rows from PlaylistTrack table:\nPlaylistId\tTrackId\n1\t3402\n1\t3389\n1\t3390\n*/', name='sql_db_schema', id='f6ffc37a-188a-4690-b84e-c9f2c78b1e49', tool_call_id='call_wX9IjHLgRBUmxlfCthprABRO')]}}
12----
13{'agent': {'messages': [AIMessage(content='The "PlaylistTrack" table has the following schema:\n- PlaylistId: INTEGER (NOT NULL)\n- TrackId: INTEGER (NOT NULL)\n- Primary Key: (PlaylistId, TrackId)\n- Foreign Key: TrackId references Track(TrackId)\n- Foreign Key: PlaylistId references Playlist(PlaylistId)\n\nHere are 3 sample rows from the "PlaylistTrack" table:\n1. PlaylistId: 1, TrackId: 3402\n2. PlaylistId: 1, TrackId: 3389\n3. PlaylistId: 1, TrackId: 3390\n\nIf you have any specific questions or queries regarding the "PlaylistTrack" table, feel free to ask!', response_metadata={'token_usage': {'completion_tokens': 145, 'prompt_tokens': 818, 'total_tokens': 963}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': 'fp_3b956da36b', 'finish_reason': 'stop', 'logprobs': None}, id='run-961a4552-3cbd-4d28-b338-4d2f1ac40ea0-0')]}}
14----

处理高基数列 #

高基数列(High Cardinality Column) 高基数列(High Cardinality Column)是指在数据集中包含大量唯一值的列。换句话说,这些列中的值分布广泛,每个值的重复出现次数较少。高基数列在数据库、数据处理和数据分析中经常被提及,特别是在处理大规模数据时。

为了过滤包含专有名词(例如地址、歌曲名称或艺术家姓名)的列,我们首先需要仔细核对拼写,以便正确地过滤数据。

我们可以通过创建一个包含数据库中所有不同专有名词的向量存储来实现这一点。然后,每当用户在问题中包含专有名词时,智能体可以查询该向量存储,以找到该词的正确拼写。通过这种方式,智能体可以确保在构建目标查询之前,理解用户指的是哪个实体。

首先,我们需要获取每个所需实体的唯一值,为此我们定义了一个函数,将结果解析为元素列表:

 1import ast
 2import re
 3
 4
 5def query_as_list(db, query):
 6    res = db.run(query)
 7    res = [el for sub in ast.literal_eval(res) for el in sub if el]
 8    res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]
 9    return list(set(res))
10
11
12artists = query_as_list(db, "SELECT Name FROM Artist")
13albums = query_as_list(db, "SELECT Title FROM Album")
14albums[:5]
1['Big Ones',
2 'Cidade Negra - Hits',
3 'In Step',
4 'Use Your Illusion I',
5 'Voodoo Lounge']

使用这个函数,我们可以创建一个检索工具,agent可以根据自己的判断来执行它。

 1from langchain.agents.agent_toolkits import create_retriever_tool
 2from langchain_community.vectorstores import FAISS
 3from langchain_openai import OpenAIEmbeddings
 4
 5vector_db = FAISS.from_texts(artists + albums, OpenAIEmbeddings(model=EMBEDDING_MODEL_NAME))
 6retriever = vector_db.as_retriever(search_kwargs={"k": 5})
 7description = """Use to look up values to filter on. Input is an approximate spelling of the proper noun, output is \
 8valid proper nouns. Use the noun most similar to the search."""
 9retriever_tool = create_retriever_tool(
10    retriever,
11    name="search_proper_nouns",
12    description=description,
13)

让我们试一试:

1print(retriever_tool.invoke("Alice Chains"))
1Alice In Chains
2
3Alanis Morissette
4
5Pearl Jam
6
7Pearl Jam
8
9Audioslave

这样,如果agent确定需要根据类似"Alice Chains"的艺术家来编写过滤器,它可以首先使用检索工具来观察列的相关值。

将这些组合在一起:

 1system = """You are an agent designed to interact with a SQL database.
 2Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
 3Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
 4You can order the results by a relevant column to return the most interesting examples in the database.
 5Never query for all the columns from a specific table, only ask for the relevant columns given the question.
 6You have access to tools for interacting with the database.
 7Only use the given tools. Only use the information returned by the tools to construct your final answer.
 8You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
 9
10DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
11
12You have access to the following tables: {table_names}
13
14If you need to filter on a proper noun, you must ALWAYS first look up the filter value using the "search_proper_nouns" tool!
15Do not try to guess at the proper name - use this function to find similar ones.""".format(
16    table_names=db.get_usable_table_names()
17)
18
19system_message = SystemMessage(content=system)
20
21tools.append(retriever_tool)
22
23agent = create_react_agent(llm, tools, state_modifier=system_message)
1for s in agent.stream(
2    {"messages": [HumanMessage(content="How many albums does alis in chain have?")]}
3):
4    print(s)
5    print("----")
 1{'agent': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_86uzEat33FDtwwdd4wePPn8m', 'function': {'arguments': '{"query":"alis in chain"}', 'name': 'search_proper_nouns'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 19, 'prompt_tokens': 665, 'total_tokens': 684, 'completion_tokens_details': None}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_67802d9a6d', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-4f44bed3-cff5-4eb8-afe6-8a0c7c7197e9-0', tool_calls=[{'name': 'search_proper_nouns', 'args': {'query': 'alis in chain'}, 'id': 'call_86uzEat33FDtwwdd4wePPn8m', 'type': 'tool_call'}], usage_metadata={'input_tokens': 665, 'output_tokens': 19, 'total_tokens': 684, 'input_token_details': {}, 'output_token_details': {}})]}}
 2----
 3{'tools': {'messages': [ToolMessage(content='Alice In Chains\n\nIn Step\n\nIV\n\nTransmission\n\nCoda', name='search_proper_nouns', id='60ea710e-bb08-4918-bdd6-71fdc77379fd', tool_call_id='call_86uzEat33FDtwwdd4wePPn8m')]}}
 4----
 5{'agent': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Wh6i6bH4K7kgvaICDLbV2mJY', 'function': {'arguments': '{"table_names": "Album, Artist"}', 'name': 'sql_db_schema'}, 'type': 'function'}, {'id': 'call_331fXVKjNDgIpkRbzqKyZoia', 'function': {'arguments': '{"query": "Alice In Chains"}', 'name': 'search_proper_nouns'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 53, 'prompt_tokens': 708, 'total_tokens': 761, 'completion_tokens_details': None}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_67802d9a6d', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-07a58711-5007-43ce-9460-7d1e0a922ab7-0', tool_calls=[{'name': 'sql_db_schema', 'args': {'table_names': 'Album, Artist'}, 'id': 'call_Wh6i6bH4K7kgvaICDLbV2mJY', 'type': 'tool_call'}, {'name': 'search_proper_nouns', 'args': {'query': 'Alice In Chains'}, 'id': 'call_331fXVKjNDgIpkRbzqKyZoia', 'type': 'tool_call'}], usage_metadata={'input_tokens': 708, 'output_tokens': 53, 'total_tokens': 761, 'input_token_details': {}, 'output_token_details': {}})]}}
 6----
 7{'tools': {'messages': [ToolMessage(content='\nCREATE TABLE "Album" (\n\t"AlbumId" INTEGER NOT NULL, \n\t"Title" NVARCHAR(160) NOT NULL, \n\t"ArtistId" INTEGER NOT NULL, \n\tPRIMARY KEY ("AlbumId"), \n\tFOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")\n)\n\n/*\n3 rows from Album table:\nAlbumId\tTitle\tArtistId\n1\tFor Those About To Rock We Salute You\t1\n2\tBalls to the Wall\t2\n3\tRestless and Wild\t2\n*/\n\n\nCREATE TABLE "Artist" (\n\t"ArtistId" INTEGER NOT NULL, \n\t"Name" NVARCHAR(120), \n\tPRIMARY KEY ("ArtistId")\n)\n\n/*\n3 rows from Artist table:\nArtistId\tName\n1\tAC/DC\n2\tAccept\n3\tAerosmith\n*/', name='sql_db_schema', id='21a37a02-fd87-4600-a3fe-bc1e65395c43', tool_call_id='call_Wh6i6bH4K7kgvaICDLbV2mJY'), ToolMessage(content='Alice In Chains\n\nSoundgarden\n\nPearl Jam\n\nPearl Jam\n\nChris Cornell', name='search_proper_nouns', id='c287ed3a-00a4-4cd5-94ed-537e3cd56108', tool_call_id='call_331fXVKjNDgIpkRbzqKyZoia')]}}
 8----
 9{'agent': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_JePs9JuFrKvf7m1dsQfItMYU', 'function': {'arguments': '{"query":"SELECT COUNT(*) FROM Album INNER JOIN Artist ON Album.ArtistId = Artist.ArtistId WHERE Artist.Name = \'Alice In Chains\'"}', 'name': 'sql_db_query'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 41, 'prompt_tokens': 1002, 'total_tokens': 1043, 'completion_tokens_details': None}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_67802d9a6d', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-c321d9b5-c313-4945-b6ca-f2ff2f42c6b2-0', tool_calls=[{'name': 'sql_db_query', 'args': {'query': "SELECT COUNT(*) FROM Album INNER JOIN Artist ON Album.ArtistId = Artist.ArtistId WHERE Artist.Name = 'Alice In Chains'"}, 'id': 'call_JePs9JuFrKvf7m1dsQfItMYU', 'type': 'tool_call'}], usage_metadata={'input_tokens': 1002, 'output_tokens': 41, 'total_tokens': 1043, 'input_token_details': {}, 'output_token_details': {}})]}}
10----
11{'tools': {'messages': [ToolMessage(content='[(1,)]', name='sql_db_query', id='56ec1ae7-a0e5-4793-8d34-9879b0d3842a', tool_call_id='call_JePs9JuFrKvf7m1dsQfItMYU')]}}
12----
13{'agent': {'messages': [AIMessage(content='Alice In Chains has 1 album in the database.', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 12, 'prompt_tokens': 1056, 'total_tokens': 1068, 'completion_tokens_details': None}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_d54531d9eb', 'finish_reason': 'stop', 'logprobs': None}, id='run-aa41d7ae-a1dc-4838-9dad-f6214a62e75f-0', usage_metadata={'input_tokens': 1056, 'output_tokens': 12, 'total_tokens': 1068, 'input_token_details': {}, 'output_token_details': {}})]}}
14----

正如我们所看到的, agent使用了search_proper_nouns工具来检查如何正确地查询数据库以获取这位特定艺术家的信息。

© 2024 青蛙小白
comments powered by Disqus