Enable support for tool calling agents

This commit is contained in:
Aymeric 2024-12-23 17:10:07 +01:00
parent 24d9cf9e3d
commit 30cb6111b3
13 changed files with 224 additions and 139 deletions

View File

@ -40,7 +40,8 @@ This library offers:
🤗 **Hub integrations**: you can share and load tools to/from the Hub, and more is to come! 🤗 **Hub integrations**: you can share and load tools to/from the Hub, and more is to come!
Quick demo: ## Quick demo
First install the package. First install the package.
```bash ```bash
pip install agents pip install agents

View File

@ -17,7 +17,7 @@
- local: conceptual_guides/intro_agents - local: conceptual_guides/intro_agents
title: 🤖 An introduction to agentic systems title: 🤖 An introduction to agentic systems
- local: conceptual_guides/react - local: conceptual_guides/react
title: 🤔 ReAct agents title: 🤔 How do Multi-step agents work?
- title: Examples - title: Examples
sections: sections:
- local: examples/text_to_sql - local: examples/text_to_sql

View File

@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License.
rendered properly in your Markdown viewer. rendered properly in your Markdown viewer.
--> -->
# ReAct agents # How do multi-step agents work?
The ReAct framework ([Yao et al., 2022](https://huggingface.co/papers/2210.03629)) is currently the main approach to building agents. The ReAct framework ([Yao et al., 2022](https://huggingface.co/papers/2210.03629)) is currently the main approach to building agents.
@ -22,7 +22,7 @@ The name is based on the concatenation of two words, "Reason" and "Act." Indeed,
React process involves keeping a memory of past steps. React process involves keeping a memory of past steps.
> [!TIP] > [!TIP]
> Read [Open-source LLMs as LangChain Agents](https://huggingface.co/blog/open-source-llms-as-agents) blog post to learn more about ReAct agents. > Read [Open-source LLMs as LangChain Agents](https://huggingface.co/blog/open-source-llms-as-agents) blog post to learn more about multi-step agents.
Here is a video overview of how that works: Here is a video overview of how that works:
@ -44,4 +44,4 @@ We implement two versions of JsonAgent:
- [`CodeAgent`] is a new type of JsonAgent that generates its tool calls as blobs of code, which works really well for LLMs that have strong coding performance. - [`CodeAgent`] is a new type of JsonAgent that generates its tool calls as blobs of code, which works really well for LLMs that have strong coding performance.
> [!TIP] > [!TIP]
> We also provide an option to run agents in one-shot: just pass `oneshot=True` when launching the agent, like `agent.run(your_task, oneshot=True)` > We also provide an option to run agents in one-shot: just pass `single_step=True` when launching the agent, like `agent.run(your_task, single_step=True)`

View File

@ -122,14 +122,14 @@ def sql_engine(query: str) -> str:
Now let us create an agent that leverages this tool. Now let us create an agent that leverages this tool.
We use the ReactCodeAgent, which is transformers.agents main agent class: an agent that writes actions in code and can iterate on previous output according to the ReAct framework. We use the CodeAgent, which is transformers.agents main agent class: an agent that writes actions in code and can iterate on previous output according to the ReAct framework.
The llm_engine is the LLM that powers the agent system. HfEngine allows you to call LLMs using HFs Inference API, either via Serverless or Dedicated endpoint, but you could also use any proprietary API. The llm_engine is the LLM that powers the agent system. HfEngine allows you to call LLMs using HFs Inference API, either via Serverless or Dedicated endpoint, but you could also use any proprietary API.
```py ```py
from transformers.agents import ReactCodeAgent, HfApiEngine from transformers.agents import CodeAgent, HfApiEngine
agent = ReactCodeAgent( agent = CodeAgent(
tools=[sql_engine], tools=[sql_engine],
llm_engine=HfApiEngine("meta-llama/Meta-Llama-3-8B-Instruct"), llm_engine=HfApiEngine("meta-llama/Meta-Llama-3-8B-Instruct"),
) )
@ -185,7 +185,7 @@ Since this request is a bit harder than the previous one, well switch the LLM
```py ```py
sql_engine.description = updated_description sql_engine.description = updated_description
agent = ReactCodeAgent( agent = CodeAgent(
tools=[sql_engine], tools=[sql_engine],
llm_engine=HfApiEngine("Qwen/Qwen2.5-72B-Instruct"), llm_engine=HfApiEngine("Qwen/Qwen2.5-72B-Instruct"),
) )

View File

@ -27,7 +27,7 @@ contains the API docs for the underlying classes.
## Agents ## Agents
Our agents inherit from [`ReactAgent`], which means they can act in multiple steps, each step consisting of one thought, then one tool call and execution. Read more in [this conceptual guide](../conceptual_guides/react). Our agents inherit from [`MultiStepAgent`], which means they can act in multiple steps, each step consisting of one thought, then one tool call and execution. Read more in [this conceptual guide](../conceptual_guides/react).
We provide two types of agents, based on the main [`Agent`] class. We provide two types of agents, based on the main [`Agent`] class.
- [`JsonAgent`] writes its tool calls in JSON. - [`JsonAgent`] writes its tool calls in JSON.
@ -40,7 +40,7 @@ We provide two types of agents, based on the main [`Agent`] class.
### React agents ### React agents
[[autodoc]] ReactAgent [[autodoc]] MultiStepAgent
[[autodoc]] JsonAgent [[autodoc]] JsonAgent

View File

@ -89,8 +89,9 @@ class AgentGenerationError(AgentError):
@dataclass @dataclass
class ToolCall: class ToolCall:
tool_name: str name: str
tool_arguments: Any arguments: Any
id: str
class AgentStep: class AgentStep:
@ -306,26 +307,49 @@ class BaseAgent:
} }
memory.append(thought_message) memory.append(thought_message)
if step_log.tool_call is not None and summary_mode: if step_log.tool_call is not None:
tool_call_message = { tool_call_message = {
"role": MessageRole.ASSISTANT, "role": MessageRole.ASSISTANT,
"content": f"[STEP {i} TOOL CALL]: " "content": str(
+ str(step_log.tool_call).strip(), [
{
"id": step_log.tool_call.id,
"type": "function",
"function": {
"name": step_log.tool_call.name,
"arguments": step_log.tool_call.arguments,
},
}
]
),
} }
memory.append(tool_call_message) memory.append(tool_call_message)
if step_log.error is not None or step_log.observations is not None: if step_log.tool_call is None and step_log.error is not None:
message_content = (
"Error:\n"
+ str(step_log.error)
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
)
tool_response_message = {
"role": MessageRole.ASSISTANT,
"content": message_content,
}
if step_log.tool_call is not None and (
step_log.error is not None or step_log.observations is not None
):
if step_log.error is not None: if step_log.error is not None:
message_content = ( message_content = (
f"[OUTPUT OF STEP {i}] -> Error:\n" "Error:\n"
+ str(step_log.error) + str(step_log.error)
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n" + "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
) )
elif step_log.observations is not None: elif step_log.observations is not None:
message_content = f"[OUTPUT OF STEP {i}] -> Observation:\n{step_log.observations}" message_content = f"Observation:\n{step_log.observations}"
tool_response_message = { tool_response_message = {
"role": MessageRole.TOOL_RESPONSE, "role": MessageRole.TOOL_RESPONSE,
"content": message_content, "content": f"Call id: {(step_log.tool_call.id if getattr(step_log.tool_call, 'id') else 'call_0')}\n"
+ message_content,
} }
memory.append(tool_response_message) memory.append(tool_response_message)
@ -362,7 +386,7 @@ class BaseAgent:
raise NotImplementedError raise NotImplementedError
class ReactAgent(BaseAgent): class MultiStepAgent(BaseAgent):
""" """
This agent that solves the given task step by step, using the ReAct framework: This agent that solves the given task step by step, using the ReAct framework:
While the objective is not reached, the agent will perform a cycle of thinking and acting. While the objective is not reached, the agent will perform a cycle of thinking and acting.
@ -474,7 +498,7 @@ class ReactAgent(BaseAgent):
task: str, task: str,
stream: bool = False, stream: bool = False,
reset: bool = True, reset: bool = True,
oneshot: bool = False, single_step: bool = False,
**kwargs, **kwargs,
): ):
""" """
@ -484,7 +508,7 @@ class ReactAgent(BaseAgent):
task (`str`): The task to perform. task (`str`): The task to perform.
stream (`bool`): Wether to run in a streaming way. stream (`bool`): Wether to run in a streaming way.
reset (`bool`): Wether to reset the conversation or keep it going from previous run. reset (`bool`): Wether to reset the conversation or keep it going from previous run.
oneshot (`bool`): Should the agent run in one shot or multi-step fashion? single_step (`bool`): Should the agent run in one shot or multi-step fashion?
Example: Example:
```py ```py
@ -516,7 +540,7 @@ class ReactAgent(BaseAgent):
console.print(Group(Rule("[bold]New task", characters="="), Text(self.task))) console.print(Group(Rule("[bold]New task", characters="="), Text(self.task)))
self.logs.append(TaskStep(task=self.task)) self.logs.append(TaskStep(task=self.task))
if oneshot: if single_step:
step_start_time = time.time() step_start_time = time.time()
step_log = ActionStep(start_time=step_start_time) step_log = ActionStep(start_time=step_start_time)
step_log.end_time = time.time() step_log.end_time = time.time()
@ -548,7 +572,7 @@ class ReactAgent(BaseAgent):
self.planning_step( self.planning_step(
task, is_first_step=(iteration == 0), iteration=iteration task, is_first_step=(iteration == 0), iteration=iteration
) )
console.rule("[bold]New step") console.rule(f"[bold]Step {iteration}")
# Run one step! # Run one step!
final_answer = self.step(step_log) final_answer = self.step(step_log)
@ -594,7 +618,7 @@ class ReactAgent(BaseAgent):
self.planning_step( self.planning_step(
task, is_first_step=(iteration == 0), iteration=iteration task, is_first_step=(iteration == 0), iteration=iteration
) )
console.rule("[bold]New step") console.rule(f"[bold]Step {iteration}")
# Run one step! # Run one step!
final_answer = self.step(step_log) final_answer = self.step(step_log)
@ -741,7 +765,7 @@ Now begin!""",
) )
class JsonAgent(ReactAgent): class JsonAgent(MultiStepAgent):
""" """
In this agent, the tool calls will be formulated by the LLM in JSON format, then parsed and executed. In this agent, the tool calls will be formulated by the LLM in JSON format, then parsed and executed.
""" """
@ -784,18 +808,6 @@ class JsonAgent(ReactAgent):
# Add new step in logs # Add new step in logs
log_entry.agent_memory = agent_memory.copy() log_entry.agent_memory = agent_memory.copy()
if self.verbose:
console.print(
Group(
Rule(
"[italic]Calling LLM engine with this last message:",
align="left",
style="orange",
),
Text(str(self.prompt_messages[-1])),
)
)
try: try:
additional_args = ( additional_args = (
{"grammar": self.grammar} if self.grammar is not None else {} {"grammar": self.grammar} if self.grammar is not None else {}
@ -827,25 +839,27 @@ class JsonAgent(ReactAgent):
) )
try: try:
tool_name, arguments = self.tool_parser(action) tool_name, tool_arguments = self.tool_parser(action)
except Exception as e: except Exception as e:
raise AgentParsingError(f"Could not parse the given action: {e}.") raise AgentParsingError(f"Could not parse the given action: {e}.")
log_entry.tool_call = ToolCall(tool_name=tool_name, tool_arguments=arguments) log_entry.tool_call = ToolCall(
tool_name=tool_name, tool_arguments=tool_arguments
)
# Execute # Execute
console.print(Rule("Agent thoughts:", align="left"), Text(rationale)) console.print(Rule("Agent thoughts:", align="left"), Text(rationale))
console.print( console.print(
Panel(Text(f"Calling tool: '{tool_name}' with arguments: {arguments}")) Panel(Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}"))
) )
if tool_name == "final_answer": if tool_name == "final_answer":
if isinstance(arguments, dict): if isinstance(tool_arguments, dict):
if "answer" in arguments: if "answer" in tool_arguments:
answer = arguments["answer"] answer = tool_arguments["answer"]
else: else:
answer = arguments answer = tool_arguments
else: else:
answer = arguments answer = tool_arguments
if ( if (
isinstance(answer, str) and answer in self.state.keys() isinstance(answer, str) and answer in self.state.keys()
): # if the answer is a state variable, return the value ): # if the answer is a state variable, return the value
@ -853,9 +867,9 @@ class JsonAgent(ReactAgent):
log_entry.action_output = answer log_entry.action_output = answer
return answer return answer
else: else:
if arguments is None: if tool_arguments is None:
arguments = {} tool_arguments = {}
observation = self.execute_tool_call(tool_name, arguments) observation = self.execute_tool_call(tool_name, tool_arguments)
observation_type = type(observation) observation_type = type(observation)
if observation_type in [AgentImage, AgentAudio]: if observation_type in [AgentImage, AgentAudio]:
if observation_type == AgentImage: if observation_type == AgentImage:
@ -871,9 +885,10 @@ class JsonAgent(ReactAgent):
log_entry.observations = updated_information log_entry.observations = updated_information
return None return None
class ToolCallingAgent(ReactAgent):
class ToolCallingAgent(MultiStepAgent):
""" """
In this agent, the tool calls will be formulated and parsed using the underlying library, before execution. This agent uses JSON-like tool calls, but to the difference of JsonAgents, it makes use of the underlying librarie's tool calling facilities.
""" """
def __init__( def __init__(
@ -912,53 +927,29 @@ class ToolCallingAgent(ReactAgent):
# Add new step in logs # Add new step in logs
log_entry.agent_memory = agent_memory.copy() log_entry.agent_memory = agent_memory.copy()
if self.verbose:
console.print(
Group(
Rule(
"[italic]Calling LLM engine with this last message:",
align="left",
style="orange",
),
Text(str(self.prompt_messages[-1])),
)
)
try: try:
llm_output = self.llm_engine( tool_name, tool_arguments, tool_call_id = self.llm_engine.get_tool_call(
self.prompt_messages, self.prompt_messages, available_tools=list(self.toolbox._tools.values())
) )
log_entry.llm_output = llm_output
except Exception as e: except Exception as e:
raise AgentGenerationError(f"Error in generating llm_engine output: {e}.") raise AgentGenerationError(f"Error in generating llm_engine output: {e}.")
if self.verbose: log_entry.tool_call = ToolCall(
console.print( name=tool_name, arguments=tool_arguments, id=tool_call_id
Group(
Rule(
"[italic]Output message of the LLM:",
align="left",
style="orange",
),
Text(llm_output),
) )
)
log_entry.tool_call = ToolCall(tool_name=tool_name, tool_arguments=arguments)
# Execute # Execute
console.print(Rule("Agent thoughts:", align="left"), Text(rationale))
console.print( console.print(
Panel(Text(f"Calling tool: '{tool_name}' with arguments: {arguments}")) Panel(Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}"))
) )
if tool_name == "final_answer": if tool_name == "final_answer":
if isinstance(arguments, dict): if isinstance(tool_arguments, dict):
if "answer" in arguments: if "answer" in tool_arguments:
answer = arguments["answer"] answer = tool_arguments["answer"]
else: else:
answer = arguments answer = tool_arguments
else: else:
answer = arguments answer = tool_arguments
if ( if (
isinstance(answer, str) and answer in self.state.keys() isinstance(answer, str) and answer in self.state.keys()
): # if the answer is a state variable, return the value ): # if the answer is a state variable, return the value
@ -966,9 +957,9 @@ class ToolCallingAgent(ReactAgent):
log_entry.action_output = answer log_entry.action_output = answer
return answer return answer
else: else:
if arguments is None: if tool_arguments is None:
arguments = {} tool_arguments = {}
observation = self.execute_tool_call(tool_name, arguments) observation = self.execute_tool_call(tool_name, tool_arguments)
observation_type = type(observation) observation_type = type(observation)
if observation_type in [AgentImage, AgentAudio]: if observation_type in [AgentImage, AgentAudio]:
if observation_type == AgentImage: if observation_type == AgentImage:
@ -985,7 +976,7 @@ class ToolCallingAgent(ReactAgent):
return None return None
class CodeAgent(ReactAgent): class CodeAgent(MultiStepAgent):
""" """
In this agent, the tool calls will be formulated by the LLM in code format, then parsed and executed. In this agent, the tool calls will be formulated by the LLM in code format, then parsed and executed.
""" """
@ -1058,18 +1049,6 @@ class CodeAgent(ReactAgent):
# Add new step in logs # Add new step in logs
log_entry.agent_memory = agent_memory.copy() log_entry.agent_memory = agent_memory.copy()
if self.verbose:
console.print(
Group(
Rule(
"[italic]Calling LLM engine with these last messages:",
align="left",
style="orange",
),
Text(str(self.prompt_messages[-2:])),
)
)
try: try:
additional_args = ( additional_args = (
{"grammar": self.grammar} if self.grammar is not None else {} {"grammar": self.grammar} if self.grammar is not None else {}
@ -1220,7 +1199,7 @@ __all__ = [
"AgentError", "AgentError",
"BaseAgent", "BaseAgent",
"ManagedAgent", "ManagedAgent",
"ReactAgent", "MultiStepAgent",
"CodeAgent", "CodeAgent",
"JsonAgent", "JsonAgent",
"Toolbox", "Toolbox",

View File

@ -112,9 +112,9 @@ class FinalAnswerTool(Tool):
name = "final_answer" name = "final_answer"
description = "Provides a final answer to the given problem." description = "Provides a final answer to the given problem."
inputs = { inputs = {
"answer": {"type": "any", "description": "The final answer to the problem"} "answer": {"type": "object", "description": "The final answer to the problem"}
} }
output_type = "any" output_type = "object"
def forward(self, answer): def forward(self, answer):
return answer return answer

View File

@ -24,13 +24,13 @@ def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
if isinstance(step_log, ActionStep): if isinstance(step_log, ActionStep):
yield gr.ChatMessage(role="assistant", content=step_log.llm_output) yield gr.ChatMessage(role="assistant", content=step_log.llm_output)
if step_log.tool_call is not None: if step_log.tool_call is not None:
used_code = step_log.tool_call.tool_name == "code interpreter" used_code = step_log.tool_call.name == "code interpreter"
content = step_log.tool_call.tool_arguments content = step_log.tool_call.arguments
if used_code: if used_code:
content = f"```py\n{content}\n```" content = f"```py\n{content}\n```"
yield gr.ChatMessage( yield gr.ChatMessage(
role="assistant", role="assistant",
metadata={"title": f"🛠️ Used tool {step_log.tool_call.tool_name}"}, metadata={"title": f"🛠️ Used tool {step_log.tool_call.name}"},
content=str(content), content=str(content),
) )
if step_log.observations is not None: if step_log.observations is not None:

View File

@ -16,14 +16,14 @@
# limitations under the License. # limitations under the License.
from copy import deepcopy from copy import deepcopy
from enum import Enum from enum import Enum
from typing import Dict, List, Optional from typing import Dict, List, Optional, Tuple
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer, Pipeline from transformers import AutoTokenizer, Pipeline
import logging import logging
import os import os
from openai import OpenAI from openai import OpenAI
from huggingface_hub import InferenceClient
from agents import Tool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -50,13 +50,37 @@ class MessageRole(str, Enum):
return [r.value for r in cls] return [r.value for r in cls]
openai_role_conversions = { llama_role_conversions = {
MessageRole.TOOL_CALL: MessageRole.ASSISTANT,
MessageRole.TOOL_RESPONSE: MessageRole.USER, MessageRole.TOOL_RESPONSE: MessageRole.USER,
} }
llama_role_conversions = {
MessageRole.TOOL_RESPONSE: MessageRole.USER, def get_json_schema(tool: Tool) -> Dict:
} return {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": tool.inputs,
"required": list(tool.inputs.keys()),
},
},
}
def get_json_schema_anthropic(tool: Tool) -> Dict:
return {
"name": tool.name,
"description": tool.description,
"input_schema": {
"type": "object",
"properties": tool.inputs,
"required": list(tool.inputs.keys()),
},
}
def remove_stop_sequences(content: str, stop_sequences: List[str]) -> str: def remove_stop_sequences(content: str, stop_sequences: List[str]) -> str:
@ -78,8 +102,8 @@ def get_clean_message_list(
final_message_list = [] final_message_list = []
message_list = deepcopy(message_list) # Avoid modifying the original list message_list = deepcopy(message_list) # Avoid modifying the original list
for message in message_list: for message in message_list:
if not set(message.keys()) == {"role", "content"}: # if not set(message.keys()) == {"role", "content"}:
raise ValueError("Message should contain only 'role' and 'content' keys!") # raise ValueError("Message should contain only 'role' and 'content' keys!")
role = message["role"] role = message["role"]
if role not in MessageRole.roles(): if role not in MessageRole.roles():
@ -206,6 +230,7 @@ class HfApiEngine(HfEngine):
grammar: Optional[str] = None, grammar: Optional[str] = None,
max_tokens: int = 1500, max_tokens: int = 1500,
) -> str: ) -> str:
"""Generates a text completion for the given message list"""
messages = get_clean_message_list( messages = get_clean_message_list(
messages, role_conversions=llama_role_conversions messages, role_conversions=llama_role_conversions
) )
@ -219,7 +244,7 @@ class HfApiEngine(HfEngine):
max_tokens=max_tokens, max_tokens=max_tokens,
) )
else: else:
response = self.client.chat_completion( response = self.client.chat.completions.create(
messages, stop=stop_sequences, max_tokens=max_tokens messages, stop=stop_sequences, max_tokens=max_tokens
) )
@ -228,6 +253,25 @@ class HfApiEngine(HfEngine):
self.last_output_token_count = response.usage.completion_tokens self.last_output_token_count = response.usage.completion_tokens
return response return response
def get_tool_call(
self,
messages: List[Dict[str, str]],
available_tools: List[Tool],
):
"""Generates a tool call for the given message list"""
messages = get_clean_message_list(
messages, role_conversions=llama_role_conversions
)
response = self.client.chat.completions.create(
messages=messages,
tools=[get_json_schema(tool) for tool in available_tools],
tool_choice="auto",
)
tool_call = response.choices[0].message.tool_calls[0]
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
return tool_call.function.name, tool_call.function.arguments, tool_call.id
class TransformersEngine(HfEngine): class TransformersEngine(HfEngine):
"""This engine uses a pre-initialized local text-generation pipeline.""" """This engine uses a pre-initialized local text-generation pipeline."""
@ -305,6 +349,8 @@ class OpenAIEngine:
base_url=base_url, base_url=base_url,
api_key=api_key, api_key=api_key,
) )
self.last_input_token_count = 0
self.last_output_token_count = 0
def __call__( def __call__(
self, self,
@ -328,6 +374,26 @@ class OpenAIEngine:
self.last_output_token_count = response.usage.completion_tokens self.last_output_token_count = response.usage.completion_tokens
return response.choices[0].message.content return response.choices[0].message.content
def get_tool_call(
self,
messages: List[Dict[str, str]],
available_tools: List[Tool],
):
"""Generates a tool call for the given message list"""
messages = get_clean_message_list(
messages, role_conversions=llama_role_conversions
)
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
tools=[get_json_schema(tool) for tool in available_tools],
tool_choice="required",
)
tool_call = response.choices[0].message.tool_calls[0]
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
return tool_call.function.name, tool_call.function.arguments, tool_call.id
class AnthropicEngine: class AnthropicEngine:
def __init__(self, model_name="claude-3-5-sonnet-20240620", use_bedrock=False): def __init__(self, model_name="claude-3-5-sonnet-20240620", use_bedrock=False):
@ -345,17 +411,19 @@ class AnthropicEngine:
self.client = Anthropic( self.client = Anthropic(
api_key=os.getenv("ANTHROPIC_API_KEY"), api_key=os.getenv("ANTHROPIC_API_KEY"),
) )
self.last_input_token_count = 0
self.last_output_token_count = 0
def __call__( def separate_messages_system_prompt(
self, self,
messages: List[Dict[str, str]], messages: List[
stop_sequences: Optional[List[str]] = None, Dict[
grammar: Optional[str] = None, str,
max_tokens: int = 1500, str,
) -> str: ]
messages = get_clean_message_list( ],
messages, role_conversions=openai_role_conversions ) -> Tuple[List[Dict[str, str]], str]:
) """Gets the system prompt and the rest of messages as separate elements."""
index_system_message, system_prompt = None, None index_system_message, system_prompt = None, None
for index, message in enumerate(messages): for index, message in enumerate(messages):
if message["role"] == MessageRole.SYSTEM: if message["role"] == MessageRole.SYSTEM:
@ -370,7 +438,21 @@ class AnthropicEngine:
if len(filtered_messages) == 0: if len(filtered_messages) == 0:
print("Error, no user message:", messages) print("Error, no user message:", messages)
assert False assert False
return filtered_messages, system_prompt
def __call__(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
) -> str:
messages = get_clean_message_list(
messages, role_conversions=llama_role_conversions
)
filtered_messages, system_prompt = self.separate_messages_system_prompt(
messages
)
response = self.client.messages.create( response = self.client.messages.create(
model=self.model_name, model=self.model_name,
system=system_prompt, system=system_prompt,
@ -385,6 +467,32 @@ class AnthropicEngine:
full_response_text += content_block.text full_response_text += content_block.text
return full_response_text return full_response_text
def get_tool_call(
self,
messages: List[Dict[str, str]],
available_tools: List[Tool],
max_tokens: int = 1500,
):
"""Generates a tool call for the given message list"""
messages = get_clean_message_list(
messages, role_conversions=llama_role_conversions
)
filtered_messages, system_prompt = self.separate_messages_system_prompt(
messages
)
response = self.client.messages.create(
model=self.model_name,
system=system_prompt,
messages=filtered_messages,
tools=[get_json_schema_anthropic(tool) for tool in available_tools],
tool_choice={"type": "any"},
max_tokens=max_tokens,
)
tool_call = response.content[0]
self.last_input_token_count = response.usage.input_tokens
self.last_output_token_count = response.usage.output_tokens
return tool_call.name, tool_call.input, tool_call.id
__all__ = [ __all__ = [
"MessageRole", "MessageRole",

View File

@ -51,7 +51,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
return f.read() return f.read()
ONESHOT_CODE_SYSTEM_PROMPT = """You will be given a task to solve, your job is to come up with a series of simple commands in Python that will perform the task. SINGLE_STEP_CODE_SYSTEM_PROMPT = """You will be given a task to solve, your job is to come up with a series of simple commands in Python that will perform the task.
To help you, I will give you access to a set of tools that you can use. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns. To help you, I will give you access to a set of tools that you can use. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns.
You should first explain which tool you will use to perform the task and for what reason, then write the code in Python. You should first explain which tool you will use to perform the task and for what reason, then write the code in Python.
Each instruction in Python should be a simple assignment. You can print intermediate results if it makes sense to do so. Each instruction in Python should be a simple assignment. You can print intermediate results if it makes sense to do so.
@ -618,7 +618,7 @@ And even if your task resolution is not successful, please return as much contex
__all__ = [ __all__ = [
"USER_PROMPT_PLAN_UPDATE", "USER_PROMPT_PLAN_UPDATE",
"PLAN_UPDATE_FINAL_PLAN_REDACTION", "PLAN_UPDATE_FINAL_PLAN_REDACTION",
"ONESHOT_CODE_SYSTEM_PROMPT", "SINGLE_STEP_CODE_SYSTEM_PROMPT",
"CODE_SYSTEM_PROMPT", "CODE_SYSTEM_PROMPT",
"JSON_SYSTEM_PROMPT", "JSON_SYSTEM_PROMPT",
"MANAGED_AGENT_PROMPT", "MANAGED_AGENT_PROMPT",

View File

@ -114,6 +114,7 @@ AUTHORIZED_TYPES = [
"image", "image",
"audio", "audio",
"any", "any",
"object",
] ]
CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"} CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"}

View File

@ -143,10 +143,6 @@ class ImportFinder(ast.NodeVisitor):
self.packages.add(base_package) self.packages.add(base_package)
import ast
from typing import Dict
def get_method_source(method): def get_method_source(method):
"""Get source code for a method, including bound methods.""" """Get source code for a method, including bound methods."""
if isinstance(method, types.MethodType): if isinstance(method, types.MethodType):

View File

@ -150,7 +150,7 @@ final_answer(res)
""" """
def fake_code_llm_oneshot(messages, stop_sequences=None, grammar=None) -> str: def fake_code_llm_single_step(messages, stop_sequences=None, grammar=None) -> str:
return """ return """
Thought: I should multiply 2 by 3.6452. special_marker Thought: I should multiply 2 by 3.6452. special_marker
Code: Code:
@ -173,11 +173,11 @@ print(result)
class AgentTests(unittest.TestCase): class AgentTests(unittest.TestCase):
def test_fake_oneshot_code_agent(self): def test_fake_single_step_code_agent(self):
agent = CodeAgent( agent = CodeAgent(
tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_oneshot tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_single_step
) )
output = agent.run("What is 2 multiplied by 3.6452?", oneshot=True) output = agent.run("What is 2 multiplied by 3.6452?", single_step=True)
assert isinstance(output, str) assert isinstance(output, str)
assert output == "7.2904" assert output == "7.2904"