Add tool calling agent example
This commit is contained in:
parent
30cb6111b3
commit
32d7bc5e06
|
@ -0,0 +1,22 @@
|
||||||
|
from agents.agents import ToolCallingAgent
|
||||||
|
from agents import tool, HfApiEngine, OpenAIEngine, AnthropicEngine
|
||||||
|
|
||||||
|
# Choose which LLM engine to use!
|
||||||
|
llm_engine = OpenAIEngine("gpt-4o")
|
||||||
|
llm_engine = AnthropicEngine()
|
||||||
|
llm_engine = HfApiEngine("meta-llama/Llama-3.3-70B-Instruct")
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def get_weather(location: str) -> str:
|
||||||
|
"""
|
||||||
|
Get weather in the next days at given location.
|
||||||
|
Secretly this tool does not care about the location, it hates the weather everywhere.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
location: the location
|
||||||
|
"""
|
||||||
|
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||||
|
|
||||||
|
agent = ToolCallingAgent(tools=[get_weather], llm_engine=llm_engine)
|
||||||
|
|
||||||
|
print(agent.run("What's the weather like in Paris?"))
|
|
@ -40,7 +40,7 @@ class DockerPythonInterpreter:
|
||||||
Execute Python code in the container and return stdout and stderr
|
Execute Python code in the container and return stdout and stderr
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if tools != None:
|
if tools is not None:
|
||||||
tool_instance = tools[0]()
|
tool_instance = tools[0]()
|
||||||
|
|
||||||
import_code = f"""
|
import_code = f"""
|
||||||
|
|
|
@ -50,7 +50,7 @@ class MessageRole(str, Enum):
|
||||||
return [r.value for r in cls]
|
return [r.value for r in cls]
|
||||||
|
|
||||||
|
|
||||||
llama_role_conversions = {
|
tool_role_conversions = {
|
||||||
MessageRole.TOOL_CALL: MessageRole.ASSISTANT,
|
MessageRole.TOOL_CALL: MessageRole.ASSISTANT,
|
||||||
MessageRole.TOOL_RESPONSE: MessageRole.USER,
|
MessageRole.TOOL_RESPONSE: MessageRole.USER,
|
||||||
}
|
}
|
||||||
|
@ -232,7 +232,7 @@ class HfApiEngine(HfEngine):
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generates a text completion for the given message list"""
|
"""Generates a text completion for the given message list"""
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
messages, role_conversions=llama_role_conversions
|
messages, role_conversions=tool_role_conversions
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send messages to the Hugging Face Inference API
|
# Send messages to the Hugging Face Inference API
|
||||||
|
@ -260,7 +260,7 @@ class HfApiEngine(HfEngine):
|
||||||
):
|
):
|
||||||
"""Generates a tool call for the given message list"""
|
"""Generates a tool call for the given message list"""
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
messages, role_conversions=llama_role_conversions
|
messages, role_conversions=tool_role_conversions
|
||||||
)
|
)
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -302,7 +302,7 @@ class TransformersEngine(HfEngine):
|
||||||
max_tokens: int = 1500,
|
max_tokens: int = 1500,
|
||||||
) -> str:
|
) -> str:
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
messages, role_conversions=llama_role_conversions
|
messages, role_conversions=tool_role_conversions
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get LLM output
|
# Get LLM output
|
||||||
|
@ -360,7 +360,7 @@ class OpenAIEngine:
|
||||||
max_tokens: int = 1500,
|
max_tokens: int = 1500,
|
||||||
) -> str:
|
) -> str:
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
messages, role_conversions=openai_role_conversions
|
messages, role_conversions=tool_role_conversions
|
||||||
)
|
)
|
||||||
|
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
|
@ -381,7 +381,7 @@ class OpenAIEngine:
|
||||||
):
|
):
|
||||||
"""Generates a tool call for the given message list"""
|
"""Generates a tool call for the given message list"""
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
messages, role_conversions=llama_role_conversions
|
messages, role_conversions=tool_role_conversions
|
||||||
)
|
)
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
|
@ -448,7 +448,7 @@ class AnthropicEngine:
|
||||||
max_tokens: int = 1500,
|
max_tokens: int = 1500,
|
||||||
) -> str:
|
) -> str:
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
messages, role_conversions=llama_role_conversions
|
messages, role_conversions=tool_role_conversions
|
||||||
)
|
)
|
||||||
filtered_messages, system_prompt = self.separate_messages_system_prompt(
|
filtered_messages, system_prompt = self.separate_messages_system_prompt(
|
||||||
messages
|
messages
|
||||||
|
@ -475,7 +475,7 @@ class AnthropicEngine:
|
||||||
):
|
):
|
||||||
"""Generates a tool call for the given message list"""
|
"""Generates a tool call for the given message list"""
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
messages, role_conversions=llama_role_conversions
|
messages, role_conversions=tool_role_conversions
|
||||||
)
|
)
|
||||||
filtered_messages, system_prompt = self.separate_messages_system_prompt(
|
filtered_messages, system_prompt = self.separate_messages_system_prompt(
|
||||||
messages
|
messages
|
||||||
|
@ -496,7 +496,7 @@ class AnthropicEngine:
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MessageRole",
|
"MessageRole",
|
||||||
"llama_role_conversions",
|
"tool_role_conversions",
|
||||||
"get_clean_message_list",
|
"get_clean_message_list",
|
||||||
"HfEngine",
|
"HfEngine",
|
||||||
"TransformersEngine",
|
"TransformersEngine",
|
||||||
|
|
|
@ -232,7 +232,7 @@ Action:
|
||||||
|
|
||||||
def test_additional_args_added_to_task(self):
|
def test_additional_args_added_to_task(self):
|
||||||
agent = CodeAgent(tools=[], llm_engine=fake_code_llm)
|
agent = CodeAgent(tools=[], llm_engine=fake_code_llm)
|
||||||
output = agent.run(
|
agent.run(
|
||||||
"What is 2 multiplied by 3.6452?", additional_instruction="Remember this."
|
"What is 2 multiplied by 3.6452?", additional_instruction="Remember this."
|
||||||
)
|
)
|
||||||
assert "Remember this" in agent.task
|
assert "Remember this" in agent.task
|
||||||
|
|
Loading…
Reference in New Issue