Enable support for tool calling agents
This commit is contained in:
parent
24d9cf9e3d
commit
30cb6111b3
|
@ -40,7 +40,8 @@ This library offers:
|
||||||
|
|
||||||
🤗 **Hub integrations**: you can share and load tools to/from the Hub, and more is to come!
|
🤗 **Hub integrations**: you can share and load tools to/from the Hub, and more is to come!
|
||||||
|
|
||||||
Quick demo:
|
## Quick demo
|
||||||
|
|
||||||
First install the package.
|
First install the package.
|
||||||
```bash
|
```bash
|
||||||
pip install agents
|
pip install agents
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
- local: conceptual_guides/intro_agents
|
- local: conceptual_guides/intro_agents
|
||||||
title: 🤖 An introduction to agentic systems
|
title: 🤖 An introduction to agentic systems
|
||||||
- local: conceptual_guides/react
|
- local: conceptual_guides/react
|
||||||
title: 🤔 ReAct agents
|
title: 🤔 How do Multi-step agents work?
|
||||||
- title: Examples
|
- title: Examples
|
||||||
sections:
|
sections:
|
||||||
- local: examples/text_to_sql
|
- local: examples/text_to_sql
|
||||||
|
|
|
@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License.
|
||||||
rendered properly in your Markdown viewer.
|
rendered properly in your Markdown viewer.
|
||||||
|
|
||||||
-->
|
-->
|
||||||
# ReAct agents
|
# How do multi-step agents work?
|
||||||
|
|
||||||
The ReAct framework ([Yao et al., 2022](https://huggingface.co/papers/2210.03629)) is currently the main approach to building agents.
|
The ReAct framework ([Yao et al., 2022](https://huggingface.co/papers/2210.03629)) is currently the main approach to building agents.
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ The name is based on the concatenation of two words, "Reason" and "Act." Indeed,
|
||||||
React process involves keeping a memory of past steps.
|
React process involves keeping a memory of past steps.
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> Read [Open-source LLMs as LangChain Agents](https://huggingface.co/blog/open-source-llms-as-agents) blog post to learn more about ReAct agents.
|
> Read [Open-source LLMs as LangChain Agents](https://huggingface.co/blog/open-source-llms-as-agents) blog post to learn more about multi-step agents.
|
||||||
|
|
||||||
Here is a video overview of how that works:
|
Here is a video overview of how that works:
|
||||||
|
|
||||||
|
@ -44,4 +44,4 @@ We implement two versions of JsonAgent:
|
||||||
- [`CodeAgent`] is a new type of JsonAgent that generates its tool calls as blobs of code, which works really well for LLMs that have strong coding performance.
|
- [`CodeAgent`] is a new type of JsonAgent that generates its tool calls as blobs of code, which works really well for LLMs that have strong coding performance.
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> We also provide an option to run agents in one-shot: just pass `oneshot=True` when launching the agent, like `agent.run(your_task, oneshot=True)`
|
> We also provide an option to run agents in one-shot: just pass `single_step=True` when launching the agent, like `agent.run(your_task, single_step=True)`
|
|
@ -122,14 +122,14 @@ def sql_engine(query: str) -> str:
|
||||||
|
|
||||||
Now let us create an agent that leverages this tool.
|
Now let us create an agent that leverages this tool.
|
||||||
|
|
||||||
We use the ReactCodeAgent, which is transformers.agents’ main agent class: an agent that writes actions in code and can iterate on previous output according to the ReAct framework.
|
We use the CodeAgent, which is transformers.agents’ main agent class: an agent that writes actions in code and can iterate on previous output according to the ReAct framework.
|
||||||
|
|
||||||
The llm_engine is the LLM that powers the agent system. HfEngine allows you to call LLMs using HF’s Inference API, either via Serverless or Dedicated endpoint, but you could also use any proprietary API.
|
The llm_engine is the LLM that powers the agent system. HfEngine allows you to call LLMs using HF’s Inference API, either via Serverless or Dedicated endpoint, but you could also use any proprietary API.
|
||||||
|
|
||||||
```py
|
```py
|
||||||
from transformers.agents import ReactCodeAgent, HfApiEngine
|
from transformers.agents import CodeAgent, HfApiEngine
|
||||||
|
|
||||||
agent = ReactCodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[sql_engine],
|
tools=[sql_engine],
|
||||||
llm_engine=HfApiEngine("meta-llama/Meta-Llama-3-8B-Instruct"),
|
llm_engine=HfApiEngine("meta-llama/Meta-Llama-3-8B-Instruct"),
|
||||||
)
|
)
|
||||||
|
@ -185,7 +185,7 @@ Since this request is a bit harder than the previous one, we’ll switch the LLM
|
||||||
```py
|
```py
|
||||||
sql_engine.description = updated_description
|
sql_engine.description = updated_description
|
||||||
|
|
||||||
agent = ReactCodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[sql_engine],
|
tools=[sql_engine],
|
||||||
llm_engine=HfApiEngine("Qwen/Qwen2.5-72B-Instruct"),
|
llm_engine=HfApiEngine("Qwen/Qwen2.5-72B-Instruct"),
|
||||||
)
|
)
|
||||||
|
|
|
@ -27,7 +27,7 @@ contains the API docs for the underlying classes.
|
||||||
|
|
||||||
## Agents
|
## Agents
|
||||||
|
|
||||||
Our agents inherit from [`ReactAgent`], which means they can act in multiple steps, each step consisting of one thought, then one tool call and execution. Read more in [this conceptual guide](../conceptual_guides/react).
|
Our agents inherit from [`MultiStepAgent`], which means they can act in multiple steps, each step consisting of one thought, then one tool call and execution. Read more in [this conceptual guide](../conceptual_guides/react).
|
||||||
|
|
||||||
We provide two types of agents, based on the main [`Agent`] class.
|
We provide two types of agents, based on the main [`Agent`] class.
|
||||||
- [`JsonAgent`] writes its tool calls in JSON.
|
- [`JsonAgent`] writes its tool calls in JSON.
|
||||||
|
@ -40,7 +40,7 @@ We provide two types of agents, based on the main [`Agent`] class.
|
||||||
|
|
||||||
### React agents
|
### React agents
|
||||||
|
|
||||||
[[autodoc]] ReactAgent
|
[[autodoc]] MultiStepAgent
|
||||||
|
|
||||||
[[autodoc]] JsonAgent
|
[[autodoc]] JsonAgent
|
||||||
|
|
||||||
|
|
|
@ -89,8 +89,9 @@ class AgentGenerationError(AgentError):
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolCall:
|
class ToolCall:
|
||||||
tool_name: str
|
name: str
|
||||||
tool_arguments: Any
|
arguments: Any
|
||||||
|
id: str
|
||||||
|
|
||||||
|
|
||||||
class AgentStep:
|
class AgentStep:
|
||||||
|
@ -306,26 +307,49 @@ class BaseAgent:
|
||||||
}
|
}
|
||||||
memory.append(thought_message)
|
memory.append(thought_message)
|
||||||
|
|
||||||
if step_log.tool_call is not None and summary_mode:
|
if step_log.tool_call is not None:
|
||||||
tool_call_message = {
|
tool_call_message = {
|
||||||
"role": MessageRole.ASSISTANT,
|
"role": MessageRole.ASSISTANT,
|
||||||
"content": f"[STEP {i} TOOL CALL]: "
|
"content": str(
|
||||||
+ str(step_log.tool_call).strip(),
|
[
|
||||||
|
{
|
||||||
|
"id": step_log.tool_call.id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": step_log.tool_call.name,
|
||||||
|
"arguments": step_log.tool_call.arguments,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
),
|
||||||
}
|
}
|
||||||
memory.append(tool_call_message)
|
memory.append(tool_call_message)
|
||||||
|
|
||||||
if step_log.error is not None or step_log.observations is not None:
|
if step_log.tool_call is None and step_log.error is not None:
|
||||||
|
message_content = (
|
||||||
|
"Error:\n"
|
||||||
|
+ str(step_log.error)
|
||||||
|
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
|
||||||
|
)
|
||||||
|
tool_response_message = {
|
||||||
|
"role": MessageRole.ASSISTANT,
|
||||||
|
"content": message_content,
|
||||||
|
}
|
||||||
|
if step_log.tool_call is not None and (
|
||||||
|
step_log.error is not None or step_log.observations is not None
|
||||||
|
):
|
||||||
if step_log.error is not None:
|
if step_log.error is not None:
|
||||||
message_content = (
|
message_content = (
|
||||||
f"[OUTPUT OF STEP {i}] -> Error:\n"
|
"Error:\n"
|
||||||
+ str(step_log.error)
|
+ str(step_log.error)
|
||||||
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
|
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
|
||||||
)
|
)
|
||||||
elif step_log.observations is not None:
|
elif step_log.observations is not None:
|
||||||
message_content = f"[OUTPUT OF STEP {i}] -> Observation:\n{step_log.observations}"
|
message_content = f"Observation:\n{step_log.observations}"
|
||||||
tool_response_message = {
|
tool_response_message = {
|
||||||
"role": MessageRole.TOOL_RESPONSE,
|
"role": MessageRole.TOOL_RESPONSE,
|
||||||
"content": message_content,
|
"content": f"Call id: {(step_log.tool_call.id if getattr(step_log.tool_call, 'id') else 'call_0')}\n"
|
||||||
|
+ message_content,
|
||||||
}
|
}
|
||||||
memory.append(tool_response_message)
|
memory.append(tool_response_message)
|
||||||
|
|
||||||
|
@ -362,7 +386,7 @@ class BaseAgent:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class ReactAgent(BaseAgent):
|
class MultiStepAgent(BaseAgent):
|
||||||
"""
|
"""
|
||||||
This agent that solves the given task step by step, using the ReAct framework:
|
This agent that solves the given task step by step, using the ReAct framework:
|
||||||
While the objective is not reached, the agent will perform a cycle of thinking and acting.
|
While the objective is not reached, the agent will perform a cycle of thinking and acting.
|
||||||
|
@ -474,7 +498,7 @@ class ReactAgent(BaseAgent):
|
||||||
task: str,
|
task: str,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
reset: bool = True,
|
reset: bool = True,
|
||||||
oneshot: bool = False,
|
single_step: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -484,7 +508,7 @@ class ReactAgent(BaseAgent):
|
||||||
task (`str`): The task to perform.
|
task (`str`): The task to perform.
|
||||||
stream (`bool`): Wether to run in a streaming way.
|
stream (`bool`): Wether to run in a streaming way.
|
||||||
reset (`bool`): Wether to reset the conversation or keep it going from previous run.
|
reset (`bool`): Wether to reset the conversation or keep it going from previous run.
|
||||||
oneshot (`bool`): Should the agent run in one shot or multi-step fashion?
|
single_step (`bool`): Should the agent run in one shot or multi-step fashion?
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```py
|
```py
|
||||||
|
@ -516,7 +540,7 @@ class ReactAgent(BaseAgent):
|
||||||
console.print(Group(Rule("[bold]New task", characters="="), Text(self.task)))
|
console.print(Group(Rule("[bold]New task", characters="="), Text(self.task)))
|
||||||
self.logs.append(TaskStep(task=self.task))
|
self.logs.append(TaskStep(task=self.task))
|
||||||
|
|
||||||
if oneshot:
|
if single_step:
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
step_log = ActionStep(start_time=step_start_time)
|
step_log = ActionStep(start_time=step_start_time)
|
||||||
step_log.end_time = time.time()
|
step_log.end_time = time.time()
|
||||||
|
@ -548,7 +572,7 @@ class ReactAgent(BaseAgent):
|
||||||
self.planning_step(
|
self.planning_step(
|
||||||
task, is_first_step=(iteration == 0), iteration=iteration
|
task, is_first_step=(iteration == 0), iteration=iteration
|
||||||
)
|
)
|
||||||
console.rule("[bold]New step")
|
console.rule(f"[bold]Step {iteration}")
|
||||||
|
|
||||||
# Run one step!
|
# Run one step!
|
||||||
final_answer = self.step(step_log)
|
final_answer = self.step(step_log)
|
||||||
|
@ -594,7 +618,7 @@ class ReactAgent(BaseAgent):
|
||||||
self.planning_step(
|
self.planning_step(
|
||||||
task, is_first_step=(iteration == 0), iteration=iteration
|
task, is_first_step=(iteration == 0), iteration=iteration
|
||||||
)
|
)
|
||||||
console.rule("[bold]New step")
|
console.rule(f"[bold]Step {iteration}")
|
||||||
|
|
||||||
# Run one step!
|
# Run one step!
|
||||||
final_answer = self.step(step_log)
|
final_answer = self.step(step_log)
|
||||||
|
@ -741,7 +765,7 @@ Now begin!""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class JsonAgent(ReactAgent):
|
class JsonAgent(MultiStepAgent):
|
||||||
"""
|
"""
|
||||||
In this agent, the tool calls will be formulated by the LLM in JSON format, then parsed and executed.
|
In this agent, the tool calls will be formulated by the LLM in JSON format, then parsed and executed.
|
||||||
"""
|
"""
|
||||||
|
@ -784,18 +808,6 @@ class JsonAgent(ReactAgent):
|
||||||
# Add new step in logs
|
# Add new step in logs
|
||||||
log_entry.agent_memory = agent_memory.copy()
|
log_entry.agent_memory = agent_memory.copy()
|
||||||
|
|
||||||
if self.verbose:
|
|
||||||
console.print(
|
|
||||||
Group(
|
|
||||||
Rule(
|
|
||||||
"[italic]Calling LLM engine with this last message:",
|
|
||||||
align="left",
|
|
||||||
style="orange",
|
|
||||||
),
|
|
||||||
Text(str(self.prompt_messages[-1])),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
additional_args = (
|
additional_args = (
|
||||||
{"grammar": self.grammar} if self.grammar is not None else {}
|
{"grammar": self.grammar} if self.grammar is not None else {}
|
||||||
|
@ -827,25 +839,27 @@ class JsonAgent(ReactAgent):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tool_name, arguments = self.tool_parser(action)
|
tool_name, tool_arguments = self.tool_parser(action)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AgentParsingError(f"Could not parse the given action: {e}.")
|
raise AgentParsingError(f"Could not parse the given action: {e}.")
|
||||||
|
|
||||||
log_entry.tool_call = ToolCall(tool_name=tool_name, tool_arguments=arguments)
|
log_entry.tool_call = ToolCall(
|
||||||
|
tool_name=tool_name, tool_arguments=tool_arguments
|
||||||
|
)
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
console.print(Rule("Agent thoughts:", align="left"), Text(rationale))
|
console.print(Rule("Agent thoughts:", align="left"), Text(rationale))
|
||||||
console.print(
|
console.print(
|
||||||
Panel(Text(f"Calling tool: '{tool_name}' with arguments: {arguments}"))
|
Panel(Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}"))
|
||||||
)
|
)
|
||||||
if tool_name == "final_answer":
|
if tool_name == "final_answer":
|
||||||
if isinstance(arguments, dict):
|
if isinstance(tool_arguments, dict):
|
||||||
if "answer" in arguments:
|
if "answer" in tool_arguments:
|
||||||
answer = arguments["answer"]
|
answer = tool_arguments["answer"]
|
||||||
else:
|
else:
|
||||||
answer = arguments
|
answer = tool_arguments
|
||||||
else:
|
else:
|
||||||
answer = arguments
|
answer = tool_arguments
|
||||||
if (
|
if (
|
||||||
isinstance(answer, str) and answer in self.state.keys()
|
isinstance(answer, str) and answer in self.state.keys()
|
||||||
): # if the answer is a state variable, return the value
|
): # if the answer is a state variable, return the value
|
||||||
|
@ -853,9 +867,9 @@ class JsonAgent(ReactAgent):
|
||||||
log_entry.action_output = answer
|
log_entry.action_output = answer
|
||||||
return answer
|
return answer
|
||||||
else:
|
else:
|
||||||
if arguments is None:
|
if tool_arguments is None:
|
||||||
arguments = {}
|
tool_arguments = {}
|
||||||
observation = self.execute_tool_call(tool_name, arguments)
|
observation = self.execute_tool_call(tool_name, tool_arguments)
|
||||||
observation_type = type(observation)
|
observation_type = type(observation)
|
||||||
if observation_type in [AgentImage, AgentAudio]:
|
if observation_type in [AgentImage, AgentAudio]:
|
||||||
if observation_type == AgentImage:
|
if observation_type == AgentImage:
|
||||||
|
@ -871,9 +885,10 @@ class JsonAgent(ReactAgent):
|
||||||
log_entry.observations = updated_information
|
log_entry.observations = updated_information
|
||||||
return None
|
return None
|
||||||
|
|
||||||
class ToolCallingAgent(ReactAgent):
|
|
||||||
|
class ToolCallingAgent(MultiStepAgent):
|
||||||
"""
|
"""
|
||||||
In this agent, the tool calls will be formulated and parsed using the underlying library, before execution.
|
This agent uses JSON-like tool calls, but to the difference of JsonAgents, it makes use of the underlying librarie's tool calling facilities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -912,53 +927,29 @@ class ToolCallingAgent(ReactAgent):
|
||||||
# Add new step in logs
|
# Add new step in logs
|
||||||
log_entry.agent_memory = agent_memory.copy()
|
log_entry.agent_memory = agent_memory.copy()
|
||||||
|
|
||||||
if self.verbose:
|
|
||||||
console.print(
|
|
||||||
Group(
|
|
||||||
Rule(
|
|
||||||
"[italic]Calling LLM engine with this last message:",
|
|
||||||
align="left",
|
|
||||||
style="orange",
|
|
||||||
),
|
|
||||||
Text(str(self.prompt_messages[-1])),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
llm_output = self.llm_engine(
|
tool_name, tool_arguments, tool_call_id = self.llm_engine.get_tool_call(
|
||||||
self.prompt_messages,
|
self.prompt_messages, available_tools=list(self.toolbox._tools.values())
|
||||||
)
|
)
|
||||||
log_entry.llm_output = llm_output
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AgentGenerationError(f"Error in generating llm_engine output: {e}.")
|
raise AgentGenerationError(f"Error in generating llm_engine output: {e}.")
|
||||||
|
|
||||||
if self.verbose:
|
log_entry.tool_call = ToolCall(
|
||||||
console.print(
|
name=tool_name, arguments=tool_arguments, id=tool_call_id
|
||||||
Group(
|
|
||||||
Rule(
|
|
||||||
"[italic]Output message of the LLM:",
|
|
||||||
align="left",
|
|
||||||
style="orange",
|
|
||||||
),
|
|
||||||
Text(llm_output),
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
log_entry.tool_call = ToolCall(tool_name=tool_name, tool_arguments=arguments)
|
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
console.print(Rule("Agent thoughts:", align="left"), Text(rationale))
|
|
||||||
console.print(
|
console.print(
|
||||||
Panel(Text(f"Calling tool: '{tool_name}' with arguments: {arguments}"))
|
Panel(Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}"))
|
||||||
)
|
)
|
||||||
if tool_name == "final_answer":
|
if tool_name == "final_answer":
|
||||||
if isinstance(arguments, dict):
|
if isinstance(tool_arguments, dict):
|
||||||
if "answer" in arguments:
|
if "answer" in tool_arguments:
|
||||||
answer = arguments["answer"]
|
answer = tool_arguments["answer"]
|
||||||
else:
|
else:
|
||||||
answer = arguments
|
answer = tool_arguments
|
||||||
else:
|
else:
|
||||||
answer = arguments
|
answer = tool_arguments
|
||||||
if (
|
if (
|
||||||
isinstance(answer, str) and answer in self.state.keys()
|
isinstance(answer, str) and answer in self.state.keys()
|
||||||
): # if the answer is a state variable, return the value
|
): # if the answer is a state variable, return the value
|
||||||
|
@ -966,9 +957,9 @@ class ToolCallingAgent(ReactAgent):
|
||||||
log_entry.action_output = answer
|
log_entry.action_output = answer
|
||||||
return answer
|
return answer
|
||||||
else:
|
else:
|
||||||
if arguments is None:
|
if tool_arguments is None:
|
||||||
arguments = {}
|
tool_arguments = {}
|
||||||
observation = self.execute_tool_call(tool_name, arguments)
|
observation = self.execute_tool_call(tool_name, tool_arguments)
|
||||||
observation_type = type(observation)
|
observation_type = type(observation)
|
||||||
if observation_type in [AgentImage, AgentAudio]:
|
if observation_type in [AgentImage, AgentAudio]:
|
||||||
if observation_type == AgentImage:
|
if observation_type == AgentImage:
|
||||||
|
@ -985,7 +976,7 @@ class ToolCallingAgent(ReactAgent):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class CodeAgent(ReactAgent):
|
class CodeAgent(MultiStepAgent):
|
||||||
"""
|
"""
|
||||||
In this agent, the tool calls will be formulated by the LLM in code format, then parsed and executed.
|
In this agent, the tool calls will be formulated by the LLM in code format, then parsed and executed.
|
||||||
"""
|
"""
|
||||||
|
@ -1058,18 +1049,6 @@ class CodeAgent(ReactAgent):
|
||||||
# Add new step in logs
|
# Add new step in logs
|
||||||
log_entry.agent_memory = agent_memory.copy()
|
log_entry.agent_memory = agent_memory.copy()
|
||||||
|
|
||||||
if self.verbose:
|
|
||||||
console.print(
|
|
||||||
Group(
|
|
||||||
Rule(
|
|
||||||
"[italic]Calling LLM engine with these last messages:",
|
|
||||||
align="left",
|
|
||||||
style="orange",
|
|
||||||
),
|
|
||||||
Text(str(self.prompt_messages[-2:])),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
additional_args = (
|
additional_args = (
|
||||||
{"grammar": self.grammar} if self.grammar is not None else {}
|
{"grammar": self.grammar} if self.grammar is not None else {}
|
||||||
|
@ -1220,7 +1199,7 @@ __all__ = [
|
||||||
"AgentError",
|
"AgentError",
|
||||||
"BaseAgent",
|
"BaseAgent",
|
||||||
"ManagedAgent",
|
"ManagedAgent",
|
||||||
"ReactAgent",
|
"MultiStepAgent",
|
||||||
"CodeAgent",
|
"CodeAgent",
|
||||||
"JsonAgent",
|
"JsonAgent",
|
||||||
"Toolbox",
|
"Toolbox",
|
||||||
|
|
|
@ -112,9 +112,9 @@ class FinalAnswerTool(Tool):
|
||||||
name = "final_answer"
|
name = "final_answer"
|
||||||
description = "Provides a final answer to the given problem."
|
description = "Provides a final answer to the given problem."
|
||||||
inputs = {
|
inputs = {
|
||||||
"answer": {"type": "any", "description": "The final answer to the problem"}
|
"answer": {"type": "object", "description": "The final answer to the problem"}
|
||||||
}
|
}
|
||||||
output_type = "any"
|
output_type = "object"
|
||||||
|
|
||||||
def forward(self, answer):
|
def forward(self, answer):
|
||||||
return answer
|
return answer
|
||||||
|
|
|
@ -24,13 +24,13 @@ def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
|
||||||
if isinstance(step_log, ActionStep):
|
if isinstance(step_log, ActionStep):
|
||||||
yield gr.ChatMessage(role="assistant", content=step_log.llm_output)
|
yield gr.ChatMessage(role="assistant", content=step_log.llm_output)
|
||||||
if step_log.tool_call is not None:
|
if step_log.tool_call is not None:
|
||||||
used_code = step_log.tool_call.tool_name == "code interpreter"
|
used_code = step_log.tool_call.name == "code interpreter"
|
||||||
content = step_log.tool_call.tool_arguments
|
content = step_log.tool_call.arguments
|
||||||
if used_code:
|
if used_code:
|
||||||
content = f"```py\n{content}\n```"
|
content = f"```py\n{content}\n```"
|
||||||
yield gr.ChatMessage(
|
yield gr.ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
metadata={"title": f"🛠️ Used tool {step_log.tool_call.tool_name}"},
|
metadata={"title": f"🛠️ Used tool {step_log.tool_call.name}"},
|
||||||
content=str(content),
|
content=str(content),
|
||||||
)
|
)
|
||||||
if step_log.observations is not None:
|
if step_log.observations is not None:
|
||||||
|
|
|
@ -16,14 +16,14 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from huggingface_hub import InferenceClient
|
|
||||||
|
|
||||||
from transformers import AutoTokenizer, Pipeline
|
from transformers import AutoTokenizer, Pipeline
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
from huggingface_hub import InferenceClient
|
||||||
|
|
||||||
|
from agents import Tool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -50,13 +50,37 @@ class MessageRole(str, Enum):
|
||||||
return [r.value for r in cls]
|
return [r.value for r in cls]
|
||||||
|
|
||||||
|
|
||||||
openai_role_conversions = {
|
llama_role_conversions = {
|
||||||
|
MessageRole.TOOL_CALL: MessageRole.ASSISTANT,
|
||||||
MessageRole.TOOL_RESPONSE: MessageRole.USER,
|
MessageRole.TOOL_RESPONSE: MessageRole.USER,
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_role_conversions = {
|
|
||||||
MessageRole.TOOL_RESPONSE: MessageRole.USER,
|
def get_json_schema(tool: Tool) -> Dict:
|
||||||
}
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool.name,
|
||||||
|
"description": tool.description,
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": tool.inputs,
|
||||||
|
"required": list(tool.inputs.keys()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_json_schema_anthropic(tool: Tool) -> Dict:
|
||||||
|
return {
|
||||||
|
"name": tool.name,
|
||||||
|
"description": tool.description,
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": tool.inputs,
|
||||||
|
"required": list(tool.inputs.keys()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def remove_stop_sequences(content: str, stop_sequences: List[str]) -> str:
|
def remove_stop_sequences(content: str, stop_sequences: List[str]) -> str:
|
||||||
|
@ -78,8 +102,8 @@ def get_clean_message_list(
|
||||||
final_message_list = []
|
final_message_list = []
|
||||||
message_list = deepcopy(message_list) # Avoid modifying the original list
|
message_list = deepcopy(message_list) # Avoid modifying the original list
|
||||||
for message in message_list:
|
for message in message_list:
|
||||||
if not set(message.keys()) == {"role", "content"}:
|
# if not set(message.keys()) == {"role", "content"}:
|
||||||
raise ValueError("Message should contain only 'role' and 'content' keys!")
|
# raise ValueError("Message should contain only 'role' and 'content' keys!")
|
||||||
|
|
||||||
role = message["role"]
|
role = message["role"]
|
||||||
if role not in MessageRole.roles():
|
if role not in MessageRole.roles():
|
||||||
|
@ -206,6 +230,7 @@ class HfApiEngine(HfEngine):
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
max_tokens: int = 1500,
|
max_tokens: int = 1500,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
"""Generates a text completion for the given message list"""
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
messages, role_conversions=llama_role_conversions
|
messages, role_conversions=llama_role_conversions
|
||||||
)
|
)
|
||||||
|
@ -219,7 +244,7 @@ class HfApiEngine(HfEngine):
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = self.client.chat_completion(
|
response = self.client.chat.completions.create(
|
||||||
messages, stop=stop_sequences, max_tokens=max_tokens
|
messages, stop=stop_sequences, max_tokens=max_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -228,6 +253,25 @@ class HfApiEngine(HfEngine):
|
||||||
self.last_output_token_count = response.usage.completion_tokens
|
self.last_output_token_count = response.usage.completion_tokens
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
def get_tool_call(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
available_tools: List[Tool],
|
||||||
|
):
|
||||||
|
"""Generates a tool call for the given message list"""
|
||||||
|
messages = get_clean_message_list(
|
||||||
|
messages, role_conversions=llama_role_conversions
|
||||||
|
)
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
messages=messages,
|
||||||
|
tools=[get_json_schema(tool) for tool in available_tools],
|
||||||
|
tool_choice="auto",
|
||||||
|
)
|
||||||
|
tool_call = response.choices[0].message.tool_calls[0]
|
||||||
|
self.last_input_token_count = response.usage.prompt_tokens
|
||||||
|
self.last_output_token_count = response.usage.completion_tokens
|
||||||
|
return tool_call.function.name, tool_call.function.arguments, tool_call.id
|
||||||
|
|
||||||
|
|
||||||
class TransformersEngine(HfEngine):
|
class TransformersEngine(HfEngine):
|
||||||
"""This engine uses a pre-initialized local text-generation pipeline."""
|
"""This engine uses a pre-initialized local text-generation pipeline."""
|
||||||
|
@ -305,6 +349,8 @@ class OpenAIEngine:
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
|
self.last_input_token_count = 0
|
||||||
|
self.last_output_token_count = 0
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
@ -328,6 +374,26 @@ class OpenAIEngine:
|
||||||
self.last_output_token_count = response.usage.completion_tokens
|
self.last_output_token_count = response.usage.completion_tokens
|
||||||
return response.choices[0].message.content
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
def get_tool_call(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
available_tools: List[Tool],
|
||||||
|
):
|
||||||
|
"""Generates a tool call for the given message list"""
|
||||||
|
messages = get_clean_message_list(
|
||||||
|
messages, role_conversions=llama_role_conversions
|
||||||
|
)
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=messages,
|
||||||
|
tools=[get_json_schema(tool) for tool in available_tools],
|
||||||
|
tool_choice="required",
|
||||||
|
)
|
||||||
|
tool_call = response.choices[0].message.tool_calls[0]
|
||||||
|
self.last_input_token_count = response.usage.prompt_tokens
|
||||||
|
self.last_output_token_count = response.usage.completion_tokens
|
||||||
|
return tool_call.function.name, tool_call.function.arguments, tool_call.id
|
||||||
|
|
||||||
|
|
||||||
class AnthropicEngine:
|
class AnthropicEngine:
|
||||||
def __init__(self, model_name="claude-3-5-sonnet-20240620", use_bedrock=False):
|
def __init__(self, model_name="claude-3-5-sonnet-20240620", use_bedrock=False):
|
||||||
|
@ -345,17 +411,19 @@ class AnthropicEngine:
|
||||||
self.client = Anthropic(
|
self.client = Anthropic(
|
||||||
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||||
)
|
)
|
||||||
|
self.last_input_token_count = 0
|
||||||
|
self.last_output_token_count = 0
|
||||||
|
|
||||||
def __call__(
|
def separate_messages_system_prompt(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[
|
||||||
stop_sequences: Optional[List[str]] = None,
|
Dict[
|
||||||
grammar: Optional[str] = None,
|
str,
|
||||||
max_tokens: int = 1500,
|
str,
|
||||||
) -> str:
|
]
|
||||||
messages = get_clean_message_list(
|
],
|
||||||
messages, role_conversions=openai_role_conversions
|
) -> Tuple[List[Dict[str, str]], str]:
|
||||||
)
|
"""Gets the system prompt and the rest of messages as separate elements."""
|
||||||
index_system_message, system_prompt = None, None
|
index_system_message, system_prompt = None, None
|
||||||
for index, message in enumerate(messages):
|
for index, message in enumerate(messages):
|
||||||
if message["role"] == MessageRole.SYSTEM:
|
if message["role"] == MessageRole.SYSTEM:
|
||||||
|
@ -370,7 +438,21 @@ class AnthropicEngine:
|
||||||
if len(filtered_messages) == 0:
|
if len(filtered_messages) == 0:
|
||||||
print("Error, no user message:", messages)
|
print("Error, no user message:", messages)
|
||||||
assert False
|
assert False
|
||||||
|
return filtered_messages, system_prompt
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
grammar: Optional[str] = None,
|
||||||
|
max_tokens: int = 1500,
|
||||||
|
) -> str:
|
||||||
|
messages = get_clean_message_list(
|
||||||
|
messages, role_conversions=llama_role_conversions
|
||||||
|
)
|
||||||
|
filtered_messages, system_prompt = self.separate_messages_system_prompt(
|
||||||
|
messages
|
||||||
|
)
|
||||||
response = self.client.messages.create(
|
response = self.client.messages.create(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
system=system_prompt,
|
system=system_prompt,
|
||||||
|
@ -385,6 +467,32 @@ class AnthropicEngine:
|
||||||
full_response_text += content_block.text
|
full_response_text += content_block.text
|
||||||
return full_response_text
|
return full_response_text
|
||||||
|
|
||||||
|
def get_tool_call(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
available_tools: List[Tool],
|
||||||
|
max_tokens: int = 1500,
|
||||||
|
):
|
||||||
|
"""Generates a tool call for the given message list"""
|
||||||
|
messages = get_clean_message_list(
|
||||||
|
messages, role_conversions=llama_role_conversions
|
||||||
|
)
|
||||||
|
filtered_messages, system_prompt = self.separate_messages_system_prompt(
|
||||||
|
messages
|
||||||
|
)
|
||||||
|
response = self.client.messages.create(
|
||||||
|
model=self.model_name,
|
||||||
|
system=system_prompt,
|
||||||
|
messages=filtered_messages,
|
||||||
|
tools=[get_json_schema_anthropic(tool) for tool in available_tools],
|
||||||
|
tool_choice={"type": "any"},
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
tool_call = response.content[0]
|
||||||
|
self.last_input_token_count = response.usage.input_tokens
|
||||||
|
self.last_output_token_count = response.usage.output_tokens
|
||||||
|
return tool_call.name, tool_call.input, tool_call.id
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MessageRole",
|
"MessageRole",
|
||||||
|
|
|
@ -51,7 +51,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
|
||||||
return f.read()
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
ONESHOT_CODE_SYSTEM_PROMPT = """You will be given a task to solve, your job is to come up with a series of simple commands in Python that will perform the task.
|
SINGLE_STEP_CODE_SYSTEM_PROMPT = """You will be given a task to solve, your job is to come up with a series of simple commands in Python that will perform the task.
|
||||||
To help you, I will give you access to a set of tools that you can use. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns.
|
To help you, I will give you access to a set of tools that you can use. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns.
|
||||||
You should first explain which tool you will use to perform the task and for what reason, then write the code in Python.
|
You should first explain which tool you will use to perform the task and for what reason, then write the code in Python.
|
||||||
Each instruction in Python should be a simple assignment. You can print intermediate results if it makes sense to do so.
|
Each instruction in Python should be a simple assignment. You can print intermediate results if it makes sense to do so.
|
||||||
|
@ -618,7 +618,7 @@ And even if your task resolution is not successful, please return as much contex
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"USER_PROMPT_PLAN_UPDATE",
|
"USER_PROMPT_PLAN_UPDATE",
|
||||||
"PLAN_UPDATE_FINAL_PLAN_REDACTION",
|
"PLAN_UPDATE_FINAL_PLAN_REDACTION",
|
||||||
"ONESHOT_CODE_SYSTEM_PROMPT",
|
"SINGLE_STEP_CODE_SYSTEM_PROMPT",
|
||||||
"CODE_SYSTEM_PROMPT",
|
"CODE_SYSTEM_PROMPT",
|
||||||
"JSON_SYSTEM_PROMPT",
|
"JSON_SYSTEM_PROMPT",
|
||||||
"MANAGED_AGENT_PROMPT",
|
"MANAGED_AGENT_PROMPT",
|
||||||
|
|
|
@ -114,6 +114,7 @@ AUTHORIZED_TYPES = [
|
||||||
"image",
|
"image",
|
||||||
"audio",
|
"audio",
|
||||||
"any",
|
"any",
|
||||||
|
"object",
|
||||||
]
|
]
|
||||||
|
|
||||||
CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"}
|
CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"}
|
||||||
|
|
|
@ -143,10 +143,6 @@ class ImportFinder(ast.NodeVisitor):
|
||||||
self.packages.add(base_package)
|
self.packages.add(base_package)
|
||||||
|
|
||||||
|
|
||||||
import ast
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
|
|
||||||
def get_method_source(method):
|
def get_method_source(method):
|
||||||
"""Get source code for a method, including bound methods."""
|
"""Get source code for a method, including bound methods."""
|
||||||
if isinstance(method, types.MethodType):
|
if isinstance(method, types.MethodType):
|
||||||
|
|
|
@ -150,7 +150,7 @@ final_answer(res)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def fake_code_llm_oneshot(messages, stop_sequences=None, grammar=None) -> str:
|
def fake_code_llm_single_step(messages, stop_sequences=None, grammar=None) -> str:
|
||||||
return """
|
return """
|
||||||
Thought: I should multiply 2 by 3.6452. special_marker
|
Thought: I should multiply 2 by 3.6452. special_marker
|
||||||
Code:
|
Code:
|
||||||
|
@ -173,11 +173,11 @@ print(result)
|
||||||
|
|
||||||
|
|
||||||
class AgentTests(unittest.TestCase):
|
class AgentTests(unittest.TestCase):
|
||||||
def test_fake_oneshot_code_agent(self):
|
def test_fake_single_step_code_agent(self):
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_oneshot
|
tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_single_step
|
||||||
)
|
)
|
||||||
output = agent.run("What is 2 multiplied by 3.6452?", oneshot=True)
|
output = agent.run("What is 2 multiplied by 3.6452?", single_step=True)
|
||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
assert output == "7.2904"
|
assert output == "7.2904"
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue