Add tool calling agent example

This commit is contained in:
Aymeric 2024-12-23 17:22:35 +01:00
parent 30cb6111b3
commit 32d7bc5e06
4 changed files with 33 additions and 11 deletions

View File

@ -0,0 +1,22 @@
from agents.agents import ToolCallingAgent
from agents import tool, HfApiEngine, OpenAIEngine, AnthropicEngine
# Choose which LLM engine to use!
llm_engine = OpenAIEngine("gpt-4o")
llm_engine = AnthropicEngine()
llm_engine = HfApiEngine("meta-llama/Llama-3.3-70B-Instruct")
@tool
def get_weather(location: str) -> str:
"""
Get weather in the next days at given location.
Secretly this tool does not care about the location, it hates the weather everywhere.
Args:
location: the location
"""
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
agent = ToolCallingAgent(tools=[get_weather], llm_engine=llm_engine)
print(agent.run("What's the weather like in Paris?"))

View File

@ -40,7 +40,7 @@ class DockerPythonInterpreter:
Execute Python code in the container and return stdout and stderr Execute Python code in the container and return stdout and stderr
""" """
if tools != None: if tools is not None:
tool_instance = tools[0]() tool_instance = tools[0]()
import_code = f""" import_code = f"""

View File

@ -50,7 +50,7 @@ class MessageRole(str, Enum):
return [r.value for r in cls] return [r.value for r in cls]
llama_role_conversions = { tool_role_conversions = {
MessageRole.TOOL_CALL: MessageRole.ASSISTANT, MessageRole.TOOL_CALL: MessageRole.ASSISTANT,
MessageRole.TOOL_RESPONSE: MessageRole.USER, MessageRole.TOOL_RESPONSE: MessageRole.USER,
} }
@ -232,7 +232,7 @@ class HfApiEngine(HfEngine):
) -> str: ) -> str:
"""Generates a text completion for the given message list""" """Generates a text completion for the given message list"""
messages = get_clean_message_list( messages = get_clean_message_list(
messages, role_conversions=llama_role_conversions messages, role_conversions=tool_role_conversions
) )
# Send messages to the Hugging Face Inference API # Send messages to the Hugging Face Inference API
@ -260,7 +260,7 @@ class HfApiEngine(HfEngine):
): ):
"""Generates a tool call for the given message list""" """Generates a tool call for the given message list"""
messages = get_clean_message_list( messages = get_clean_message_list(
messages, role_conversions=llama_role_conversions messages, role_conversions=tool_role_conversions
) )
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
messages=messages, messages=messages,
@ -302,7 +302,7 @@ class TransformersEngine(HfEngine):
max_tokens: int = 1500, max_tokens: int = 1500,
) -> str: ) -> str:
messages = get_clean_message_list( messages = get_clean_message_list(
messages, role_conversions=llama_role_conversions messages, role_conversions=tool_role_conversions
) )
# Get LLM output # Get LLM output
@ -360,7 +360,7 @@ class OpenAIEngine:
max_tokens: int = 1500, max_tokens: int = 1500,
) -> str: ) -> str:
messages = get_clean_message_list( messages = get_clean_message_list(
messages, role_conversions=openai_role_conversions messages, role_conversions=tool_role_conversions
) )
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
@ -381,7 +381,7 @@ class OpenAIEngine:
): ):
"""Generates a tool call for the given message list""" """Generates a tool call for the given message list"""
messages = get_clean_message_list( messages = get_clean_message_list(
messages, role_conversions=llama_role_conversions messages, role_conversions=tool_role_conversions
) )
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=self.model_name, model=self.model_name,
@ -448,7 +448,7 @@ class AnthropicEngine:
max_tokens: int = 1500, max_tokens: int = 1500,
) -> str: ) -> str:
messages = get_clean_message_list( messages = get_clean_message_list(
messages, role_conversions=llama_role_conversions messages, role_conversions=tool_role_conversions
) )
filtered_messages, system_prompt = self.separate_messages_system_prompt( filtered_messages, system_prompt = self.separate_messages_system_prompt(
messages messages
@ -475,7 +475,7 @@ class AnthropicEngine:
): ):
"""Generates a tool call for the given message list""" """Generates a tool call for the given message list"""
messages = get_clean_message_list( messages = get_clean_message_list(
messages, role_conversions=llama_role_conversions messages, role_conversions=tool_role_conversions
) )
filtered_messages, system_prompt = self.separate_messages_system_prompt( filtered_messages, system_prompt = self.separate_messages_system_prompt(
messages messages
@ -496,7 +496,7 @@ class AnthropicEngine:
__all__ = [ __all__ = [
"MessageRole", "MessageRole",
"llama_role_conversions", "tool_role_conversions",
"get_clean_message_list", "get_clean_message_list",
"HfEngine", "HfEngine",
"TransformersEngine", "TransformersEngine",

View File

@ -232,7 +232,7 @@ Action:
def test_additional_args_added_to_task(self): def test_additional_args_added_to_task(self):
agent = CodeAgent(tools=[], llm_engine=fake_code_llm) agent = CodeAgent(tools=[], llm_engine=fake_code_llm)
output = agent.run( agent.run(
"What is 2 multiplied by 3.6452?", additional_instruction="Remember this." "What is 2 multiplied by 3.6452?", additional_instruction="Remember this."
) )
assert "Remember this" in agent.task assert "Remember this" in agent.task