Simplify code by removing the dedicated CodeAgent

This commit is contained in:
Aymeric 2024-12-09 22:15:33 +01:00
parent 0c71f92039
commit 154d1e938e
9 changed files with 118 additions and 511 deletions

View File

@ -24,7 +24,7 @@ from transformers.utils import (
_import_structure = { _import_structure = {
"agents": ["Agent", "CodeAgent", "ManagedAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"], "agents": ["Agent", "CodeAgent", "ManagedAgent", "ReactAgent", "CodeAgent", "JsonAgent", "Toolbox"],
"llm_engine": ["HfApiEngine", "TransformersEngine"], "llm_engine": ["HfApiEngine", "TransformersEngine"],
"monitoring": ["stream_to_gradio"], "monitoring": ["stream_to_gradio"],
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool", "tool"], "tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool", "tool"],
@ -45,7 +45,7 @@ else:
_import_structure["translation"] = ["TranslationTool"] _import_structure["translation"] = ["TranslationTool"]
if TYPE_CHECKING: if TYPE_CHECKING:
from .agents import Agent, CodeAgent, ManagedAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox from .agents import Agent, CodeAgent, ManagedAgent, ReactAgent, CodeAgent, JsonAgent, Toolbox
from .llm_engine import HfApiEngine, TransformersEngine from .llm_engine import HfApiEngine, TransformersEngine
from .monitoring import stream_to_gradio from .monitoring import stream_to_gradio
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool, tool from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool, tool
@ -57,12 +57,12 @@ if TYPE_CHECKING:
pass pass
else: else:
from .default_tools import FinalAnswerTool, PythonInterpreterTool from .default_tools import FinalAnswerTool, PythonInterpreterTool
from .document_question_answering import DocumentQuestionAnsweringTool from .tools.document_question_answering import DocumentQuestionAnsweringTool
from .image_question_answering import ImageQuestionAnsweringTool from .tools.image_question_answering import ImageQuestionAnsweringTool
from .search import DuckDuckGoSearchTool, VisitWebpageTool from .tools.search import DuckDuckGoSearchTool, VisitWebpageTool
from .speech_to_text import SpeechToTextTool from .tools.speech_to_text import SpeechToTextTool
from .text_to_speech import TextToSpeechTool from .tools.text_to_speech import TextToSpeechTool
from .translation import TranslationTool from .tools.translation import TranslationTool
else: else:
import sys import sys

View File

@ -14,30 +14,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import re
import time import time
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Union
from dataclasses import dataclass from dataclasses import dataclass
from transformers.utils import is_torch_available from transformers.utils import is_torch_available
import logging from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content
from .utils import console, parse_code_blob, parse_json_tool_call
from .agent_types import AgentAudio, AgentImage from .agent_types import AgentAudio, AgentImage
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
from .llm_engine import HfApiEngine, MessageRole from .llm_engine import HfApiEngine, MessageRole
from .monitoring import Monitor from .monitoring import Monitor
from .prompts import ( from .prompts import (
DEFAULT_CODE_SYSTEM_PROMPT, CODE_SYSTEM_PROMPT,
DEFAULT_REACT_CODE_SYSTEM_PROMPT, JSON_SYSTEM_PROMPT,
DEFAULT_REACT_JSON_SYSTEM_PROMPT,
PLAN_UPDATE_FINAL_PLAN_REDACTION, PLAN_UPDATE_FINAL_PLAN_REDACTION,
PROMPTS_FOR_INITIAL_PLAN,
PROMPTS_FOR_PLAN_UPDATE,
SUPPORTED_PLAN_TYPES,
SYSTEM_PROMPT_FACTS, SYSTEM_PROMPT_FACTS,
SYSTEM_PROMPT_FACTS_UPDATE, SYSTEM_PROMPT_FACTS_UPDATE,
USER_PROMPT_FACTS_UPDATE, USER_PROMPT_FACTS_UPDATE,
USER_PROMPT_PLAN_UPDATE,
USER_PROMPT_PLAN,
SYSTEM_PROMPT_PLAN_UPDATE,
SYSTEM_PROMPT_PLAN,
) )
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
from .tools import ( from .tools import (
@ -48,8 +45,6 @@ from .tools import (
Toolbox, Toolbox,
) )
LENGTH_TRUNCATE_REPORTS = 10000
HUGGINGFACE_DEFAULT_TOOLS = {} HUGGINGFACE_DEFAULT_TOOLS = {}
_tools_are_initialized = False _tools_are_initialized = False
@ -91,7 +86,7 @@ class AgentGenerationError(AgentError):
class ActionStep: class ActionStep:
tool_call: str | None = None tool_call: str | None = None
start_time: float | None = None start_time: float | None = None
end_time: float | None = None step_end_time: float | None = None
iteration: int | None = None iteration: int | None = None
final_answer: Any = None final_answer: Any = None
error: AgentError | None = None error: AgentError | None = None
@ -109,11 +104,10 @@ class TaskStep:
def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str: def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str:
tool_descriptions = toolbox.show_tool_descriptions(tool_description_template) tool_descriptions = toolbox.show_tool_descriptions(tool_description_template)
prompt = prompt_template.replace("<<tool_descriptions>>", tool_descriptions) prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions)
if "<<tool_names>>" in prompt: if "{{tool_names}}" in prompt:
tool_names = [f"'{tool_name}'" for tool_name in toolbox.tools.keys()] prompt = prompt.replace("{{tool_names}}", ", ".join([f"'{tool_name}'" for tool_name in toolbox.tools.keys()]))
prompt = prompt.replace("<<tool_names>>", ", ".join(tool_names))
return prompt return prompt
@ -142,7 +136,7 @@ def format_prompt_with_imports(prompt_template: str, authorized_imports: List[st
return prompt_template.replace("<<authorized_imports>>", str(authorized_imports)) return prompt_template.replace("<<authorized_imports>>", str(authorized_imports))
class Agent: class BaseAgent:
def __init__( def __init__(
self, self,
tools: Union[List[Tool], Toolbox], tools: Union[List[Tool], Toolbox],
@ -160,7 +154,7 @@ class Agent:
monitor_metrics: bool = True, monitor_metrics: bool = True,
): ):
if system_prompt is None: if system_prompt is None:
system_prompt = DEFAULT_REACT_CODE_SYSTEM_PROMPT system_prompt = CODE_SYSTEM_PROMPT
if tool_parser is None: if tool_parser is None:
tool_parser = parse_json_tool_call tool_parser = parse_json_tool_call
self.agent_name = self.__class__.__name__ self.agent_name = self.__class__.__name__
@ -184,7 +178,7 @@ class Agent:
if not is_torch_available(): if not is_torch_available():
raise ImportError("Using the base tools requires torch to be installed.") raise ImportError("Using the base tools requires torch to be installed.")
self._toolbox.add_base_tools(add_python_interpreter=(self.__class__ == ReactJsonAgent)) self._toolbox.add_base_tools(add_python_interpreter=(self.__class__ == JsonAgent))
else: else:
self._toolbox = Toolbox(tools, add_base_tools=add_base_tools) self._toolbox = Toolbox(tools, add_base_tools=add_base_tools)
self._toolbox.add_tool(FinalAnswerTool()) self._toolbox.add_tool(FinalAnswerTool())
@ -225,7 +219,7 @@ class Agent:
self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports)) self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports))
) )
self.logs = [TaskStep(system_prompt=self.system_prompt, task=self.task)] self.logs = [TaskStep(system_prompt=self.system_prompt, task=self.task)]
console.rule("New task", characters='=') console.rule("[bold]New task", characters='=')
console.print(self.task) console.print(self.task)
def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]: def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]:
@ -372,128 +366,7 @@ class Agent:
raise NotImplementedError raise NotImplementedError
class CodeAgent(Agent): class ReactAgent(BaseAgent):
"""
A class for an agent that solves the given task using a single block of code. It plans all its actions, then executes all in one shot.
"""
def __init__(
self,
tools: List[Tool],
llm_engine: Optional[Callable] = None,
system_prompt: Optional[str] = None,
tool_description_template: Optional[str] = None,
grammar: Optional[Dict[str, str]] = None,
additional_authorized_imports: Optional[List[str]] = None,
**kwargs,
):
if llm_engine is None:
llm_engine = HfApiEngine()
if system_prompt is None:
system_prompt = DEFAULT_CODE_SYSTEM_PROMPT
if tool_description_template is None:
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
super().__init__(
tools=tools,
llm_engine=llm_engine,
system_prompt=system_prompt,
tool_description_template=tool_description_template,
grammar=grammar,
**kwargs,
)
self.python_evaluator = evaluate_python_code
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
def parse_code_blob(self, result: str) -> str:
"""
Override this method if you want to change the way the code is
cleaned in the `run` method.
"""
return parse_code_blob(result)
def run(self, task: str, return_generated_code: bool = False, **kwargs):
"""
Runs the agent for the given task.
Args:
task (`str`): The task to perform
return_generated_code (`bool`, *optional*, defaults to `False`): Whether to return the generated code instead of running it
kwargs (additional keyword arguments, *optional*):
Any keyword argument to send to the agent when evaluating the code.
Example:
```py
from transformers.agents import CodeAgent
agent = CodeAgent(tools=[])
agent.run("What is the result of 2 power 3.7384?")
```
"""
self.task = task
if len(kwargs) > 0:
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
self.state = kwargs.copy()
self.initialize_for_run()
# Run LLM
prompt_message = {"role": MessageRole.SYSTEM, "content": self.system_prompt}
task_message = {
"role": MessageRole.USER,
"content": "Task: " + self.task,
}
self.prompt = [prompt_message, task_message]
if self.verbose:
console.rule("Executing with this prompt")
console.print(self.prompt)
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_action>"], **additional_args)
# Parse
try:
rationale, code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
except Exception as e:
if self.verbose:
console.print(
f"Error in extracting action, trying to parse the whole output as code. Error trace: {e}"
)
rationale, code_action = "", llm_output
try:
code_action = self.parse_code_blob(code_action)
except Exception as e:
error_msg = f"Error in code parsing: {e}. Be sure to provide correct code"
console.print(f"[bold red]{error_msg}[/bold red]")
return error_msg
# Execute
self.log_rationale_code_action(rationale, code_action)
try:
available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools}
output = self.python_evaluator(
code_action,
static_tools=available_tools,
custom_tools={},
state=self.state,
authorized_imports=self.authorized_imports,
)
if self.verbose:
console.print(self.state["print_outputs"])
return output
except Exception as e:
error_msg = f"Error in execution: {e}. Be sure to provide correct code."
console.print(f"[bold red]{error_msg}[/bold red]")
return error_msg
class ReactAgent(Agent):
""" """
This agent that solves the given task step by step, using the ReAct framework: This agent that solves the given task step by step, using the ReAct framework:
While the objective is not reached, the agent will perform a cycle of thinking and acting. While the objective is not reached, the agent will perform a cycle of thinking and acting.
@ -507,20 +380,16 @@ class ReactAgent(Agent):
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
tool_description_template: Optional[str] = None, tool_description_template: Optional[str] = None,
grammar: Optional[Dict[str, str]] = None, grammar: Optional[Dict[str, str]] = None,
plan_type: Optional[str] = None,
planning_interval: Optional[int] = None, planning_interval: Optional[int] = None,
**kwargs, **kwargs,
): ):
if llm_engine is None: if llm_engine is None:
llm_engine = HfApiEngine() llm_engine = HfApiEngine()
if system_prompt is None: if system_prompt is None:
system_prompt = DEFAULT_REACT_CODE_SYSTEM_PROMPT system_prompt = CODE_SYSTEM_PROMPT
if tool_description_template is None: if tool_description_template is None:
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
if plan_type is None:
plan_type = SUPPORTED_PLAN_TYPES[0]
else:
assert plan_type in SUPPORTED_PLAN_TYPES, f"plan type {plan_type} is not supported"
super().__init__( super().__init__(
tools=tools, tools=tools,
llm_engine=llm_engine, llm_engine=llm_engine,
@ -530,7 +399,6 @@ class ReactAgent(Agent):
**kwargs, **kwargs,
) )
self.planning_interval = planning_interval self.planning_interval = planning_interval
self.plan_type = plan_type
def provide_final_answer(self, task) -> str: def provide_final_answer(self, task) -> str:
""" """
@ -556,7 +424,7 @@ class ReactAgent(Agent):
console.print(f"[bold red]{error_msg}[/bold red]") console.print(f"[bold red]{error_msg}[/bold red]")
return error_msg return error_msg
def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs): def run(self, task: str, stream: bool = False, reset: bool = True, oneshot: bool = False, **kwargs):
""" """
Runs the agent for the given task. Runs the agent for the given task.
@ -578,6 +446,17 @@ class ReactAgent(Agent):
self.initialize_for_run() self.initialize_for_run()
else: else:
self.logs.append(TaskStep(task=task)) self.logs.append(TaskStep(task=task))
if oneshot:
step_start_time = time.time()
step_log = ActionStep(start_time=step_start_time)
step_log.step_end_time = time.time()
step_log.step_duration = step_log.step_end_time - step_start_time
# Run the agent's step
result = self.step(step_log)
return result
if stream: if stream:
return self.stream_run(task) return self.stream_run(task)
else: else:
@ -685,11 +564,11 @@ Now begin!""",
message_system_prompt_plan = { message_system_prompt_plan = {
"role": MessageRole.SYSTEM, "role": MessageRole.SYSTEM,
"content": PROMPTS_FOR_INITIAL_PLAN[self.plan_type]["system"], "content": SYSTEM_PROMPT_PLAN,
} }
message_user_prompt_plan = { message_user_prompt_plan = {
"role": MessageRole.USER, "role": MessageRole.USER,
"content": PROMPTS_FOR_INITIAL_PLAN[self.plan_type]["user"].format( "content": USER_PROMPT_PLAN.format(
task=task, task=task,
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template), tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
managed_agents_descriptions=( managed_agents_descriptions=(
@ -732,11 +611,11 @@ Now begin!""",
# Redact updated plan # Redact updated plan
plan_update_message = { plan_update_message = {
"role": MessageRole.SYSTEM, "role": MessageRole.SYSTEM,
"content": PROMPTS_FOR_PLAN_UPDATE[self.plan_type]["system"].format(task=task), "content": SYSTEM_PROMPT_PLAN_UPDATE.format(task=task),
} }
plan_update_message_user = { plan_update_message_user = {
"role": MessageRole.USER, "role": MessageRole.USER,
"content": PROMPTS_FOR_PLAN_UPDATE[self.plan_type]["user"].format( "content": USER_PROMPT_PLAN_UPDATE.format(
task=task, task=task,
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template), tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
managed_agents_descriptions=( managed_agents_descriptions=(
@ -761,7 +640,7 @@ Now begin!""",
console.print(final_plan_redaction) console.print(final_plan_redaction)
class ReactJsonAgent(ReactAgent): class JsonAgent(ReactAgent):
""" """
This agent that solves the given task step by step, using the ReAct framework: This agent that solves the given task step by step, using the ReAct framework:
While the objective is not reached, the agent will perform a cycle of thinking and acting. While the objective is not reached, the agent will perform a cycle of thinking and acting.
@ -781,7 +660,7 @@ class ReactJsonAgent(ReactAgent):
if llm_engine is None: if llm_engine is None:
llm_engine = HfApiEngine() llm_engine = HfApiEngine()
if system_prompt is None: if system_prompt is None:
system_prompt = DEFAULT_REACT_JSON_SYSTEM_PROMPT system_prompt = JSON_SYSTEM_PROMPT
if tool_description_template is None: if tool_description_template is None:
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
super().__init__( super().__init__(
@ -808,8 +687,9 @@ class ReactJsonAgent(ReactAgent):
log_entry.agent_memory = agent_memory.copy() log_entry.agent_memory = agent_memory.copy()
if self.verbose: if self.verbose:
console.rule("Calling LLM with this last message:") console.rule("Calling LLM engine with this last message:", align="left")
console.print(self.prompt[-1]) console.print(self.prompt[-1])
console.rule()
try: try:
additional_args = {"grammar": self.grammar} if self.grammar is not None else {} additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
@ -818,12 +698,11 @@ class ReactJsonAgent(ReactAgent):
) )
except Exception as e: except Exception as e:
raise AgentGenerationError(f"Error in generating llm output: {e}.") raise AgentGenerationError(f"Error in generating llm output: {e}.")
console.rule("===== Output message of the LLM: =====") console.rule("Output message of the LLM")
console.print(llm_output) console.print(llm_output)
log_entry.llm_output = llm_output log_entry.llm_output = llm_output
# Parse # Parse
console.rule("===== Extracting action =====")
rationale, action = self.extract_action(llm_output=llm_output, split_token="Action:") rationale, action = self.extract_action(llm_output=llm_output, split_token="Action:")
try: try:
@ -835,8 +714,9 @@ class ReactJsonAgent(ReactAgent):
log_entry.tool_call = {"tool_name": tool_name, "tool_arguments": arguments} log_entry.tool_call = {"tool_name": tool_name, "tool_arguments": arguments}
# Execute # Execute
console.print("=== Agent thoughts:") console.rule("Agent thoughts:")
console.print(rationale) console.print(rationale)
console.rule()
console.print(f">>> Calling tool: '{tool_name}' with arguments: {arguments}") console.print(f">>> Calling tool: '{tool_name}' with arguments: {arguments}")
if tool_name == "final_answer": if tool_name == "final_answer":
if isinstance(arguments, dict): if isinstance(arguments, dict):
@ -872,7 +752,7 @@ class ReactJsonAgent(ReactAgent):
return log_entry return log_entry
class ReactCodeAgent(ReactAgent): class CodeAgent(ReactAgent):
""" """
This agent that solves the given task step by step, using the ReAct framework: This agent that solves the given task step by step, using the ReAct framework:
While the objective is not reached, the agent will perform a cycle of thinking and acting. While the objective is not reached, the agent will perform a cycle of thinking and acting.
@ -893,7 +773,7 @@ class ReactCodeAgent(ReactAgent):
if llm_engine is None: if llm_engine is None:
llm_engine = HfApiEngine() llm_engine = HfApiEngine()
if system_prompt is None: if system_prompt is None:
system_prompt = DEFAULT_REACT_CODE_SYSTEM_PROMPT system_prompt = CODE_SYSTEM_PROMPT
if tool_description_template is None: if tool_description_template is None:
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
super().__init__( super().__init__(
@ -926,8 +806,9 @@ class ReactCodeAgent(ReactAgent):
log_entry.agent_memory = agent_memory.copy() log_entry.agent_memory = agent_memory.copy()
if self.verbose: if self.verbose:
console.print("===== Calling LLM with these last messages: =====") console.rule("Calling LLM engine with these last messages:", align="left")
console.print(self.prompt[-2:]) console.print(self.prompt[-2:])
console.rule()
try: try:
additional_args = {"grammar": self.grammar} if self.grammar is not None else {} additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
@ -980,7 +861,7 @@ class ReactCodeAgent(ReactAgent):
if result is not None: if result is not None:
console.print("Last output from code snippet:") console.print("Last output from code snippet:")
console.print(str(result)) console.print(str(result))
observation += "Last output from code snippet:\n" + str(result)[:LENGTH_TRUNCATE_REPORTS] observation += "Last output from code snippet:\n" + truncate_content(str(result))
log_entry.observation = observation log_entry.observation = observation
except Exception as e: except Exception as e:
error_msg = f"Code execution failed due to the following error:\n{str(e)}" error_msg = f"Code execution failed due to the following error:\n{str(e)}"
@ -994,6 +875,7 @@ class ReactCodeAgent(ReactAgent):
log_entry.final_answer = result log_entry.final_answer = result
return result return result
class ManagedAgent: class ManagedAgent:
def __init__(self, agent, name, description, additional_prompting=None, provide_run_summary=False): def __init__(self, agent, name, description, additional_prompting=None, provide_run_summary=False):
self.agent = agent self.agent = agent
@ -1034,14 +916,7 @@ And even if your task resolution is not successful, please return as much contex
answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n" answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n"
for message in self.agent.write_inner_memory_from_logs(summary_mode=True): for message in self.agent.write_inner_memory_from_logs(summary_mode=True):
content = message["content"] content = message["content"]
if len(str(content)) < LENGTH_TRUNCATE_REPORTS or "[FACTS LIST]" in str(content): answer += "\n" + truncate_content(str(content)) + "\n---"
answer += "\n" + str(content) + "\n---"
else:
answer += (
"\n"
+ str(content)[:LENGTH_TRUNCATE_REPORTS]
+ "\n(...Step was truncated because too long)...\n---"
)
answer += f"\nEND OF SUMMARY OF WORK FROM AGENT '{self.name}'." answer += f"\nEND OF SUMMARY OF WORK FROM AGENT '{self.name}'."
return answer return answer
else: else:

View File

@ -48,7 +48,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
return f.read() return f.read()
DEFAULT_CODE_SYSTEM_PROMPT = """You will be given a task to solve, your job is to come up with a series of simple commands in Python that will perform the task. ONESHOT_CODE_SYSTEM_PROMPT = """You will be given a task to solve, your job is to come up with a series of simple commands in Python that will perform the task.
To help you, I will give you access to a set of tools that you can use. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns. To help you, I will give you access to a set of tools that you can use. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns.
You should first explain which tool you will use to perform the task and for what reason, then write the code in Python. You should first explain which tool you will use to perform the task and for what reason, then write the code in Python.
Each instruction in Python should be a simple assignment. You can print intermediate results if it makes sense to do so. Each instruction in Python should be a simple assignment. You can print intermediate results if it makes sense to do so.
@ -57,7 +57,7 @@ You can use imports in your code, but only from the following list of modules: <
Be sure to provide a 'Code:' token, else the run will fail. Be sure to provide a 'Code:' token, else the run will fail.
Tools: Tools:
<<tool_descriptions>> {{tool_descriptions}}
Examples: Examples:
--- ---
@ -129,10 +129,10 @@ final_answer(caption)
```<end_action> ```<end_action>
--- ---
Above example were using tools that might not exist for you. You only have acces to those Tools: Above example were using tools that might not exist for you. You only have access to these tools:
<<tool_names>> {{tool_names}}
Remember to make sure that variables you use are all defined. Remember to make sure that variables you use are all defined. In particular don't import packages!
Be sure to provide a 'Code:\n```' sequence before the code and '```<end_action>' after, else you will get an error. Be sure to provide a 'Code:\n```' sequence before the code and '```<end_action>' after, else you will get an error.
DO NOT pass the arguments as a dict as in 'answer = ask_search_agent({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = ask_search_agent(query="What is the place where James Bond lives?")'. DO NOT pass the arguments as a dict as in 'answer = ask_search_agent({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = ask_search_agent(query="What is the place where James Bond lives?")'.
@ -140,8 +140,8 @@ Now Begin! If you solve the task correctly, you will receive a reward of $1,000,
""" """
DEFAULT_REACT_JSON_SYSTEM_PROMPT = """You are an expert assistant who can solve any task using JSON tool calls. You will be given a task to solve as best you can. JSON_SYSTEM_PROMPT = """You are an expert assistant who can solve any task using JSON tool calls. You will be given a task to solve as best you can.
To do so, you have been given access to the following tools: <<tool_names>> To do so, you have been given access to the following tools: {{tool_names}}
The way you use the tools is by specifying a json blob, ending with '<end_action>'. The way you use the tools is by specifying a json blob, ending with '<end_action>'.
Specifically, this json should have an `action` key (name of the tool to use) and an `action_input` key (input to the tool). Specifically, this json should have an `action` key (name of the tool to use) and an `action_input` key (input to the tool).
@ -256,8 +256,8 @@ Action:
}<end_action> }<end_action>
Above example were using notional tools that might not exist for you. You only have acces to those tools: Above example were using notional tools that might not exist for you. You only have access to these tools:
<<tool_descriptions>> {{tool_descriptions}}
Here are the rules you should always follow to solve your task: Here are the rules you should always follow to solve your task:
1. ALWAYS provide a 'Thought:' sequence, and an 'Action:' sequence that ends with <end_action>, else you will fail. 1. ALWAYS provide a 'Thought:' sequence, and an 'Action:' sequence that ends with <end_action>, else you will fail.
@ -269,7 +269,7 @@ Now Begin! If you solve the task correctly, you will receive a reward of $1,000,
""" """
DEFAULT_REACT_CODE_SYSTEM_PROMPT = """You are an expert assistant who can solve any task using code blobs. You will be given a task to solve as best you can. CODE_SYSTEM_PROMPT = """You are an expert assistant who can solve any task using code blobs. You will be given a task to solve as best you can.
To do so, you have been given access to a list of tools: these tools are basically Python functions which you can call with code. To do so, you have been given access to a list of tools: these tools are basically Python functions which you can call with code.
To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences. To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences.
@ -348,9 +348,9 @@ pope_current_age = 85 ** 0.36
final_answer(pope_current_age) final_answer(pope_current_age)
```<end_action> ```<end_action>
Above example were using notional tools that might not exist for you. On top of performing computations in the Python code snippets that you create, you have acces to those tools (and no other tool): Above example were using notional tools that might not exist for you. On top of performing computations in the Python code snippets that you create, you only have access to these tools:
<<tool_descriptions>> {{tool_descriptions}}
<<managed_agents_descriptions>> <<managed_agents_descriptions>>
@ -473,299 +473,6 @@ After writing the final step of the plan, write the '\n<end_plan>' tag and stop
Now write your new plan below.""" Now write your new plan below."""
SYSTEM_PROMPT_PLAN_STRUCTURED = """Output a step-by-step plan to solve the task using the given tools.
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer. Each step should be structured as follows:
Step #n: {
"description": <description of what the step does and its output>
"tool": <tool to use>,
"params": {
<parameters to pass to the tool as a valid dict>
}
"output_var": <output variable name>
}
Each step must be necessary to reach the final answer. Steps should reuse outputs produced by earlier steps. The last step must be the final answer.
Below are some examples:
Example 1:
------
Inputs:
---
Task:
How many encoder blocks were in the first attention-only ML architecture published?
[FACTS LIST]:
### 1. Facts given in the task
- The paper first introduced an attention-only ML architecture.
- The specific information required is the page number where the number of encoder blocks is stated.
- No local files are provided for access.
### 2. Facts to look up
- The title and authors of the paper that first introduced an attention-only ML architecture.
- Source: Online search (e.g., Google Scholar, arXiv, or other academic databases)
- The full text of the identified paper.
- Source: Online academic repositories (e.g., arXiv, journal websites)
- The specific page number in the paper where the number of encoder blocks is mentioned.
- Source: The content of the identified paper
### 3. Facts to derive
- By identifying the correct paper and locating the specific page, we will derive the page number where the number of encoder blocks is stated.
- Logical steps: Identify the correct paper, access its content, search for the term "encoder blocks," and note the page number where this information is found.
```
[STEP 1 TOOL CALL]: {'tool_name': 'code interpreter', 'tool_arguments': '# Step 1: Identify the title and authors of the paper that first introduced an attention-only ML architecture.\nanswer = ask_search_agent(query="Can you find the title and authors of the paper that first introduced an attention-only machine learning architecture? Please provide the full citation.")\nprint(answer)'}
[OUTPUT OF STEP 1] Observation: **Title**: Attention Is All You Need
**Authors**: Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin
[STEP 2 TOOL CALL]: {'tool_name': 'code interpreter', 'tool_arguments': '# Step 1: Find the full text of the identified paper on arXiv\\npaper_url = "https://arxiv.org/pdf/1706.03762.pdf"\\nprint(paper_url)'}
[OUTPUT OF STEP 2] Observation: https://arxiv.org/pdf/1706.03762.pdf
---
Output plan:
---
Step #1: {
"description": "Open the PDF of the paper from the provided URL and search within the text of the paper for the mention of "encoder blocks"",
"tool": "inspect_file_as_text",
"params": {
"file_path": "https://arxiv.org/pdf/1706.03762.pdf",
"question": "On which page is the number of encoder blocks mentioned?"
},
"output_var": "page_number"
}
Step #2: {
"description": "Provide the final answer",
"tool": "final_answer",
"params": {
"answer": "{page_number}"
},
"output_var": ""
}
------
Example 2:
------
Inputs:
---
Task:
How many golf balls fits into a Boeing-747?
[FACTS LIST]:
### 1. Facts given in the task
- The task requires calculating the number of golf balls that fir into a Boeing-747
### 2. Facts to look up
- The volume of a golf ball
- The volume of a Boeing-747
### 3. Facts to derive
- Once the volumes are known the final answer can be calculated
---
Output plan:
---
Step #1: {
"description": "Find the volume of a Boeing-747",
"tool": "web_search",
"params": {
"query": "What is the internal volume of a Boeing-747 in cubic meters?"
},
"output_var": "boeing_volume"
}
Step #2: {
"description": "Find the volume of a standard golf ball",
"tool": "ask_search_agent",
"params": {
"query": "What is the volume of a standard golf ball in cubic centimeters?"
},
"output_var": "golf_ball_volume"
}
Step #3: {
"description": "Convert the volume of a golf ball from cubic centimeters to cubic meters. Calculate the number of golf balls that fit into the Boeing-747 by dividing the internal volume of the Boeing-747 by the volume of a golf ball.",
"tool": "python_code",
"params": {
"code": "golf_ball_volume_m3 = golf_ball_volume / 1e6\nnumber_of_golf_balls = boeing_volume / golf_ball_volume_m3"
},
"output_var": "number_of_golf_balls"
}
Step #4: {
"description": "Provide the final answer",
"tool": "final_answer",
"params": {
"answer": "{number_of_golf_balls}"
},
"output_var": ""
}
------
Above example were using tools that might not exist for you.
Your goal is to create a plan to solve the task."""
USER_PROMPT_PLAN_STRUCTURED = """
Here are your inputs:
Task:
```
{task}
```
Your plan can leverage any of these tools:
{tool_descriptions}
These tools are Python functions which you can call with code. You also have access to a Python interpreter so you can run Python code.
List of facts that you know:
```
{answer_facts}
```
Now for the given task, create a plan taking into account the list of facts.
After writing the final step of the plan, write the '\n<end_plan>' tag and stop there. Output the plan only and nothing else."""
SYSTEM_PROMPT_PLAN_UPDATE_STRUCTURED = """Output a step-by-step plan to solve the task using the given tools.
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer. Each step should be structured as follows:
Step #n: {{
"description": <description of what the step does and its output>
"tool": <tool to use>,
"params": {{
<parameters to pass to the tool as a valid dict>
}}
"output_var": <output variable name>
}}
Each step must be necessary to reach the final answer. Steps should reuse outputs produced by earlier steps. The last step must be the final answer.
Below are some examples:
Example 1:
------
Inputs:
---
Task:
How many encoder blocks were in the first attention-only ML architecture published?
[FACTS LIST]:
### 1. Facts given in the task
- The paper first introduced an attention-only ML architecture.
- The specific information required is the page number where the number of encoder blocks is stated.
- No local files are provided for access.
### 2. Facts to look up
- The title and authors of the paper that first introduced an attention-only ML architecture.
- Source: Online search (e.g., Google Scholar, arXiv, or other academic databases)
- The full text of the identified paper.
- Source: Online academic repositories (e.g., arXiv, journal websites)
- The specific page number in the paper where the number of encoder blocks is mentioned.
- Source: The content of the identified paper
### 3. Facts to derive
- By identifying the correct paper and locating the specific page, we will derive the page number where the number of encoder blocks is stated.
- Logical steps: Identify the correct paper, access its content, search for the term "encoder blocks," and note the page number where this information is found.
```
[STEP 1 TOOL CALL]: {{'tool_name': 'code interpreter', 'tool_arguments': '# Step 1: Identify the title and authors of the paper that first introduced an attention-only ML architecture.\nanswer = ask_search_agent(query="Can you find the title and authors of the paper that first introduced an attention-only machine learning architecture? Please provide the full citation.")\nprint(answer)'}}
[OUTPUT OF STEP 1] Observation: **Title**: Attention Is All You Need
**Authors**: Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin
[STEP 2 TOOL CALL]: {{'tool_name': 'code interpreter', 'tool_arguments': '# Step 1: Find the full text of the identified paper on arXiv\\npaper_url = "https://arxiv.org/pdf/1706.03762.pdf"\\nprint(paper_url)'}}
[OUTPUT OF STEP 2] Observation: https://arxiv.org/pdf/1706.03762.pdf
---
Output plan:
---
Step #1: {{
"description": "Open the PDF of the paper from the provided URL and search within the text of the paper for the mention of "encoder blocks"",
"tool": "inspect_file_as_text",
"params": {{
"file_path": "https://arxiv.org/pdf/1706.03762.pdf",
"question": "On which page is the number of encoder blocks mentioned?"
}},
"output_var": "page_number"
}}
Step #2: {{
"description": "Provide the final answer",
"tool": "final_answer",
"params": {{
"answer": "{{page_number}}"
}},
"output_var": ""
}}
------
Example 2:
------
Inputs:
---
Task:
How many golf balls fits into a Boeing-747?
[FACTS LIST]:
### 1. Facts given in the task
- The task requires calculating the number of golf balls that fir into a Boeing-747
### 2. Facts to look up
- The volume of a golf ball
- The volume of a Boeing-747
### 3. Facts to derive
- Once the volumes are known the final answer can be calculated
---
Output plan:
---
Step #1: {{
"description": "Find the volume of a Boeing-747",
"tool": "web_search",
"params": {{
"query": "What is the internal volume of a Boeing-747 in cubic meters?"
}},
"output_var": "boeing_volume"
}}
Step #2: {{
"description": "Find the volume of a standard golf ball",
"tool": "ask_search_agent",
"params": {{
"query": "What is the volume of a standard golf ball in cubic centimeters?"
}},
"output_var": "golf_ball_volume"
}}
Step #3: {{
"description": "Convert the volume of a golf ball from cubic centimeters to cubic meters. Calculate the number of golf balls that fit into the Boeing-747 by dividing the internal volume of the Boeing-747 by the volume of a golf ball.",
"tool": "python_code",
"params": {{
"code": "golf_ball_volume_m3 = golf_ball_volume / 1e6\nnumber_of_golf_balls = boeing_volume / golf_ball_volume_m3"
}},
"output_var": "number_of_golf_balls"
}}
Step #4: {{
"description": "Provide the final answer",
"tool": "final_answer",
"params": {{
"answer": "{{number_of_golf_balls}}"
}},
"output_var": ""
}}
------
Above example were using tools that might not exist for you.
Find below the record of what has been tried so far to solve it. Your goal is to create an updated plan to solve the task."""
USER_PROMPT_PLAN_UPDATE_STRUCTURED = """
Here are your inputs:
Task:
```
{task}
```
Your plan can leverage any of these tools:
{tool_descriptions}
These tools are Python functions which you can call with code. You also have access to a Python interpreter so you can run Python code.
List of facts that you know:
```
{facts_update}
```
Now for the given task, create a plan taking into account the above inputs and list of facts.
Beware that you have {remaining_steps} steps remaining.
After writing the final step of the plan, write the '\n<end_plan>' tag and stop there. Output the plan only and nothing else."""
PLAN_UPDATE_FINAL_PLAN_REDACTION = """I still need to solve the task I was given: PLAN_UPDATE_FINAL_PLAN_REDACTION = """I still need to solve the task I was given:
``` ```
{task} {task}
@ -775,15 +482,3 @@ Here is my new/updated plan of action to solve the task:
``` ```
{plan_update} {plan_update}
```""" ```"""
SUPPORTED_PLAN_TYPES = ["default", "structured"]
PROMPTS_FOR_INITIAL_PLAN = {
"default": {"system": SYSTEM_PROMPT_PLAN, "user": USER_PROMPT_PLAN},
"structured": {"system": SYSTEM_PROMPT_PLAN_STRUCTURED, "user": USER_PROMPT_PLAN_STRUCTURED},
}
PROMPTS_FOR_PLAN_UPDATE = {
"default": {"system": SYSTEM_PROMPT_PLAN_UPDATE, "user": USER_PROMPT_PLAN_UPDATE},
"structured": {"system": SYSTEM_PROMPT_PLAN_UPDATE_STRUCTURED, "user": USER_PROMPT_PLAN_UPDATE_STRUCTURED},
}

View File

@ -24,6 +24,7 @@ from typing import Any, Callable, Dict, List, Optional
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from .utils import truncate_content
class InterpreterError(ValueError): class InterpreterError(ValueError):
""" """
@ -895,9 +896,9 @@ def evaluate_python_code(
try: try:
for node in expression.body: for node in expression.body:
result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports) result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
state["print_outputs"] = truncate_print_outputs(PRINT_OUTPUTS, max_len_outputs=MAX_LEN_OUTPUT) state["print_outputs"] = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT)
return result return result
except InterpreterError as e: except InterpreterError as e:
msg = truncate_print_outputs(PRINT_OUTPUTS, max_len_outputs=MAX_LEN_OUTPUT) msg = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT)
msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}" msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
raise InterpreterError(msg) raise InterpreterError(msg)

View File

@ -28,6 +28,14 @@ def is_pygments_available():
from rich.console import Console from rich.console import Console
console = Console() console = Console()
LENGTH_TRUNCATE_REPORTS = 10000
def truncate_content(content: str, max_length: int = LENGTH_TRUNCATE_REPORTS):
if len(content) < max_length:
return content
else:
return content[:max_length//2] + "\n..._(Content was truncated because too long)_...\n---" + content[-max_length//2:]
def parse_json_blob(json_blob: str) -> Dict[str, str]: def parse_json_blob(json_blob: str) -> Dict[str, str]:
try: try:

View File

@ -220,7 +220,7 @@ An agent, or rather the LLM that drives the agent, generates an output based on
```text ```text
You will be given a task to solve as best you can. You will be given a task to solve as best you can.
You have access to the following tools: You have access to the following tools:
<<tool_descriptions>> {{tool_descriptions}}
To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences. To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences.
@ -236,7 +236,7 @@ Here are a few examples using notional tools:
{examples} {examples}
Above example were using notional tools that might not exist for you. You only have acces to those tools: Above example were using notional tools that might not exist for you. You only have acces to those tools:
<<tool_names>> {{tool_names}}
You also can perform computations in the python code you generate. You also can perform computations in the python code you generate.
Always provide a 'Thought:' and a 'Code:\n```py' sequence ending with '```<end_code>' sequence. You MUST provide at least the 'Code:' sequence to move forward. Always provide a 'Thought:' and a 'Code:\n```py' sequence ending with '```<end_code>' sequence. You MUST provide at least the 'Code:' sequence to move forward.
@ -251,7 +251,7 @@ Now Begin!
The system prompt includes: The system prompt includes:
- An *introduction* that explains how the agent should behave and what tools are. - An *introduction* that explains how the agent should behave and what tools are.
- A description of all the tools that is defined by a `<<tool_descriptions>>` token that is dynamically replaced at runtime with the tools defined/chosen by the user. - A description of all the tools that is defined by a `{{tool_descriptions}}` token that is dynamically replaced at runtime with the tools defined/chosen by the user.
- The tool description comes from the tool attributes, `name`, `description`, `inputs` and `output_type`, and a simple `jinja2` template that you can refine. - The tool description comes from the tool attributes, `name`, `description`, `inputs` and `output_type`, and a simple `jinja2` template that you can refine.
- The expected output format. - The expected output format.
@ -267,7 +267,7 @@ agent = ReactJsonAgent(tools=[PythonInterpreterTool()], system_prompt="{your_cus
``` ```
> [!WARNING] > [!WARNING]
> Please make sure to define the `<<tool_descriptions>>` string somewhere in the `template` so the agent is aware > Please make sure to define the `{{tool_descriptions}}` string somewhere in the `template` so the agent is aware
of the available tools. of the available tools.

View File

@ -1,9 +1,11 @@
from agents import load_tool, ReactCodeAgent, HfApiEngine from agents import load_tool, ReactCodeAgent, HfApiEngine
from agents.search import DuckDuckGoSearchTool
# Import tool from Hub # Import tool from Hub
image_generation_tool = load_tool("m-ric/text-to-image", cache=False) image_generation_tool = load_tool("m-ric/text-to-image", cache=False)
# Import tool from LangChain
from agents.search import DuckDuckGoSearchTool
search_tool = DuckDuckGoSearchTool() search_tool = DuckDuckGoSearchTool()
llm_engine = HfApiEngine("Qwen/Qwen2.5-72B-Instruct") llm_engine = HfApiEngine("Qwen/Qwen2.5-72B-Instruct")

View File

@ -1,5 +1,5 @@
from agents.llm_engine import TransformersEngine from agents.llm_engine import TransformersEngine
from agents import CodeAgent, ReactJsonAgent from agents import CodeAgent, JsonAgent
import requests import requests
from datetime import datetime from datetime import datetime
@ -42,7 +42,7 @@ If none of the functions can be used, point it out and refuse to answer.
If the given question lacks the parameters required by the function, also point it out. If the given question lacks the parameters required by the function, also point it out.
You have access to the following tools: You have access to the following tools:
<<tool_descriptions>> {{tool_descriptions}}
<<managed_agents_descriptions>> <<managed_agents_descriptions>>
@ -145,6 +145,6 @@ def process(self, text: str) -> Generator[str, None, None]:
yield response yield response
return return
agent = ReactJsonAgent(llm_engine = llm_engine, tools=[get_current_time, open_webbrowser, get_random_number_between, get_weather]) agent = JsonAgent(llm_engine = llm_engine, tools=[get_current_time, open_webbrowser, get_random_number_between, get_weather])
print("Agent initialized!") print("Agent initialized!")
agent.run("What's the weather like in London?") agent.run("What's the weather like in London?")

26
examples/oneshot.py Normal file
View File

@ -0,0 +1,26 @@
from agents import load_tool, CodeAgent, JsonAgent, HfApiEngine
from agents.prompts import ONESHOT_CODE_SYSTEM_PROMPT
# Import tool from Hub
image_generation_tool = load_tool("m-ric/text-to-image", cache=False)
# Import tool from LangChain
from agents.search import DuckDuckGoSearchTool
search_tool = DuckDuckGoSearchTool()
llm_engine = HfApiEngine("Qwen/Qwen2.5-Coder-32B-Instruct")
# Initialize the agent with both tools
agent = CodeAgent(
tools=[image_generation_tool, search_tool],
llm_engine=llm_engine,
system_prompt=ONESHOT_CODE_SYSTEM_PROMPT,
verbose=True
)
# Run it!
result = agent.run(
"When was Llama 3 first released?", oneshot=True
)
print(result)