Make RAG example extremely fast with BM25

This commit is contained in:
Aymeric 2024-12-26 16:19:31 +01:00
parent eecd728668
commit 1abaf69b67
6 changed files with 40 additions and 102 deletions

View File

@ -78,7 +78,7 @@ The `preview` command only works with existing doc files. When you add a complet
Accepted files are Markdown (.md).
Create a file with its extension and put it in the source directory. You can then link it to the toc-tree by putting
the filename without the extension in the [`_toctree.yml`](https://github.com/huggingface/agents/blob/main/docs/source/_toctree.yml) file.
the filename without the extension in the [`_toctree.yml`](https://github.com/huggingface/smolagents/blob/main/docs/source/_toctree.yml) file.
## Renaming section headers and moving sections
@ -108,7 +108,7 @@ For an example of a rich moved section set please see the very end of [the trans
## Writing Documentation - Specification
The `huggingface/agents` documentation follows the
The `huggingface/smolagents` documentation follows the
[Google documentation](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) style for docstrings,
although we can write them directly in Markdown.
@ -123,7 +123,7 @@ Make sure to put your new file under the proper section. If you have a doubt, fe
### Translating
When translating, refer to the guide at [./TRANSLATING.md](https://github.com/huggingface/agents/blob/main/docs/TRANSLATING.md).
When translating, refer to the guide at [./TRANSLATING.md](https://github.com/huggingface/smolagents/blob/main/docs/TRANSLATING.md).
### Writing source documentation

View File

@ -52,14 +52,10 @@ Then prepare the knowledge base by processing the dataset and storing it into a
We use [LangChain](https://python.langchain.com/docs/introduction/) for its excellent vector database utilities.
```py
import time
import datasets
from tqdm import tqdm
from transformers import AutoTokenizer
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS, DistanceStrategy
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.retrievers import BM25Retriever
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
knowledge_base = knowledge_base.filter(lambda row: row["source"].startswith("huggingface/transformers"))
@ -69,47 +65,17 @@ source_docs = [
for doc in knowledge_base
]
embedding_model = "TaylorAI/gte-tiny"
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
AutoTokenizer.from_pretrained(embedding_model),
chunk_size=200,
chunk_overlap=20,
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
add_start_index=True,
strip_whitespace=True,
separators=["\n\n", "\n", ".", " ", ""],
)
# Split docs and keep only unique ones
print("Splitting documents...")
docs_processed = []
unique_texts = {}
for doc in tqdm(source_docs):
new_docs = text_splitter.split_documents([doc])
for new_doc in new_docs:
if new_doc.page_content not in unique_texts:
unique_texts[new_doc.page_content] = True
docs_processed.append(new_doc)
print(
"Embedding documents... This could take a few minutes."
)
t0 = time.time()
embedding_model = HuggingFaceEmbeddings(
model_name=embedding_model,
show_progress=True
)
vectordb = FAISS.from_documents(
documents=docs_processed,
embedding=embedding_model,
distance_strategy=DistanceStrategy.COSINE,
)
t1 = time.time()
print(f"VectorDB embedded in {(t1-t0):.2f} seconds")
docs_processed = text_splitter.split_documents(source_docs)
```
If you want to improve performance, head to the [MTEB Leaderboard](https://huggingface.co/spaces/mteb/leaderboard) to select a bigger model for your embeddings: here we selected a small one for the sake of speed.
Now the database is ready. Building the embeddings for each document snippet took a few minutes, but now they're ready to be used in a split second.
Now the documents are ready.
So lets build our agentic RAG system!
@ -122,7 +88,7 @@ from smolagents import Tool
class RetrieverTool(Tool):
name = "retriever"
description = "Using semantic similarity, retrieves some documents from the knowledge base that have the closest embeddings to the input query."
description = "Uses semantic search to retrieve the parts of transformers documentation that could be most relevant to answer your query."
inputs = {
"query": {
"type": "string",
@ -131,27 +97,31 @@ class RetrieverTool(Tool):
}
output_type = "string"
def __init__(self, vectordb, **kwargs):
def __init__(self, docs, **kwargs):
super().__init__(**kwargs)
self.vectordb = vectordb
self.retriever = BM25Retriever.from_documents(
docs, k=10
)
def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"
docs = self.vectordb.similarity_search(
docs = self.retriever.invoke(
query,
k=10,
)
return "\nRetrieved documents:\n" + "".join(
[
f"===== Document {str(i)} =====\n" + doc.page_content
f"\n\n===== Document {str(i)} =====\n" + doc.page_content
for i, doc in enumerate(docs)
]
)
```
Now its straightforward to create an agent that leverages this tool!
retriever_tool = RetrieverTool(docs_processed)
```
We have used BM25, a classic retrieval method, because it's lightning fast to setup.
To improve retrieval accuracy, you could use replace BM25 with semantic search using vector representations for documents: thus you can head to the [MTEB Leaderboard](https://huggingface.co/spaces/mteb/leaderboard) to select a good embedding model.
Now its straightforward to create an agent that leverages this `retriever_tool`!
The agent will need these arguments upon initialization:
- `tools`: a list of tools that the agent will be able to call.
@ -167,7 +137,6 @@ _Note:_ The Inference API hosts models based on various criteria, and deployed m
```py
from smolagents import HfApiModel, CodeAgent
retriever_tool = RetrieverTool(vectordb)
agent = CodeAgent(
tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_iterations=4, verbose=True
)
@ -178,7 +147,7 @@ Upon initializing the CodeAgent, it has been automatically given a default syste
Then when its `.run()` method is launched, the agent takes care of calling the LLM engine, and executing the tool calls, all in a loop that ends only when tool `final_answer` is called with the final answer as its argument.
```py
agent_output = agent.run("How can I push a model to the Hub?")
agent_output = agent.run("For a transformers model training, which is slower, the forward or the backward pass?")
print("Final output:")
print(agent_output)

View File

@ -2,4 +2,4 @@
FROM e2bdev/code-interpreter:latest
# Install dependencies and customize sandbox
RUN pip install git+https://github.com/huggingface/agents.git
RUN pip install git+https://github.com/huggingface/smolagents.git

View File

@ -1,59 +1,28 @@
# from huggingface_hub import login
# login()
import time
import datasets
from tqdm import tqdm
from transformers import AutoTokenizer
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS, DistanceStrategy
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.retrievers import BM25Retriever
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
knowledge_base = knowledge_base.filter(lambda row: row["source"].startswith("huggingface/transformers"))
embedding_model = "TaylorAI/gte-tiny"
source_docs = [
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]})
for doc in knowledge_base
]
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
AutoTokenizer.from_pretrained(embedding_model),
chunk_size=200,
chunk_overlap=20,
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
add_start_index=True,
strip_whitespace=True,
separators=["\n\n", "\n", ".", " ", ""],
)
# Split docs and keep only unique ones
print("Splitting documents...")
docs_processed = []
unique_texts = {}
for doc in tqdm(source_docs):
new_docs = text_splitter.split_documents([doc])
for new_doc in new_docs:
if new_doc.page_content not in unique_texts:
unique_texts[new_doc.page_content] = True
docs_processed.append(new_doc)
print(
"Embedding documents... This could take a few minutes."
)
t0 = time.time()
embedding_model = HuggingFaceEmbeddings(
model_name=embedding_model,
show_progress=True
)
vectordb = FAISS.from_documents(
documents=docs_processed,
embedding=embedding_model,
distance_strategy=DistanceStrategy.COSINE,
)
t1 = time.time()
print(f"VectorDB embedded in {(t1-t0):.2f} seconds")
docs_processed = text_splitter.split_documents(source_docs)
from smolagents import Tool
@ -68,33 +37,33 @@ class RetrieverTool(Tool):
}
output_type = "string"
def __init__(self, vectordb, **kwargs):
def __init__(self, docs, **kwargs):
super().__init__(**kwargs)
self.vectordb = vectordb
self.retriever = BM25Retriever.from_documents(
docs, k=10
)
def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"
docs = self.vectordb.similarity_search(
docs = self.retriever.invoke(
query,
k=10,
)
return "\nRetrieved documents:\n" + "".join(
[
f"===== Document {str(i)} =====\n" + doc.page_content
f"\n\n===== Document {str(i)} =====\n" + doc.page_content
for i, doc in enumerate(docs)
]
)
from smolagents import HfApiModel, CodeAgent
retriever_tool = RetrieverTool(vectordb)
retriever_tool = RetrieverTool(docs_processed)
agent = CodeAgent(
tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_iterations=4, verbose=True
)
agent_output = agent.run("For a transformers model training, which is faster, the forward or the backward pass?")
agent_output = agent.run("For a transformers model training, which is slower, the forward or the backward pass?")
print("Final output:")
print(agent_output)

View File

@ -910,7 +910,7 @@ class CodeAgent(MultiStepAgent):
align="left",
style="orange",
),
Syntax(llm_output, lexer="markdown", theme="github-dark"),
Syntax(llm_output, lexer="markdown", theme="github-dark", word_wrap=True),
)
)

View File

@ -36,7 +36,7 @@ class E2BExecutor:
# TODO: validate installing agents package or not
# print("Installing agents package on remote executor...")
# self.sbx.commands.run(
# "pip install git+https://github.com/huggingface/agents.git",
# "pip install git+https://github.com/huggingface/smolagents.git",
# timeout=300
# )
# print("Installation of agents package finished.")