70 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			70 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
| # from huggingface_hub import login
 | |
| 
 | |
| # login()
 | |
| import datasets
 | |
| from langchain.docstore.document import Document
 | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter
 | |
| from langchain_community.retrievers import BM25Retriever
 | |
| 
 | |
| 
 | |
| knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
 | |
| knowledge_base = knowledge_base.filter(lambda row: row["source"].startswith("huggingface/transformers"))
 | |
| 
 | |
| source_docs = [
 | |
|     Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]})
 | |
|     for doc in knowledge_base
 | |
| ]
 | |
| 
 | |
| text_splitter = RecursiveCharacterTextSplitter(
 | |
|     chunk_size=500,
 | |
|     chunk_overlap=50,
 | |
|     add_start_index=True,
 | |
|     strip_whitespace=True,
 | |
|     separators=["\n\n", "\n", ".", " ", ""],
 | |
| )
 | |
| docs_processed = text_splitter.split_documents(source_docs)
 | |
| 
 | |
| from smolagents import Tool
 | |
| 
 | |
| class RetrieverTool(Tool):
 | |
|     name = "retriever"
 | |
|     description = "Uses semantic search to retrieve the parts of transformers documentation that could be most relevant to answer your query."
 | |
|     inputs = {
 | |
|         "query": {
 | |
|             "type": "string",
 | |
|             "description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
 | |
|         }
 | |
|     }
 | |
|     output_type = "string"
 | |
| 
 | |
|     def __init__(self, docs, **kwargs):
 | |
|         super().__init__(**kwargs)
 | |
|         self.retriever = BM25Retriever.from_documents(
 | |
|             docs, k=10
 | |
|         )
 | |
| 
 | |
|     def forward(self, query: str) -> str:
 | |
|         assert isinstance(query, str), "Your search query must be a string"
 | |
| 
 | |
|         docs = self.retriever.invoke(
 | |
|             query,
 | |
|         )
 | |
|         return "\nRetrieved documents:\n" + "".join(
 | |
|             [
 | |
|                 f"\n\n===== Document {str(i)} =====\n" + doc.page_content
 | |
|                 for i, doc in enumerate(docs)
 | |
|             ]
 | |
|         )
 | |
| 
 | |
| from smolagents import HfApiModel, CodeAgent
 | |
| 
 | |
| retriever_tool = RetrieverTool(docs_processed)
 | |
| agent = CodeAgent(
 | |
|     tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_iterations=4, verbose=True
 | |
| )
 | |
| 
 | |
| agent_output = agent.run("For a transformers model training, which is slower, the forward or the backward pass?")
 | |
| 
 | |
| print("Final output:")
 | |
| print(agent_output)
 |