Add tests for multiagent hierarchies
This commit is contained in:
parent
4fa8255377
commit
0824785b7a
|
@ -195,7 +195,7 @@ This is not the only way to build the tool: you can directly define it as a subc
|
|||
Let's see how it works for both options:
|
||||
|
||||
<hfoptions id="build-a-tool">
|
||||
<hfoption id="@tool decorator">
|
||||
<hfoption id="Decorate a function with @tool">
|
||||
|
||||
```py
|
||||
from smolagents import tool
|
||||
|
@ -214,10 +214,10 @@ def model_download_tool(task: str) -> str:
|
|||
```
|
||||
|
||||
The function needs:
|
||||
- A clear name. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's put `model_download_tool`.
|
||||
- A clear name. The name should be descriptive enough of what this tool does to help the LLM brain powering the agent. Since this tool returns the model with the most downloads for a task, let's name it `model_download_tool`.
|
||||
- Type hints on both inputs and output
|
||||
- A description, that includes an 'Args:' part where each argument is described (without a type indication this time, it will be pulled from the type hint).
|
||||
All these will be automatically baked into the agent's system prompt upon initialization: so strive to make them as clear as possible!
|
||||
- A description, that includes an 'Args:' part where each argument is described (without a type indication this time, it will be pulled from the type hint). Same as for the tool name, this description is an instruction manual for the LLM powering you agent, so do not neglect it.
|
||||
All these elements will be automatically baked into the agent's system prompt upon initialization: so strive to make them as clear as possible!
|
||||
|
||||
> [!TIP]
|
||||
> This definition format is the same as tool schemas used in `apply_chat_template`, the only difference is the added `tool` decorator: read more on our tool use API [here](https://huggingface.co/blog/unified-tool-use#passing-tools-to-a-chat-template).
|
||||
|
@ -237,6 +237,13 @@ class ModelDownloadTool(Tool):
|
|||
most_downloaded_model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
|
||||
return most_downloaded_model.id
|
||||
```
|
||||
|
||||
The subclass needs the following attributes:
|
||||
- A clear `name`. The name should be descriptive enough of what this tool does to help the LLM brain powering the agent. Since this tool returns the model with the most downloads for a task, let's name it `model_download_tool`.
|
||||
- A `description`. Same as for the `name`, this description is an instruction manual for the LLM powering you agent, so do not neglect it.
|
||||
- Input types and descriptions
|
||||
- Output type
|
||||
All these attributes will be automatically baked into the agent's system prompt upon initialization: so strive to make them as clear as possible!
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
|
|
@ -372,7 +372,7 @@ class MultiStepAgent:
|
|||
except Exception as e:
|
||||
return f"Error in generating final LLM output:\n{e}"
|
||||
|
||||
def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any:
|
||||
def execute_tool_call(self, tool_name: str, arguments: Union[Dict[str, str], str]) -> Any:
|
||||
"""
|
||||
Execute tool with the provided input and returns the result.
|
||||
This method replaces arguments with the actual values from the state if they refer to state variables.
|
||||
|
|
|
@ -361,3 +361,86 @@ class AgentTests(unittest.TestCase):
|
|||
agent.run("Count to 3")
|
||||
str_output = capture.get()
|
||||
assert "import under additional_authorized_imports" in str_output
|
||||
|
||||
def test_multiagents(self):
|
||||
class FakeModelMultiagentsManagerAgent:
|
||||
def __call__(
|
||||
self, messages, stop_sequences=None, grammar=None
|
||||
):
|
||||
if len(messages) < 3:
|
||||
return """
|
||||
Thought: Let's call our search agent.
|
||||
Code:
|
||||
```py
|
||||
result = search_agent("Who is the current US president?")
|
||||
```<end_code>
|
||||
"""
|
||||
else:
|
||||
assert "Report on the current US president" in str(messages)
|
||||
return """
|
||||
Thought: Let's return the report.
|
||||
Code:
|
||||
```py
|
||||
final_answer("Final report.")
|
||||
```<end_code>
|
||||
"""
|
||||
|
||||
def get_tool_call(
|
||||
self, messages, available_tools, stop_sequences=None, grammar=None
|
||||
):
|
||||
if len(messages) < 3:
|
||||
return (
|
||||
"search_agent",
|
||||
"Who is the current US president?",
|
||||
"call_0",
|
||||
)
|
||||
else:
|
||||
assert "Report on the current US president" in str(messages)
|
||||
return (
|
||||
"final_answer",
|
||||
{"prompt": "Final report."},
|
||||
"call_0",
|
||||
)
|
||||
manager_model = FakeModelMultiagentsManagerAgent()
|
||||
|
||||
class FakeModelMultiagentsManagedAgent:
|
||||
def get_tool_call(
|
||||
self, messages, available_tools, stop_sequences=None, grammar=None
|
||||
):
|
||||
return (
|
||||
"final_answer",
|
||||
{"prompt": "Report on the current US president"},
|
||||
"call_0",
|
||||
)
|
||||
managed_model = FakeModelMultiagentsManagedAgent()
|
||||
|
||||
web_agent = ToolCallingAgent(
|
||||
tools=[],
|
||||
model=managed_model,
|
||||
max_steps=10,
|
||||
)
|
||||
|
||||
managed_web_agent = ManagedAgent(
|
||||
agent=web_agent,
|
||||
name="search_agent",
|
||||
description="Runs web searches for you. Give it your request as an argument. Make the request as detailed as needed, you can ask for thorough reports",
|
||||
)
|
||||
|
||||
manager_code_agent = CodeAgent(
|
||||
tools=[],
|
||||
model=manager_model,
|
||||
managed_agents=[managed_web_agent],
|
||||
additional_authorized_imports=["time", "numpy", "pandas"],
|
||||
)
|
||||
|
||||
report = manager_code_agent.run("Fake question.")
|
||||
assert report == "Final report."
|
||||
|
||||
manager_toolcalling_agent = ToolCallingAgent(
|
||||
tools=[],
|
||||
model=manager_model,
|
||||
managed_agents=[managed_web_agent],
|
||||
)
|
||||
|
||||
report = manager_toolcalling_agent.run("Fake question.")
|
||||
assert report == "Final report."
|
||||
|
|
|
@ -110,9 +110,10 @@ class TestDocs:
|
|||
code_blocks = self.extractor.extract_python_code(content)
|
||||
excluded_snippets = [
|
||||
"ToolCollection",
|
||||
"image_generation_tool",
|
||||
"from_langchain",
|
||||
"while llm_should_continue(memory):",
|
||||
"image_generation_tool", # We don't want to run this expensive operation
|
||||
"from_langchain", # Langchain is not a dependency
|
||||
"while llm_should_continue(memory):", # This is pseudo code
|
||||
"ollama_chat/llama3.2" # Exclude ollama building in guided tour
|
||||
]
|
||||
code_blocks = [
|
||||
block
|
||||
|
|
Loading…
Reference in New Issue