Enable support for tool calling agents

This commit is contained in:
Aymeric 2024-12-23 17:10:07 +01:00
parent 24d9cf9e3d
commit 30cb6111b3
13 changed files with 224 additions and 139 deletions

View File

@ -40,7 +40,8 @@ This library offers:
🤗 **Hub integrations**: you can share and load tools to/from the Hub, and more is to come!
Quick demo:
## Quick demo
First install the package.
```bash
pip install agents

View File

@ -17,7 +17,7 @@
- local: conceptual_guides/intro_agents
title: 🤖 An introduction to agentic systems
- local: conceptual_guides/react
title: 🤔 ReAct agents
title: 🤔 How do Multi-step agents work?
- title: Examples
sections:
- local: examples/text_to_sql

View File

@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License.
rendered properly in your Markdown viewer.
-->
# ReAct agents
# How do multi-step agents work?
The ReAct framework ([Yao et al., 2022](https://huggingface.co/papers/2210.03629)) is currently the main approach to building agents.
@ -22,7 +22,7 @@ The name is based on the concatenation of two words, "Reason" and "Act." Indeed,
React process involves keeping a memory of past steps.
> [!TIP]
> Read [Open-source LLMs as LangChain Agents](https://huggingface.co/blog/open-source-llms-as-agents) blog post to learn more about ReAct agents.
> Read [Open-source LLMs as LangChain Agents](https://huggingface.co/blog/open-source-llms-as-agents) blog post to learn more about multi-step agents.
Here is a video overview of how that works:
@ -44,4 +44,4 @@ We implement two versions of JsonAgent:
- [`CodeAgent`] is a new type of JsonAgent that generates its tool calls as blobs of code, which works really well for LLMs that have strong coding performance.
> [!TIP]
> We also provide an option to run agents in one-shot: just pass `oneshot=True` when launching the agent, like `agent.run(your_task, oneshot=True)`
> We also provide an option to run agents in one-shot: just pass `single_step=True` when launching the agent, like `agent.run(your_task, single_step=True)`

View File

@ -122,14 +122,14 @@ def sql_engine(query: str) -> str:
Now let us create an agent that leverages this tool.
We use the ReactCodeAgent, which is transformers.agents main agent class: an agent that writes actions in code and can iterate on previous output according to the ReAct framework.
We use the CodeAgent, which is transformers.agents main agent class: an agent that writes actions in code and can iterate on previous output according to the ReAct framework.
The llm_engine is the LLM that powers the agent system. HfEngine allows you to call LLMs using HFs Inference API, either via Serverless or Dedicated endpoint, but you could also use any proprietary API.
```py
from transformers.agents import ReactCodeAgent, HfApiEngine
from transformers.agents import CodeAgent, HfApiEngine
agent = ReactCodeAgent(
agent = CodeAgent(
tools=[sql_engine],
llm_engine=HfApiEngine("meta-llama/Meta-Llama-3-8B-Instruct"),
)
@ -185,7 +185,7 @@ Since this request is a bit harder than the previous one, well switch the LLM
```py
sql_engine.description = updated_description
agent = ReactCodeAgent(
agent = CodeAgent(
tools=[sql_engine],
llm_engine=HfApiEngine("Qwen/Qwen2.5-72B-Instruct"),
)

View File

@ -27,7 +27,7 @@ contains the API docs for the underlying classes.
## Agents
Our agents inherit from [`ReactAgent`], which means they can act in multiple steps, each step consisting of one thought, then one tool call and execution. Read more in [this conceptual guide](../conceptual_guides/react).
Our agents inherit from [`MultiStepAgent`], which means they can act in multiple steps, each step consisting of one thought, then one tool call and execution. Read more in [this conceptual guide](../conceptual_guides/react).
We provide two types of agents, based on the main [`Agent`] class.
- [`JsonAgent`] writes its tool calls in JSON.
@ -40,7 +40,7 @@ We provide two types of agents, based on the main [`Agent`] class.
### React agents
[[autodoc]] ReactAgent
[[autodoc]] MultiStepAgent
[[autodoc]] JsonAgent

View File

@ -89,8 +89,9 @@ class AgentGenerationError(AgentError):
@dataclass
class ToolCall:
tool_name: str
tool_arguments: Any
name: str
arguments: Any
id: str
class AgentStep:
@ -306,26 +307,49 @@ class BaseAgent:
}
memory.append(thought_message)
if step_log.tool_call is not None and summary_mode:
if step_log.tool_call is not None:
tool_call_message = {
"role": MessageRole.ASSISTANT,
"content": f"[STEP {i} TOOL CALL]: "
+ str(step_log.tool_call).strip(),
"content": str(
[
{
"id": step_log.tool_call.id,
"type": "function",
"function": {
"name": step_log.tool_call.name,
"arguments": step_log.tool_call.arguments,
},
}
]
),
}
memory.append(tool_call_message)
if step_log.error is not None or step_log.observations is not None:
if step_log.tool_call is None and step_log.error is not None:
message_content = (
"Error:\n"
+ str(step_log.error)
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
)
tool_response_message = {
"role": MessageRole.ASSISTANT,
"content": message_content,
}
if step_log.tool_call is not None and (
step_log.error is not None or step_log.observations is not None
):
if step_log.error is not None:
message_content = (
f"[OUTPUT OF STEP {i}] -> Error:\n"
"Error:\n"
+ str(step_log.error)
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
)
elif step_log.observations is not None:
message_content = f"[OUTPUT OF STEP {i}] -> Observation:\n{step_log.observations}"
message_content = f"Observation:\n{step_log.observations}"
tool_response_message = {
"role": MessageRole.TOOL_RESPONSE,
"content": message_content,
"content": f"Call id: {(step_log.tool_call.id if getattr(step_log.tool_call, 'id') else 'call_0')}\n"
+ message_content,
}
memory.append(tool_response_message)
@ -362,7 +386,7 @@ class BaseAgent:
raise NotImplementedError
class ReactAgent(BaseAgent):
class MultiStepAgent(BaseAgent):
"""
This agent that solves the given task step by step, using the ReAct framework:
While the objective is not reached, the agent will perform a cycle of thinking and acting.
@ -474,7 +498,7 @@ class ReactAgent(BaseAgent):
task: str,
stream: bool = False,
reset: bool = True,
oneshot: bool = False,
single_step: bool = False,
**kwargs,
):
"""
@ -484,7 +508,7 @@ class ReactAgent(BaseAgent):
task (`str`): The task to perform.
stream (`bool`): Wether to run in a streaming way.
reset (`bool`): Wether to reset the conversation or keep it going from previous run.
oneshot (`bool`): Should the agent run in one shot or multi-step fashion?
single_step (`bool`): Should the agent run in one shot or multi-step fashion?
Example:
```py
@ -516,7 +540,7 @@ class ReactAgent(BaseAgent):
console.print(Group(Rule("[bold]New task", characters="="), Text(self.task)))
self.logs.append(TaskStep(task=self.task))
if oneshot:
if single_step:
step_start_time = time.time()
step_log = ActionStep(start_time=step_start_time)
step_log.end_time = time.time()
@ -548,7 +572,7 @@ class ReactAgent(BaseAgent):
self.planning_step(
task, is_first_step=(iteration == 0), iteration=iteration
)
console.rule("[bold]New step")
console.rule(f"[bold]Step {iteration}")
# Run one step!
final_answer = self.step(step_log)
@ -594,7 +618,7 @@ class ReactAgent(BaseAgent):
self.planning_step(
task, is_first_step=(iteration == 0), iteration=iteration
)
console.rule("[bold]New step")
console.rule(f"[bold]Step {iteration}")
# Run one step!
final_answer = self.step(step_log)
@ -741,7 +765,7 @@ Now begin!""",
)
class JsonAgent(ReactAgent):
class JsonAgent(MultiStepAgent):
"""
In this agent, the tool calls will be formulated by the LLM in JSON format, then parsed and executed.
"""
@ -784,18 +808,6 @@ class JsonAgent(ReactAgent):
# Add new step in logs
log_entry.agent_memory = agent_memory.copy()
if self.verbose:
console.print(
Group(
Rule(
"[italic]Calling LLM engine with this last message:",
align="left",
style="orange",
),
Text(str(self.prompt_messages[-1])),
)
)
try:
additional_args = (
{"grammar": self.grammar} if self.grammar is not None else {}
@ -827,25 +839,27 @@ class JsonAgent(ReactAgent):
)
try:
tool_name, arguments = self.tool_parser(action)
tool_name, tool_arguments = self.tool_parser(action)
except Exception as e:
raise AgentParsingError(f"Could not parse the given action: {e}.")
log_entry.tool_call = ToolCall(tool_name=tool_name, tool_arguments=arguments)
log_entry.tool_call = ToolCall(
tool_name=tool_name, tool_arguments=tool_arguments
)
# Execute
console.print(Rule("Agent thoughts:", align="left"), Text(rationale))
console.print(
Panel(Text(f"Calling tool: '{tool_name}' with arguments: {arguments}"))
Panel(Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}"))
)
if tool_name == "final_answer":
if isinstance(arguments, dict):
if "answer" in arguments:
answer = arguments["answer"]
if isinstance(tool_arguments, dict):
if "answer" in tool_arguments:
answer = tool_arguments["answer"]
else:
answer = arguments
answer = tool_arguments
else:
answer = arguments
answer = tool_arguments
if (
isinstance(answer, str) and answer in self.state.keys()
): # if the answer is a state variable, return the value
@ -853,9 +867,9 @@ class JsonAgent(ReactAgent):
log_entry.action_output = answer
return answer
else:
if arguments is None:
arguments = {}
observation = self.execute_tool_call(tool_name, arguments)
if tool_arguments is None:
tool_arguments = {}
observation = self.execute_tool_call(tool_name, tool_arguments)
observation_type = type(observation)
if observation_type in [AgentImage, AgentAudio]:
if observation_type == AgentImage:
@ -871,9 +885,10 @@ class JsonAgent(ReactAgent):
log_entry.observations = updated_information
return None
class ToolCallingAgent(ReactAgent):
class ToolCallingAgent(MultiStepAgent):
"""
In this agent, the tool calls will be formulated and parsed using the underlying library, before execution.
This agent uses JSON-like tool calls, but to the difference of JsonAgents, it makes use of the underlying librarie's tool calling facilities.
"""
def __init__(
@ -912,53 +927,29 @@ class ToolCallingAgent(ReactAgent):
# Add new step in logs
log_entry.agent_memory = agent_memory.copy()
if self.verbose:
console.print(
Group(
Rule(
"[italic]Calling LLM engine with this last message:",
align="left",
style="orange",
),
Text(str(self.prompt_messages[-1])),
)
)
try:
llm_output = self.llm_engine(
self.prompt_messages,
tool_name, tool_arguments, tool_call_id = self.llm_engine.get_tool_call(
self.prompt_messages, available_tools=list(self.toolbox._tools.values())
)
log_entry.llm_output = llm_output
except Exception as e:
raise AgentGenerationError(f"Error in generating llm_engine output: {e}.")
if self.verbose:
console.print(
Group(
Rule(
"[italic]Output message of the LLM:",
align="left",
style="orange",
),
Text(llm_output),
log_entry.tool_call = ToolCall(
name=tool_name, arguments=tool_arguments, id=tool_call_id
)
)
log_entry.tool_call = ToolCall(tool_name=tool_name, tool_arguments=arguments)
# Execute
console.print(Rule("Agent thoughts:", align="left"), Text(rationale))
console.print(
Panel(Text(f"Calling tool: '{tool_name}' with arguments: {arguments}"))
Panel(Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}"))
)
if tool_name == "final_answer":
if isinstance(arguments, dict):
if "answer" in arguments:
answer = arguments["answer"]
if isinstance(tool_arguments, dict):
if "answer" in tool_arguments:
answer = tool_arguments["answer"]
else:
answer = arguments
answer = tool_arguments
else:
answer = arguments
answer = tool_arguments
if (
isinstance(answer, str) and answer in self.state.keys()
): # if the answer is a state variable, return the value
@ -966,9 +957,9 @@ class ToolCallingAgent(ReactAgent):
log_entry.action_output = answer
return answer
else:
if arguments is None:
arguments = {}
observation = self.execute_tool_call(tool_name, arguments)
if tool_arguments is None:
tool_arguments = {}
observation = self.execute_tool_call(tool_name, tool_arguments)
observation_type = type(observation)
if observation_type in [AgentImage, AgentAudio]:
if observation_type == AgentImage:
@ -985,7 +976,7 @@ class ToolCallingAgent(ReactAgent):
return None
class CodeAgent(ReactAgent):
class CodeAgent(MultiStepAgent):
"""
In this agent, the tool calls will be formulated by the LLM in code format, then parsed and executed.
"""
@ -1058,18 +1049,6 @@ class CodeAgent(ReactAgent):
# Add new step in logs
log_entry.agent_memory = agent_memory.copy()
if self.verbose:
console.print(
Group(
Rule(
"[italic]Calling LLM engine with these last messages:",
align="left",
style="orange",
),
Text(str(self.prompt_messages[-2:])),
)
)
try:
additional_args = (
{"grammar": self.grammar} if self.grammar is not None else {}
@ -1220,7 +1199,7 @@ __all__ = [
"AgentError",
"BaseAgent",
"ManagedAgent",
"ReactAgent",
"MultiStepAgent",
"CodeAgent",
"JsonAgent",
"Toolbox",

View File

@ -112,9 +112,9 @@ class FinalAnswerTool(Tool):
name = "final_answer"
description = "Provides a final answer to the given problem."
inputs = {
"answer": {"type": "any", "description": "The final answer to the problem"}
"answer": {"type": "object", "description": "The final answer to the problem"}
}
output_type = "any"
output_type = "object"
def forward(self, answer):
return answer

View File

@ -24,13 +24,13 @@ def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
if isinstance(step_log, ActionStep):
yield gr.ChatMessage(role="assistant", content=step_log.llm_output)
if step_log.tool_call is not None:
used_code = step_log.tool_call.tool_name == "code interpreter"
content = step_log.tool_call.tool_arguments
used_code = step_log.tool_call.name == "code interpreter"
content = step_log.tool_call.arguments
if used_code:
content = f"```py\n{content}\n```"
yield gr.ChatMessage(
role="assistant",
metadata={"title": f"🛠️ Used tool {step_log.tool_call.tool_name}"},
metadata={"title": f"🛠️ Used tool {step_log.tool_call.name}"},
content=str(content),
)
if step_log.observations is not None:

View File

@ -16,14 +16,14 @@
# limitations under the License.
from copy import deepcopy
from enum import Enum
from typing import Dict, List, Optional
from huggingface_hub import InferenceClient
from typing import Dict, List, Optional, Tuple
from transformers import AutoTokenizer, Pipeline
import logging
import os
from openai import OpenAI
from huggingface_hub import InferenceClient
from agents import Tool
logger = logging.getLogger(__name__)
@ -50,13 +50,37 @@ class MessageRole(str, Enum):
return [r.value for r in cls]
openai_role_conversions = {
llama_role_conversions = {
MessageRole.TOOL_CALL: MessageRole.ASSISTANT,
MessageRole.TOOL_RESPONSE: MessageRole.USER,
}
llama_role_conversions = {
MessageRole.TOOL_RESPONSE: MessageRole.USER,
}
def get_json_schema(tool: Tool) -> Dict:
return {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": tool.inputs,
"required": list(tool.inputs.keys()),
},
},
}
def get_json_schema_anthropic(tool: Tool) -> Dict:
return {
"name": tool.name,
"description": tool.description,
"input_schema": {
"type": "object",
"properties": tool.inputs,
"required": list(tool.inputs.keys()),
},
}
def remove_stop_sequences(content: str, stop_sequences: List[str]) -> str:
@ -78,8 +102,8 @@ def get_clean_message_list(
final_message_list = []
message_list = deepcopy(message_list) # Avoid modifying the original list
for message in message_list:
if not set(message.keys()) == {"role", "content"}:
raise ValueError("Message should contain only 'role' and 'content' keys!")
# if not set(message.keys()) == {"role", "content"}:
# raise ValueError("Message should contain only 'role' and 'content' keys!")
role = message["role"]
if role not in MessageRole.roles():
@ -206,6 +230,7 @@ class HfApiEngine(HfEngine):
grammar: Optional[str] = None,
max_tokens: int = 1500,
) -> str:
"""Generates a text completion for the given message list"""
messages = get_clean_message_list(
messages, role_conversions=llama_role_conversions
)
@ -219,7 +244,7 @@ class HfApiEngine(HfEngine):
max_tokens=max_tokens,
)
else:
response = self.client.chat_completion(
response = self.client.chat.completions.create(
messages, stop=stop_sequences, max_tokens=max_tokens
)
@ -228,6 +253,25 @@ class HfApiEngine(HfEngine):
self.last_output_token_count = response.usage.completion_tokens
return response
def get_tool_call(
self,
messages: List[Dict[str, str]],
available_tools: List[Tool],
):
"""Generates a tool call for the given message list"""
messages = get_clean_message_list(
messages, role_conversions=llama_role_conversions
)
response = self.client.chat.completions.create(
messages=messages,
tools=[get_json_schema(tool) for tool in available_tools],
tool_choice="auto",
)
tool_call = response.choices[0].message.tool_calls[0]
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
return tool_call.function.name, tool_call.function.arguments, tool_call.id
class TransformersEngine(HfEngine):
"""This engine uses a pre-initialized local text-generation pipeline."""
@ -305,6 +349,8 @@ class OpenAIEngine:
base_url=base_url,
api_key=api_key,
)
self.last_input_token_count = 0
self.last_output_token_count = 0
def __call__(
self,
@ -328,6 +374,26 @@ class OpenAIEngine:
self.last_output_token_count = response.usage.completion_tokens
return response.choices[0].message.content
def get_tool_call(
self,
messages: List[Dict[str, str]],
available_tools: List[Tool],
):
"""Generates a tool call for the given message list"""
messages = get_clean_message_list(
messages, role_conversions=llama_role_conversions
)
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
tools=[get_json_schema(tool) for tool in available_tools],
tool_choice="required",
)
tool_call = response.choices[0].message.tool_calls[0]
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
return tool_call.function.name, tool_call.function.arguments, tool_call.id
class AnthropicEngine:
def __init__(self, model_name="claude-3-5-sonnet-20240620", use_bedrock=False):
@ -345,17 +411,19 @@ class AnthropicEngine:
self.client = Anthropic(
api_key=os.getenv("ANTHROPIC_API_KEY"),
)
self.last_input_token_count = 0
self.last_output_token_count = 0
def __call__(
def separate_messages_system_prompt(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
) -> str:
messages = get_clean_message_list(
messages, role_conversions=openai_role_conversions
)
messages: List[
Dict[
str,
str,
]
],
) -> Tuple[List[Dict[str, str]], str]:
"""Gets the system prompt and the rest of messages as separate elements."""
index_system_message, system_prompt = None, None
for index, message in enumerate(messages):
if message["role"] == MessageRole.SYSTEM:
@ -370,7 +438,21 @@ class AnthropicEngine:
if len(filtered_messages) == 0:
print("Error, no user message:", messages)
assert False
return filtered_messages, system_prompt
def __call__(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
) -> str:
messages = get_clean_message_list(
messages, role_conversions=llama_role_conversions
)
filtered_messages, system_prompt = self.separate_messages_system_prompt(
messages
)
response = self.client.messages.create(
model=self.model_name,
system=system_prompt,
@ -385,6 +467,32 @@ class AnthropicEngine:
full_response_text += content_block.text
return full_response_text
def get_tool_call(
self,
messages: List[Dict[str, str]],
available_tools: List[Tool],
max_tokens: int = 1500,
):
"""Generates a tool call for the given message list"""
messages = get_clean_message_list(
messages, role_conversions=llama_role_conversions
)
filtered_messages, system_prompt = self.separate_messages_system_prompt(
messages
)
response = self.client.messages.create(
model=self.model_name,
system=system_prompt,
messages=filtered_messages,
tools=[get_json_schema_anthropic(tool) for tool in available_tools],
tool_choice={"type": "any"},
max_tokens=max_tokens,
)
tool_call = response.content[0]
self.last_input_token_count = response.usage.input_tokens
self.last_output_token_count = response.usage.output_tokens
return tool_call.name, tool_call.input, tool_call.id
__all__ = [
"MessageRole",

View File

@ -51,7 +51,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
return f.read()
ONESHOT_CODE_SYSTEM_PROMPT = """You will be given a task to solve, your job is to come up with a series of simple commands in Python that will perform the task.
SINGLE_STEP_CODE_SYSTEM_PROMPT = """You will be given a task to solve, your job is to come up with a series of simple commands in Python that will perform the task.
To help you, I will give you access to a set of tools that you can use. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns.
You should first explain which tool you will use to perform the task and for what reason, then write the code in Python.
Each instruction in Python should be a simple assignment. You can print intermediate results if it makes sense to do so.
@ -618,7 +618,7 @@ And even if your task resolution is not successful, please return as much contex
__all__ = [
"USER_PROMPT_PLAN_UPDATE",
"PLAN_UPDATE_FINAL_PLAN_REDACTION",
"ONESHOT_CODE_SYSTEM_PROMPT",
"SINGLE_STEP_CODE_SYSTEM_PROMPT",
"CODE_SYSTEM_PROMPT",
"JSON_SYSTEM_PROMPT",
"MANAGED_AGENT_PROMPT",

View File

@ -114,6 +114,7 @@ AUTHORIZED_TYPES = [
"image",
"audio",
"any",
"object",
]
CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"}

View File

@ -143,10 +143,6 @@ class ImportFinder(ast.NodeVisitor):
self.packages.add(base_package)
import ast
from typing import Dict
def get_method_source(method):
"""Get source code for a method, including bound methods."""
if isinstance(method, types.MethodType):

View File

@ -150,7 +150,7 @@ final_answer(res)
"""
def fake_code_llm_oneshot(messages, stop_sequences=None, grammar=None) -> str:
def fake_code_llm_single_step(messages, stop_sequences=None, grammar=None) -> str:
return """
Thought: I should multiply 2 by 3.6452. special_marker
Code:
@ -173,11 +173,11 @@ print(result)
class AgentTests(unittest.TestCase):
def test_fake_oneshot_code_agent(self):
def test_fake_single_step_code_agent(self):
agent = CodeAgent(
tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_oneshot
tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_single_step
)
output = agent.run("What is 2 multiplied by 3.6452?", oneshot=True)
output = agent.run("What is 2 multiplied by 3.6452?", single_step=True)
assert isinstance(output, str)
assert output == "7.2904"