Improve RAG example
This commit is contained in:
parent
9e288cefba
commit
eecd728668
|
@ -21,7 +21,9 @@
|
||||||
- title: Examples
|
- title: Examples
|
||||||
sections:
|
sections:
|
||||||
- local: examples/text_to_sql
|
- local: examples/text_to_sql
|
||||||
title: Text-to-SQL
|
title: Self-correcting Text-to-SQL
|
||||||
|
- local: examples/rag
|
||||||
|
title: Master you knowledge base with agentic RAG
|
||||||
- title: Reference
|
- title: Reference
|
||||||
sections:
|
sections:
|
||||||
- local: reference/agents
|
- local: reference/agents
|
||||||
|
|
|
@ -45,7 +45,7 @@ from huggingface_hub import login
|
||||||
login()
|
login()
|
||||||
```
|
```
|
||||||
|
|
||||||
We first load a knowledge base on which we want to perform RAG: this dataset is a compilation of the documentation pages for many Hugging Face libraries, stored as markdown.
|
We first load a knowledge base on which we want to perform RAG: this dataset is a compilation of the documentation pages for many Hugging Face libraries, stored as markdown. We will keep only the documentation for the `transformers` library.
|
||||||
|
|
||||||
Then prepare the knowledge base by processing the dataset and storing it into a vector database to be used by the retriever.
|
Then prepare the knowledge base by processing the dataset and storing it into a vector database to be used by the retriever.
|
||||||
|
|
||||||
|
@ -58,11 +58,11 @@ from tqdm import tqdm
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
from langchain_community.vectorstores import FAISS
|
from langchain_community.vectorstores import FAISS, DistanceStrategy
|
||||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||||
from langchain_community.vectorstores.utils import DistanceStrategy
|
|
||||||
|
|
||||||
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
|
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
|
||||||
|
knowledge_base = knowledge_base.filter(lambda row: row["source"].startswith("huggingface/transformers"))
|
||||||
|
|
||||||
source_docs = [
|
source_docs = [
|
||||||
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]})
|
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]})
|
||||||
|
@ -92,7 +92,7 @@ for doc in tqdm(source_docs):
|
||||||
docs_processed.append(new_doc)
|
docs_processed.append(new_doc)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
"Embedding documents... This should take a few minutes (5 minutes on MacBook with M1 Pro)"
|
"Embedding documents... This could take a few minutes."
|
||||||
)
|
)
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
embedding_model = HuggingFaceEmbeddings(
|
embedding_model = HuggingFaceEmbeddings(
|
||||||
|
@ -105,11 +105,13 @@ vectordb = FAISS.from_documents(
|
||||||
distance_strategy=DistanceStrategy.COSINE,
|
distance_strategy=DistanceStrategy.COSINE,
|
||||||
)
|
)
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
print(f"VectorDB embedded in {int((t1-t0)/60)} minutes")
|
print(f"VectorDB embedded in {(t1-t0):.2f} seconds")
|
||||||
```
|
```
|
||||||
If you want to improve performance, head to the [MTEB Leaderboard](https://huggingface.co/spaces/mteb/leaderboard) to select a bigger model for your embeddings: here we selected a small one for the sake of speed.
|
If you want to improve performance, head to the [MTEB Leaderboard](https://huggingface.co/spaces/mteb/leaderboard) to select a bigger model for your embeddings: here we selected a small one for the sake of speed.
|
||||||
|
|
||||||
Now the database is ready: let’s build our agentic RAG system!
|
Now the database is ready. Building the embeddings for each document snippet took a few minutes, but now they're ready to be used in a split second.
|
||||||
|
|
||||||
|
So let’s build our agentic RAG system!
|
||||||
|
|
||||||
👉 We only need a RetrieverTool that our agent can leverage to retrieve information from the knowledge base.
|
👉 We only need a RetrieverTool that our agent can leverage to retrieve information from the knowledge base.
|
||||||
|
|
||||||
|
@ -138,7 +140,7 @@ class RetrieverTool(Tool):
|
||||||
|
|
||||||
docs = self.vectordb.similarity_search(
|
docs = self.vectordb.similarity_search(
|
||||||
query,
|
query,
|
||||||
k=7,
|
k=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
return "\nRetrieved documents:\n" + "".join(
|
return "\nRetrieved documents:\n" + "".join(
|
||||||
|
@ -156,7 +158,7 @@ The agent will need these arguments upon initialization:
|
||||||
- `model`: the LLM that powers the agent.
|
- `model`: the LLM that powers the agent.
|
||||||
Our `model` must be a callable that takes as input a list of messages and returns text. It also needs to accept a stop_sequences argument that indicates when to stop its generation. For convenience, we directly use the HfEngine class provided in the package to get a LLM engine that calls Hugging Face's Inference API.
|
Our `model` must be a callable that takes as input a list of messages and returns text. It also needs to accept a stop_sequences argument that indicates when to stop its generation. For convenience, we directly use the HfEngine class provided in the package to get a LLM engine that calls Hugging Face's Inference API.
|
||||||
|
|
||||||
And we use meta-llama/Llama-3.3-70B-Instruct as the llm engine because:
|
And we use [meta-llama/Llama-3.3-70B-Instruct](meta-llama/Llama-3.3-70B-Instruct) as the llm engine because:
|
||||||
- It has a long 128k context, which is helpful for processing long source documents
|
- It has a long 128k context, which is helpful for processing long source documents
|
||||||
- It is served for free at all times on HF's Inference API!
|
- It is served for free at all times on HF's Inference API!
|
||||||
|
|
||||||
|
@ -167,7 +169,7 @@ from smolagents import HfApiModel, CodeAgent
|
||||||
|
|
||||||
retriever_tool = RetrieverTool(vectordb)
|
retriever_tool = RetrieverTool(vectordb)
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[retriever_tool], model=HfApiModel("Qwen/Qwen2.5-72B-Instruct"), max_iterations=4, verbose=True
|
tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_iterations=4, verbose=True
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,100 @@
|
||||||
|
# from huggingface_hub import login
|
||||||
|
|
||||||
|
# login()
|
||||||
|
import time
|
||||||
|
import datasets
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
from langchain_community.vectorstores import FAISS, DistanceStrategy
|
||||||
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||||
|
|
||||||
|
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
|
||||||
|
knowledge_base = knowledge_base.filter(lambda row: row["source"].startswith("huggingface/transformers"))
|
||||||
|
embedding_model = "TaylorAI/gte-tiny"
|
||||||
|
|
||||||
|
source_docs = [
|
||||||
|
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]})
|
||||||
|
for doc in knowledge_base
|
||||||
|
]
|
||||||
|
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
|
||||||
|
AutoTokenizer.from_pretrained(embedding_model),
|
||||||
|
chunk_size=200,
|
||||||
|
chunk_overlap=20,
|
||||||
|
add_start_index=True,
|
||||||
|
strip_whitespace=True,
|
||||||
|
separators=["\n\n", "\n", ".", " ", ""],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Split docs and keep only unique ones
|
||||||
|
print("Splitting documents...")
|
||||||
|
docs_processed = []
|
||||||
|
unique_texts = {}
|
||||||
|
for doc in tqdm(source_docs):
|
||||||
|
new_docs = text_splitter.split_documents([doc])
|
||||||
|
for new_doc in new_docs:
|
||||||
|
if new_doc.page_content not in unique_texts:
|
||||||
|
unique_texts[new_doc.page_content] = True
|
||||||
|
docs_processed.append(new_doc)
|
||||||
|
|
||||||
|
print(
|
||||||
|
"Embedding documents... This could take a few minutes."
|
||||||
|
)
|
||||||
|
t0 = time.time()
|
||||||
|
embedding_model = HuggingFaceEmbeddings(
|
||||||
|
model_name=embedding_model,
|
||||||
|
show_progress=True
|
||||||
|
)
|
||||||
|
vectordb = FAISS.from_documents(
|
||||||
|
documents=docs_processed,
|
||||||
|
embedding=embedding_model,
|
||||||
|
distance_strategy=DistanceStrategy.COSINE,
|
||||||
|
)
|
||||||
|
t1 = time.time()
|
||||||
|
print(f"VectorDB embedded in {(t1-t0):.2f} seconds")
|
||||||
|
|
||||||
|
from smolagents import Tool
|
||||||
|
|
||||||
|
class RetrieverTool(Tool):
|
||||||
|
name = "retriever"
|
||||||
|
description = "Uses semantic search to retrieve the parts of transformers documentation that could be most relevant to answer your query."
|
||||||
|
inputs = {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output_type = "string"
|
||||||
|
|
||||||
|
def __init__(self, vectordb, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.vectordb = vectordb
|
||||||
|
|
||||||
|
def forward(self, query: str) -> str:
|
||||||
|
assert isinstance(query, str), "Your search query must be a string"
|
||||||
|
|
||||||
|
docs = self.vectordb.similarity_search(
|
||||||
|
query,
|
||||||
|
k=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
return "\nRetrieved documents:\n" + "".join(
|
||||||
|
[
|
||||||
|
f"===== Document {str(i)} =====\n" + doc.page_content
|
||||||
|
for i, doc in enumerate(docs)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
from smolagents import HfApiModel, CodeAgent
|
||||||
|
|
||||||
|
retriever_tool = RetrieverTool(vectordb)
|
||||||
|
agent = CodeAgent(
|
||||||
|
tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_iterations=4, verbose=True
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_output = agent.run("For a transformers model training, which is faster, the forward or the backward pass?")
|
||||||
|
|
||||||
|
print("Final output:")
|
||||||
|
print(agent_output)
|
Loading…
Reference in New Issue