From 0a0402d090c87fb816b445c7f4be2900fdc9811f Mon Sep 17 00:00:00 2001 From: Aymeric Date: Wed, 11 Dec 2024 20:16:14 +0100 Subject: [PATCH] Simplify step logs --- agents/agents.py | 119 ++++++++++++++++++++-------------------- agents/default_tools.py | 6 +- agents/gradio_ui.py | 12 ++-- agents/monitoring.py | 2 +- agents/tools.py | 2 +- tests/test_agents.py | 109 ++++++++++++++++++++++++------------ 6 files changed, 146 insertions(+), 104 deletions(-) diff --git a/agents/agents.py b/agents/agents.py index 9652257..f99e041 100644 --- a/agents/agents.py +++ b/agents/agents.py @@ -79,6 +79,11 @@ class AgentGenerationError(AgentError): pass +@dataclass +class ToolCall(): + tool_name: str + tool_arguments: Any + class AgentStep: pass @@ -86,17 +91,16 @@ class AgentStep: @dataclass class ActionStep(AgentStep): - tool_call: Dict[str, str] | None = None - start_time: float | None = None - step_end_time: float | None = None - iteration: int | None = None - final_answer: Any = None - error: AgentError | None = None - step_duration: float | None = None - llm_output: str | None = None - observation: str | None = None agent_memory: List[Dict[str, str]] | None = None - rationale: str | None = None + tool_call: ToolCall | None = None + start_time: float | None = None + end_time: float | None = None + iteration: int | None = None + error: AgentError | None = None + duration: float | None = None + llm_output: str | None = None + observations: str | None = None + action_output: Any = None @dataclass @@ -222,7 +226,6 @@ class BaseAgent: self._toolbox.add_tool(FinalAnswerTool()) self.system_prompt = self.initialize_system_prompt() - print("SYS0:", self.system_prompt) self.prompt_messages = None self.logs = [] self.task = None @@ -313,15 +316,15 @@ class BaseAgent: } memory.append(tool_call_message) - if step_log.error is not None or step_log.observation is not None: + if step_log.error is not None or step_log.observations is not None: if step_log.error is not None: message_content = ( f"[OUTPUT OF STEP {i}] -> Error:\n" + str(step_log.error) + "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n" ) - elif step_log.observation is not None: - message_content = f"[OUTPUT OF STEP {i}] -> Observation:\n{step_log.observation}" + elif step_log.observations is not None: + message_content = f"[OUTPUT OF STEP {i}] -> Observation:\n{step_log.observations}" tool_response_message = { "role": MessageRole.TOOL_RESPONSE, "content": message_content, @@ -466,8 +469,8 @@ class ReactAgent(BaseAgent): console.print(f"[bold red]{error_msg}") raise AgentExecutionError(error_msg) - def step(self, log_entry: ActionStep): - """To be implemented in children classes""" + def step(self, log_entry: ActionStep) -> Union[None, Any]: + """To be implemented in children classes. Should return either None if the step is not final.""" pass def run( @@ -521,8 +524,8 @@ class ReactAgent(BaseAgent): if oneshot: step_start_time = time.time() step_log = ActionStep(start_time=step_start_time) - step_log.step_end_time = time.time() - step_log.step_duration = step_log.step_end_time - step_start_time + step_log.end_time = time.time() + step_log.duration = step_log.end_time - step_start_time # Run the agent's step result = self.step(step_log) @@ -551,14 +554,14 @@ class ReactAgent(BaseAgent): task, is_first_step=(iteration == 0), iteration=iteration ) console.rule("[bold]New step") - self.step(step_log) - if step_log.final_answer is not None: - final_answer = step_log.final_answer + + # Run one step! + final_answer = self.step(step_log) except AgentError as e: step_log.error = e finally: - step_log.step_end_time = time.time() - step_log.step_duration = step_log.step_end_time - step_start_time + step_log.end_time = time.time() + step_log.duration = step_log.end_time - step_start_time self.logs.append(step_log) for callback in self.step_callbacks: callback(step_log) @@ -570,9 +573,9 @@ class ReactAgent(BaseAgent): final_step_log = ActionStep(error=AgentMaxIterationsError(error_message)) self.logs.append(final_step_log) final_answer = self.provide_final_answer(task) - final_step_log.final_answer = final_answer - final_step_log.step_end_time = time.time() - final_step_log.step_duration = step_log.step_end_time - step_start_time + final_step_log.action_output = final_answer + final_step_log.end_time = time.time() + final_step_log.duration = step_log.end_time - step_start_time for callback in self.step_callbacks: callback(final_step_log) yield final_step_log @@ -597,15 +600,16 @@ class ReactAgent(BaseAgent): task, is_first_step=(iteration == 0), iteration=iteration ) console.rule("[bold]New step") - self.step(step_log) - if step_log.final_answer is not None: - final_answer = step_log.final_answer + + # Run one step! + final_answer = self.step(step_log) + except AgentError as e: step_log.error = e finally: step_end_time = time.time() - step_log.step_end_time = step_end_time - step_log.step_duration = step_end_time - step_start_time + step_log.end_time = step_end_time + step_log.duration = step_end_time - step_start_time self.logs.append(step_log) for callback in self.step_callbacks: callback(step_log) @@ -616,8 +620,8 @@ class ReactAgent(BaseAgent): final_step_log = ActionStep(error=AgentMaxIterationsError(error_message)) self.logs.append(final_step_log) final_answer = self.provide_final_answer(task) - final_step_log.final_answer = final_answer - final_step_log.step_duration = 0 + final_step_log.action_output = final_answer + final_step_log.duration = 0 for callback in self.step_callbacks: callback(final_step_log) @@ -777,10 +781,10 @@ class JsonAgent(ReactAgent): **kwargs, ) - def step(self, log_entry: ActionStep): + def step(self, log_entry: ActionStep) -> Union[None, Any]: """ Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. - The errors are raised here, they are caught and logged in the run() method. + Returns None if the step is not final. """ agent_memory = self.write_inner_memory_from_logs() @@ -823,8 +827,7 @@ class JsonAgent(ReactAgent): except Exception as e: raise AgentParsingError(f"Could not parse the given action: {e}.") - log_entry.rationale = rationale - log_entry.tool_call = {"tool_name": tool_name, "tool_arguments": arguments} + log_entry.tool_call = ToolCall(tool_name=tool_name, tool_arguments=arguments) # Execute console.rule("Agent thoughts:") @@ -835,15 +838,15 @@ class JsonAgent(ReactAgent): if isinstance(arguments, dict): if "answer" in arguments: answer = arguments["answer"] - if ( - isinstance(answer, str) and answer in self.state.keys() - ): # if the answer is a state variable, return the value - answer = self.state[answer] else: answer = arguments else: answer = arguments - log_entry.final_answer = answer + if ( + isinstance(answer, str) and answer in self.state.keys() + ): # if the answer is a state variable, return the value + answer = self.state[answer] + log_entry.action_output = answer return answer else: if arguments is None: @@ -861,8 +864,8 @@ class JsonAgent(ReactAgent): updated_information = f"Stored '{observation_name}' in memory." else: updated_information = str(observation).strip() - log_entry.observation = updated_information - return log_entry + log_entry.observations = updated_information + return None class CodeAgent(ReactAgent): @@ -906,16 +909,15 @@ class CodeAgent(ReactAgent): self.authorized_imports = list( set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports) ) - print("SYSS:", self.system_prompt) self.system_prompt = self.system_prompt.replace( "{{authorized_imports}}", str(self.authorized_imports) ) self.custom_tools = {} - def step(self, log_entry: ActionStep): + def step(self, log_entry: ActionStep) -> Union[None, Any]: """ Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. - The errors are raised here, they are caught and logged in the run() method. + Returns None if the step is not final. """ agent_memory = self.write_inner_memory_from_logs() @@ -967,11 +969,7 @@ class CodeAgent(ReactAgent): error_msg = f"Error in code parsing: {e}. Make sure to provide correct code" raise AgentParsingError(error_msg) - log_entry.rationale = rationale - log_entry.tool_call = { - "tool_name": "code interpreter", - "tool_arguments": code_action, - } + log_entry.tool_call = ToolCall(tool_name="python_interpreter", tool_arguments=code_action) # Execute if self.verbose: @@ -988,7 +986,7 @@ class CodeAgent(ReactAgent): } if self.managed_agents is not None: static_tools = {**static_tools, **self.managed_agents} - result = self.python_evaluator( + output = self.python_evaluator( code_action, static_tools=static_tools, custom_tools=self.custom_tools, @@ -998,13 +996,13 @@ class CodeAgent(ReactAgent): console.print("Print outputs:") console.print(self.state["print_outputs"]) observation = "Print outputs:\n" + self.state["print_outputs"] - if result is not None: + if output is not None: console.rule("Last output from code snippet:", align="left") - console.print(str(result)) + console.print(str(output)) observation += "Last output from code snippet:\n" + truncate_content( - str(result) + str(output) ) - log_entry.observation = observation + log_entry.observations = observation except Exception as e: error_msg = f"Code execution failed due to the following error:\n{str(e)}" if "'dict' object has no attribute 'read'" in str(e): @@ -1013,9 +1011,10 @@ class CodeAgent(ReactAgent): for line in code_action.split("\n"): if line[: len("final_answer")] == "final_answer": console.print("Final answer:") - console.print(f"[bold]{result}") - log_entry.final_answer = result - return result + console.print(f"[bold]{output}") + log_entry.action_output = output + return output + return None class ManagedAgent: diff --git a/agents/default_tools.py b/agents/default_tools.py index 379f6a2..1c06e3f 100644 --- a/agents/default_tools.py +++ b/agents/default_tools.py @@ -126,7 +126,9 @@ def get_remote_tools(logger, organization="huggingface-tools"): class PythonInterpreterTool(Tool): name = "python_interpreter" description = "This is a tool that evaluates python code. It can be used to perform calculations." - + inputs = { + "code": {"type": "string", "description": "The python code to run in interpreter"} + } output_type = "string" def __init__(self, *args, authorized_imports=None, **kwargs): @@ -147,7 +149,7 @@ class PythonInterpreterTool(Tool): } super().__init__(*args, **kwargs) - def forward(self, code): + def forward(self, code: str) -> str: output = str( evaluate_python_code( code, diff --git a/agents/gradio_ui.py b/agents/gradio_ui.py index 00b192c..b7ac7ab 100644 --- a/agents/gradio_ui.py +++ b/agents/gradio_ui.py @@ -22,20 +22,20 @@ import gradio as gr def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True): """Extract ChatMessage objects from agent steps""" if isinstance(step_log, ActionStep): - yield gr.ChatMessage(role="assistant", content=step_log.rationale) + yield gr.ChatMessage(role="assistant", content=step_log.llm_output) if step_log.tool_call is not None: - used_code = step_log.tool_call["tool_name"] == "code interpreter" - content = step_log.tool_call["tool_arguments"] + used_code = step_log.tool_call.tool_name == "code interpreter" + content = step_log.tool_call.tool_arguments if used_code: content = f"```py\n{content}\n```" yield gr.ChatMessage( role="assistant", - metadata={"title": f"🛠️ Used tool {step_log.tool_call['tool_name']}"}, + metadata={"title": f"🛠️ Used tool {step_log.tool_call.tool_name}"}, content=str(content), ) - if step_log.observation is not None: + if step_log.observations is not None: yield gr.ChatMessage( - role="assistant", content=f"```\n{step_log.observation}\n```" + role="assistant", content=f"```\n{step_log.observations}\n```" ) if step_log.error is not None: yield gr.ChatMessage( diff --git a/agents/monitoring.py b/agents/monitoring.py index 58059db..9c53e44 100644 --- a/agents/monitoring.py +++ b/agents/monitoring.py @@ -29,7 +29,7 @@ class Monitor: self.total_output_token_count = 0 def update_metrics(self, step_log): - step_duration = step_log.step_duration + step_duration = step_log.duration self.step_durations.append(step_duration) console.print(f"Step {len(self.step_durations)}:") console.print(f"- Time taken: {step_duration:.2f} seconds") diff --git a/agents/tools.py b/agents/tools.py index 2b29f9c..13b0cd7 100644 --- a/agents/tools.py +++ b/agents/tools.py @@ -150,7 +150,7 @@ class Tool: name: str description: str inputs: Dict[str, Dict[str, Union[str, type]]] - output_type: type + output_type: str def __init__(self, *args, **kwargs): self.is_initialized = False diff --git a/tests/test_agents.py b/tests/test_agents.py index eaa0f56..18bc999 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -16,9 +16,10 @@ import os import tempfile import unittest import uuid - import pytest +from pathlib import Path + from agents.agent_types import AgentText from agents.agents import ( AgentMaxIterationsError, @@ -26,16 +27,18 @@ from agents.agents import ( CodeAgent, JsonAgent, Toolbox, + ToolCall ) +from agents.tools import tool from agents.default_tools import PythonInterpreterTool - +from transformers.testing_utils import get_tests_dir def get_new_path(suffix="") -> str: directory = tempfile.mkdtemp() return os.path.join(directory, str(uuid.uuid4()) + suffix) -def fake_react_json_llm(messages, stop_sequences=None, grammar=None) -> str: +def fake_json_llm(messages, stop_sequences=None, grammar=None) -> str: prompt = str(messages) if "special_marker" not in prompt: @@ -57,8 +60,29 @@ Action: } """ +def fake_json_llm_image(messages, stop_sequences=None, grammar=None) -> str: + prompt = str(messages) -def fake_react_code_llm(messages, stop_sequences=None, grammar=None) -> str: + if "special_marker" not in prompt: + return """ +Thought: I should generate an image. special_marker +Action: +{ + "action": "fake_image_generation_tool", + "action_input": {"prompt": "An image of a cat"} +} +""" + else: # We're at step 2 + return """ +Thought: I can now answer the initial question +Action: +{ + "action": "final_answer", + "action_input": "image.png" +} +""" + +def fake_code_llm(messages, stop_sequences=None, grammar=None) -> str: prompt = str(messages) if "special_marker" not in prompt: return """ @@ -78,7 +102,7 @@ final_answer(7.2904) """ -def fake_react_code_llm_error(messages, stop_sequences=None) -> str: +def fake_code_llm_error(messages, stop_sequences=None) -> str: prompt = str(messages) if "special_marker" not in prompt: return """ @@ -98,7 +122,7 @@ final_answer("got an error") """ -def fake_react_code_functiondef(messages, stop_sequences=None) -> str: +def fake_code_functiondef(messages, stop_sequences=None) -> str: prompt = str(messages) if "special_marker" not in prompt: return """ @@ -146,27 +170,23 @@ print(result) class AgentTests(unittest.TestCase): - def test_fake_code_agent(self): + def test_fake_oneshot_code_agent(self): agent = CodeAgent( tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_oneshot ) - output = agent.run("What is 2 multiplied by 3.6452?") + output = agent.run("What is 2 multiplied by 3.6452?", oneshot=True) assert isinstance(output, str) assert output == "7.2904" - def test_fake_react_json_agent(self): + def test_fake_json_agent(self): agent = JsonAgent( - tools=[PythonInterpreterTool()], llm_engine=fake_react_json_llm + tools=[PythonInterpreterTool()], llm_engine=fake_json_llm ) output = agent.run("What is 2 multiplied by 3.6452?") assert isinstance(output, str) assert output == "7.2904" assert agent.logs[1].task == "What is 2 multiplied by 3.6452?" - assert agent.logs[2].observation == "7.2904" - assert ( - agent.logs[2].rationale.strip() - == "Thought: I should multiply 2 by 3.6452. special_marker" - ) + assert agent.logs[2].observations == "7.2904" assert ( agent.logs[3].llm_output == """ @@ -179,22 +199,43 @@ Action: """ ) - def test_fake_react_code_agent(self): + def test_json_agent_handles_image_tool_outputs(self): + from PIL import Image + + @tool + def fake_image_generation_tool(prompt: str) -> Image.Image: + """Tool that generates an image. + + Args: + prompt: The prompt + """ + return Image.open( + Path(get_tests_dir("fixtures")) / "000000039769.png" + ) + + agent = JsonAgent( + tools=[fake_image_generation_tool], llm_engine=fake_json_llm_image + ) + output = agent.run("Make me an image.") + assert isinstance(output, Image.Image) + assert isinstance(agent.state["image.png"], Image.Image) + + def test_fake_code_agent(self): agent = CodeAgent( - tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm + tools=[PythonInterpreterTool()], llm_engine=fake_code_llm ) output = agent.run("What is 2 multiplied by 3.6452?") assert isinstance(output, float) assert output == 7.2904 assert agent.logs[1].task == "What is 2 multiplied by 3.6452?" - assert agent.logs[3].tool_call == { - "tool_arguments": "final_answer(7.2904)", - "tool_name": "code interpreter", - } + assert agent.logs[3].tool_call == ToolCall( + tool_name="python_interpreter", + tool_arguments="final_answer(7.2904)", + ) - def test_react_code_agent_code_errors_show_offending_lines(self): + def test_code_agent_code_errors_show_offending_lines(self): agent = CodeAgent( - tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm_error + tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_error ) output = agent.run("What is 2 multiplied by 3.6452?") assert isinstance(output, AgentText) @@ -202,9 +243,9 @@ Action: assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs) def test_setup_agent_with_empty_toolbox(self): - JsonAgent(llm_engine=fake_react_json_llm, tools=[]) + JsonAgent(llm_engine=fake_json_llm, tools=[]) - def test_react_fails_max_iterations(self): + def test_fails_max_iterations(self): agent = CodeAgent( tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_no_return, # use this callable because it never ends @@ -216,19 +257,19 @@ Action: def test_init_agent_with_different_toolsets(self): toolset_1 = [] - agent = CodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm) + agent = CodeAgent(tools=toolset_1, llm_engine=fake_code_llm) assert ( len(agent.toolbox.tools) == 1 ) # when no tools are provided, only the final_answer tool is added by default toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()] - agent = CodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm) + agent = CodeAgent(tools=toolset_2, llm_engine=fake_code_llm) assert ( len(agent.toolbox.tools) == 2 ) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer toolset_3 = Toolbox(toolset_2) - agent = CodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm) + agent = CodeAgent(tools=toolset_3, llm_engine=fake_code_llm) assert ( len(agent.toolbox.tools) == 2 ) # same as previous one, where toolset_3 is an instantiation of previous one @@ -236,12 +277,12 @@ Action: # check that add_base_tools will not interfere with existing tools with pytest.raises(KeyError) as e: agent = JsonAgent( - tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True + tools=toolset_3, llm_engine=fake_json_llm, add_base_tools=True ) assert "already exists in the toolbox" in str(e) # check that python_interpreter base tool does not get added to code agents - agent = CodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True) + agent = CodeAgent(tools=[], llm_engine=fake_code_llm, add_base_tools=True) assert ( len(agent.toolbox.tools) == 2 ) # added final_answer tool + search @@ -249,7 +290,7 @@ Action: def test_function_persistence_across_steps(self): agent = CodeAgent( tools=[], - llm_engine=fake_react_code_functiondef, + llm_engine=fake_code_functiondef, max_iterations=2, additional_authorized_imports=["numpy"], ) @@ -257,17 +298,17 @@ Action: assert res[0] == 0.5 def test_init_managed_agent(self): - agent = CodeAgent(tools=[], llm_engine=fake_react_code_functiondef) + agent = CodeAgent(tools=[], llm_engine=fake_code_functiondef) managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty") assert managed_agent.name == "managed_agent" assert managed_agent.description == "Empty" def test_agent_description_gets_correctly_inserted_in_system_prompt(self): - agent = CodeAgent(tools=[], llm_engine=fake_react_code_functiondef) + agent = CodeAgent(tools=[], llm_engine=fake_code_functiondef) managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty") manager_agent = CodeAgent( tools=[], - llm_engine=fake_react_code_functiondef, + llm_engine=fake_code_functiondef, managed_agents=[managed_agent], ) assert "You can also give requests to team members." not in agent.system_prompt