Fix SQL example
This commit is contained in:
parent
edb0be3adf
commit
a3cd9158a7
|
@ -15,6 +15,8 @@ rendered properly in your Markdown viewer.
|
||||||
-->
|
-->
|
||||||
# Text-to-SQL
|
# Text-to-SQL
|
||||||
|
|
||||||
|
[[open-in-colab]]
|
||||||
|
|
||||||
In this tutorial, we’ll see how to implement an agent that leverages SQL using `smolagents`.
|
In this tutorial, we’ll see how to implement an agent that leverages SQL using `smolagents`.
|
||||||
|
|
||||||
> Let's start with the goldnen question: why not keep it simple and use a standard text-to-SQL pipeline?
|
> Let's start with the goldnen question: why not keep it simple and use a standard text-to-SQL pipeline?
|
||||||
|
@ -54,9 +56,7 @@ receipts = Table(
|
||||||
Column("tip", Float),
|
Column("tip", Float),
|
||||||
)
|
)
|
||||||
metadata_obj.create_all(engine)
|
metadata_obj.create_all(engine)
|
||||||
```
|
|
||||||
|
|
||||||
```py
|
|
||||||
rows = [
|
rows = [
|
||||||
{"receipt_id": 1, "customer_name": "Alan Payne", "price": 12.06, "tip": 1.20},
|
{"receipt_id": 1, "customer_name": "Alan Payne", "price": 12.06, "tip": 1.20},
|
||||||
{"receipt_id": 2, "customer_name": "Alex Mason", "price": 23.86, "tip": 0.24},
|
{"receipt_id": 2, "customer_name": "Alex Mason", "price": 23.86, "tip": 0.24},
|
||||||
|
@ -96,7 +96,7 @@ Now let’s build our tool. It needs the following: (read [the tool doc](../tuto
|
||||||
- Type hints on both inputs and output.
|
- Type hints on both inputs and output.
|
||||||
|
|
||||||
```py
|
```py
|
||||||
from transformers.agents import tool
|
from smolagents import tool
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def sql_engine(query: str) -> str:
|
def sql_engine(query: str) -> str:
|
||||||
|
@ -127,11 +127,11 @@ We use the CodeAgent, which is transformers.agents’ main agent class: an agent
|
||||||
The llm_engine is the LLM that powers the agent system. HfEngine allows you to call LLMs using HF’s Inference API, either via Serverless or Dedicated endpoint, but you could also use any proprietary API.
|
The llm_engine is the LLM that powers the agent system. HfEngine allows you to call LLMs using HF’s Inference API, either via Serverless or Dedicated endpoint, but you could also use any proprietary API.
|
||||||
|
|
||||||
```py
|
```py
|
||||||
from transformers.agents import CodeAgent, HfApiEngine
|
from smolagents import CodeAgent, HfApiEngine
|
||||||
|
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[sql_engine],
|
tools=[sql_engine],
|
||||||
llm_engine=HfApiEngine("meta-llama/Meta-Llama-3-8B-Instruct"),
|
llm_engine=HfApiEngine("meta-llama/Meta-Llama-3.1-8B-Instruct"),
|
||||||
)
|
)
|
||||||
agent.run("Can you give me the name of the client who got the most expensive receipt?")
|
agent.run("Can you give me the name of the client who got the most expensive receipt?")
|
||||||
```
|
```
|
||||||
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
from sqlalchemy import (
|
||||||
|
create_engine,
|
||||||
|
MetaData,
|
||||||
|
Table,
|
||||||
|
Column,
|
||||||
|
String,
|
||||||
|
Integer,
|
||||||
|
Float,
|
||||||
|
insert,
|
||||||
|
inspect,
|
||||||
|
text,
|
||||||
|
)
|
||||||
|
|
||||||
|
engine = create_engine("sqlite:///:memory:")
|
||||||
|
metadata_obj = MetaData()
|
||||||
|
|
||||||
|
# create city SQL table
|
||||||
|
table_name = "receipts"
|
||||||
|
receipts = Table(
|
||||||
|
table_name,
|
||||||
|
metadata_obj,
|
||||||
|
Column("receipt_id", Integer, primary_key=True),
|
||||||
|
Column("customer_name", String(16), primary_key=True),
|
||||||
|
Column("price", Float),
|
||||||
|
Column("tip", Float),
|
||||||
|
)
|
||||||
|
metadata_obj.create_all(engine)
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
{"receipt_id": 1, "customer_name": "Alan Payne", "price": 12.06, "tip": 1.20},
|
||||||
|
{"receipt_id": 2, "customer_name": "Alex Mason", "price": 23.86, "tip": 0.24},
|
||||||
|
{"receipt_id": 3, "customer_name": "Woodrow Wilson", "price": 53.43, "tip": 5.43},
|
||||||
|
{"receipt_id": 4, "customer_name": "Margaret James", "price": 21.11, "tip": 1.00},
|
||||||
|
]
|
||||||
|
for row in rows:
|
||||||
|
stmt = insert(receipts).values(**row)
|
||||||
|
with engine.begin() as connection:
|
||||||
|
cursor = connection.execute(stmt)
|
||||||
|
|
||||||
|
inspector = inspect(engine)
|
||||||
|
columns_info = [(col["name"], col["type"]) for col in inspector.get_columns("receipts")]
|
||||||
|
|
||||||
|
table_description = "Columns:\n" + "\n".join([f" - {name}: {col_type}" for name, col_type in columns_info])
|
||||||
|
print(table_description)
|
||||||
|
|
||||||
|
from smolagents import tool
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def sql_engine(query: str) -> str:
|
||||||
|
"""
|
||||||
|
Allows you to perform SQL queries on the table. Returns a string representation of the result.
|
||||||
|
The table is named 'receipts'. Its description is as follows:
|
||||||
|
Columns:
|
||||||
|
- receipt_id: INTEGER
|
||||||
|
- customer_name: VARCHAR(16)
|
||||||
|
- price: FLOAT
|
||||||
|
- tip: FLOAT
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The query to perform. This should be correct SQL.
|
||||||
|
"""
|
||||||
|
output = ""
|
||||||
|
with engine.connect() as con:
|
||||||
|
rows = con.execute(text(query))
|
||||||
|
for row in rows:
|
||||||
|
output += "\n" + str(row)
|
||||||
|
return output
|
||||||
|
|
||||||
|
from smolagents import CodeAgent, HfApiEngine
|
||||||
|
|
||||||
|
agent = CodeAgent(
|
||||||
|
tools=[sql_engine],
|
||||||
|
llm_engine=HfApiEngine("meta-llama/Meta-Llama-3.1-8B-Instruct"),
|
||||||
|
)
|
||||||
|
agent.run("Can you give me the name of the client who got the most expensive receipt?")
|
Loading…
Reference in New Issue