Fix SQL example

This commit is contained in:
Aymeric 2024-12-24 11:58:36 +01:00
parent edb0be3adf
commit a3cd9158a7
2 changed files with 80 additions and 5 deletions

View File

@ -15,6 +15,8 @@ rendered properly in your Markdown viewer.
--> -->
# Text-to-SQL # Text-to-SQL
[[open-in-colab]]
In this tutorial, well see how to implement an agent that leverages SQL using `smolagents`. In this tutorial, well see how to implement an agent that leverages SQL using `smolagents`.
> Let's start with the goldnen question: why not keep it simple and use a standard text-to-SQL pipeline? > Let's start with the goldnen question: why not keep it simple and use a standard text-to-SQL pipeline?
@ -54,9 +56,7 @@ receipts = Table(
Column("tip", Float), Column("tip", Float),
) )
metadata_obj.create_all(engine) metadata_obj.create_all(engine)
```
```py
rows = [ rows = [
{"receipt_id": 1, "customer_name": "Alan Payne", "price": 12.06, "tip": 1.20}, {"receipt_id": 1, "customer_name": "Alan Payne", "price": 12.06, "tip": 1.20},
{"receipt_id": 2, "customer_name": "Alex Mason", "price": 23.86, "tip": 0.24}, {"receipt_id": 2, "customer_name": "Alex Mason", "price": 23.86, "tip": 0.24},
@ -96,7 +96,7 @@ Now lets build our tool. It needs the following: (read [the tool doc](../tuto
- Type hints on both inputs and output. - Type hints on both inputs and output.
```py ```py
from transformers.agents import tool from smolagents import tool
@tool @tool
def sql_engine(query: str) -> str: def sql_engine(query: str) -> str:
@ -127,11 +127,11 @@ We use the CodeAgent, which is transformers.agents main agent class: an agent
The llm_engine is the LLM that powers the agent system. HfEngine allows you to call LLMs using HFs Inference API, either via Serverless or Dedicated endpoint, but you could also use any proprietary API. The llm_engine is the LLM that powers the agent system. HfEngine allows you to call LLMs using HFs Inference API, either via Serverless or Dedicated endpoint, but you could also use any proprietary API.
```py ```py
from transformers.agents import CodeAgent, HfApiEngine from smolagents import CodeAgent, HfApiEngine
agent = CodeAgent( agent = CodeAgent(
tools=[sql_engine], tools=[sql_engine],
llm_engine=HfApiEngine("meta-llama/Meta-Llama-3-8B-Instruct"), llm_engine=HfApiEngine("meta-llama/Meta-Llama-3.1-8B-Instruct"),
) )
agent.run("Can you give me the name of the client who got the most expensive receipt?") agent.run("Can you give me the name of the client who got the most expensive receipt?")
``` ```

75
examples/text_to_sql.py Normal file
View File

@ -0,0 +1,75 @@
from sqlalchemy import (
create_engine,
MetaData,
Table,
Column,
String,
Integer,
Float,
insert,
inspect,
text,
)
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()
# create city SQL table
table_name = "receipts"
receipts = Table(
table_name,
metadata_obj,
Column("receipt_id", Integer, primary_key=True),
Column("customer_name", String(16), primary_key=True),
Column("price", Float),
Column("tip", Float),
)
metadata_obj.create_all(engine)
rows = [
{"receipt_id": 1, "customer_name": "Alan Payne", "price": 12.06, "tip": 1.20},
{"receipt_id": 2, "customer_name": "Alex Mason", "price": 23.86, "tip": 0.24},
{"receipt_id": 3, "customer_name": "Woodrow Wilson", "price": 53.43, "tip": 5.43},
{"receipt_id": 4, "customer_name": "Margaret James", "price": 21.11, "tip": 1.00},
]
for row in rows:
stmt = insert(receipts).values(**row)
with engine.begin() as connection:
cursor = connection.execute(stmt)
inspector = inspect(engine)
columns_info = [(col["name"], col["type"]) for col in inspector.get_columns("receipts")]
table_description = "Columns:\n" + "\n".join([f" - {name}: {col_type}" for name, col_type in columns_info])
print(table_description)
from smolagents import tool
@tool
def sql_engine(query: str) -> str:
"""
Allows you to perform SQL queries on the table. Returns a string representation of the result.
The table is named 'receipts'. Its description is as follows:
Columns:
- receipt_id: INTEGER
- customer_name: VARCHAR(16)
- price: FLOAT
- tip: FLOAT
Args:
query: The query to perform. This should be correct SQL.
"""
output = ""
with engine.connect() as con:
rows = con.execute(text(query))
for row in rows:
output += "\n" + str(row)
return output
from smolagents import CodeAgent, HfApiEngine
agent = CodeAgent(
tools=[sql_engine],
llm_engine=HfApiEngine("meta-llama/Meta-Llama-3.1-8B-Instruct"),
)
agent.run("Can you give me the name of the client who got the most expensive receipt?")