Simplify step logs

This commit is contained in:
Aymeric 2024-12-11 20:16:14 +01:00
parent 1606b9a80c
commit 0a0402d090
6 changed files with 146 additions and 104 deletions

View File

@ -79,6 +79,11 @@ class AgentGenerationError(AgentError):
pass
@dataclass
class ToolCall():
tool_name: str
tool_arguments: Any
class AgentStep:
pass
@ -86,17 +91,16 @@ class AgentStep:
@dataclass
class ActionStep(AgentStep):
tool_call: Dict[str, str] | None = None
start_time: float | None = None
step_end_time: float | None = None
iteration: int | None = None
final_answer: Any = None
error: AgentError | None = None
step_duration: float | None = None
llm_output: str | None = None
observation: str | None = None
agent_memory: List[Dict[str, str]] | None = None
rationale: str | None = None
tool_call: ToolCall | None = None
start_time: float | None = None
end_time: float | None = None
iteration: int | None = None
error: AgentError | None = None
duration: float | None = None
llm_output: str | None = None
observations: str | None = None
action_output: Any = None
@dataclass
@ -222,7 +226,6 @@ class BaseAgent:
self._toolbox.add_tool(FinalAnswerTool())
self.system_prompt = self.initialize_system_prompt()
print("SYS0:", self.system_prompt)
self.prompt_messages = None
self.logs = []
self.task = None
@ -313,15 +316,15 @@ class BaseAgent:
}
memory.append(tool_call_message)
if step_log.error is not None or step_log.observation is not None:
if step_log.error is not None or step_log.observations is not None:
if step_log.error is not None:
message_content = (
f"[OUTPUT OF STEP {i}] -> Error:\n"
+ str(step_log.error)
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
)
elif step_log.observation is not None:
message_content = f"[OUTPUT OF STEP {i}] -> Observation:\n{step_log.observation}"
elif step_log.observations is not None:
message_content = f"[OUTPUT OF STEP {i}] -> Observation:\n{step_log.observations}"
tool_response_message = {
"role": MessageRole.TOOL_RESPONSE,
"content": message_content,
@ -466,8 +469,8 @@ class ReactAgent(BaseAgent):
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg)
def step(self, log_entry: ActionStep):
"""To be implemented in children classes"""
def step(self, log_entry: ActionStep) -> Union[None, Any]:
"""To be implemented in children classes. Should return either None if the step is not final."""
pass
def run(
@ -521,8 +524,8 @@ class ReactAgent(BaseAgent):
if oneshot:
step_start_time = time.time()
step_log = ActionStep(start_time=step_start_time)
step_log.step_end_time = time.time()
step_log.step_duration = step_log.step_end_time - step_start_time
step_log.end_time = time.time()
step_log.duration = step_log.end_time - step_start_time
# Run the agent's step
result = self.step(step_log)
@ -551,14 +554,14 @@ class ReactAgent(BaseAgent):
task, is_first_step=(iteration == 0), iteration=iteration
)
console.rule("[bold]New step")
self.step(step_log)
if step_log.final_answer is not None:
final_answer = step_log.final_answer
# Run one step!
final_answer = self.step(step_log)
except AgentError as e:
step_log.error = e
finally:
step_log.step_end_time = time.time()
step_log.step_duration = step_log.step_end_time - step_start_time
step_log.end_time = time.time()
step_log.duration = step_log.end_time - step_start_time
self.logs.append(step_log)
for callback in self.step_callbacks:
callback(step_log)
@ -570,9 +573,9 @@ class ReactAgent(BaseAgent):
final_step_log = ActionStep(error=AgentMaxIterationsError(error_message))
self.logs.append(final_step_log)
final_answer = self.provide_final_answer(task)
final_step_log.final_answer = final_answer
final_step_log.step_end_time = time.time()
final_step_log.step_duration = step_log.step_end_time - step_start_time
final_step_log.action_output = final_answer
final_step_log.end_time = time.time()
final_step_log.duration = step_log.end_time - step_start_time
for callback in self.step_callbacks:
callback(final_step_log)
yield final_step_log
@ -597,15 +600,16 @@ class ReactAgent(BaseAgent):
task, is_first_step=(iteration == 0), iteration=iteration
)
console.rule("[bold]New step")
self.step(step_log)
if step_log.final_answer is not None:
final_answer = step_log.final_answer
# Run one step!
final_answer = self.step(step_log)
except AgentError as e:
step_log.error = e
finally:
step_end_time = time.time()
step_log.step_end_time = step_end_time
step_log.step_duration = step_end_time - step_start_time
step_log.end_time = step_end_time
step_log.duration = step_end_time - step_start_time
self.logs.append(step_log)
for callback in self.step_callbacks:
callback(step_log)
@ -616,8 +620,8 @@ class ReactAgent(BaseAgent):
final_step_log = ActionStep(error=AgentMaxIterationsError(error_message))
self.logs.append(final_step_log)
final_answer = self.provide_final_answer(task)
final_step_log.final_answer = final_answer
final_step_log.step_duration = 0
final_step_log.action_output = final_answer
final_step_log.duration = 0
for callback in self.step_callbacks:
callback(final_step_log)
@ -777,10 +781,10 @@ class JsonAgent(ReactAgent):
**kwargs,
)
def step(self, log_entry: ActionStep):
def step(self, log_entry: ActionStep) -> Union[None, Any]:
"""
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
The errors are raised here, they are caught and logged in the run() method.
Returns None if the step is not final.
"""
agent_memory = self.write_inner_memory_from_logs()
@ -823,8 +827,7 @@ class JsonAgent(ReactAgent):
except Exception as e:
raise AgentParsingError(f"Could not parse the given action: {e}.")
log_entry.rationale = rationale
log_entry.tool_call = {"tool_name": tool_name, "tool_arguments": arguments}
log_entry.tool_call = ToolCall(tool_name=tool_name, tool_arguments=arguments)
# Execute
console.rule("Agent thoughts:")
@ -835,15 +838,15 @@ class JsonAgent(ReactAgent):
if isinstance(arguments, dict):
if "answer" in arguments:
answer = arguments["answer"]
else:
answer = arguments
else:
answer = arguments
if (
isinstance(answer, str) and answer in self.state.keys()
): # if the answer is a state variable, return the value
answer = self.state[answer]
else:
answer = arguments
else:
answer = arguments
log_entry.final_answer = answer
log_entry.action_output = answer
return answer
else:
if arguments is None:
@ -861,8 +864,8 @@ class JsonAgent(ReactAgent):
updated_information = f"Stored '{observation_name}' in memory."
else:
updated_information = str(observation).strip()
log_entry.observation = updated_information
return log_entry
log_entry.observations = updated_information
return None
class CodeAgent(ReactAgent):
@ -906,16 +909,15 @@ class CodeAgent(ReactAgent):
self.authorized_imports = list(
set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)
)
print("SYSS:", self.system_prompt)
self.system_prompt = self.system_prompt.replace(
"{{authorized_imports}}", str(self.authorized_imports)
)
self.custom_tools = {}
def step(self, log_entry: ActionStep):
def step(self, log_entry: ActionStep) -> Union[None, Any]:
"""
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
The errors are raised here, they are caught and logged in the run() method.
Returns None if the step is not final.
"""
agent_memory = self.write_inner_memory_from_logs()
@ -967,11 +969,7 @@ class CodeAgent(ReactAgent):
error_msg = f"Error in code parsing: {e}. Make sure to provide correct code"
raise AgentParsingError(error_msg)
log_entry.rationale = rationale
log_entry.tool_call = {
"tool_name": "code interpreter",
"tool_arguments": code_action,
}
log_entry.tool_call = ToolCall(tool_name="python_interpreter", tool_arguments=code_action)
# Execute
if self.verbose:
@ -988,7 +986,7 @@ class CodeAgent(ReactAgent):
}
if self.managed_agents is not None:
static_tools = {**static_tools, **self.managed_agents}
result = self.python_evaluator(
output = self.python_evaluator(
code_action,
static_tools=static_tools,
custom_tools=self.custom_tools,
@ -998,13 +996,13 @@ class CodeAgent(ReactAgent):
console.print("Print outputs:")
console.print(self.state["print_outputs"])
observation = "Print outputs:\n" + self.state["print_outputs"]
if result is not None:
if output is not None:
console.rule("Last output from code snippet:", align="left")
console.print(str(result))
console.print(str(output))
observation += "Last output from code snippet:\n" + truncate_content(
str(result)
str(output)
)
log_entry.observation = observation
log_entry.observations = observation
except Exception as e:
error_msg = f"Code execution failed due to the following error:\n{str(e)}"
if "'dict' object has no attribute 'read'" in str(e):
@ -1013,9 +1011,10 @@ class CodeAgent(ReactAgent):
for line in code_action.split("\n"):
if line[: len("final_answer")] == "final_answer":
console.print("Final answer:")
console.print(f"[bold]{result}")
log_entry.final_answer = result
return result
console.print(f"[bold]{output}")
log_entry.action_output = output
return output
return None
class ManagedAgent:

View File

@ -126,7 +126,9 @@ def get_remote_tools(logger, organization="huggingface-tools"):
class PythonInterpreterTool(Tool):
name = "python_interpreter"
description = "This is a tool that evaluates python code. It can be used to perform calculations."
inputs = {
"code": {"type": "string", "description": "The python code to run in interpreter"}
}
output_type = "string"
def __init__(self, *args, authorized_imports=None, **kwargs):
@ -147,7 +149,7 @@ class PythonInterpreterTool(Tool):
}
super().__init__(*args, **kwargs)
def forward(self, code):
def forward(self, code: str) -> str:
output = str(
evaluate_python_code(
code,

View File

@ -22,20 +22,20 @@ import gradio as gr
def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
"""Extract ChatMessage objects from agent steps"""
if isinstance(step_log, ActionStep):
yield gr.ChatMessage(role="assistant", content=step_log.rationale)
yield gr.ChatMessage(role="assistant", content=step_log.llm_output)
if step_log.tool_call is not None:
used_code = step_log.tool_call["tool_name"] == "code interpreter"
content = step_log.tool_call["tool_arguments"]
used_code = step_log.tool_call.tool_name == "code interpreter"
content = step_log.tool_call.tool_arguments
if used_code:
content = f"```py\n{content}\n```"
yield gr.ChatMessage(
role="assistant",
metadata={"title": f"🛠️ Used tool {step_log.tool_call['tool_name']}"},
metadata={"title": f"🛠️ Used tool {step_log.tool_call.tool_name}"},
content=str(content),
)
if step_log.observation is not None:
if step_log.observations is not None:
yield gr.ChatMessage(
role="assistant", content=f"```\n{step_log.observation}\n```"
role="assistant", content=f"```\n{step_log.observations}\n```"
)
if step_log.error is not None:
yield gr.ChatMessage(

View File

@ -29,7 +29,7 @@ class Monitor:
self.total_output_token_count = 0
def update_metrics(self, step_log):
step_duration = step_log.step_duration
step_duration = step_log.duration
self.step_durations.append(step_duration)
console.print(f"Step {len(self.step_durations)}:")
console.print(f"- Time taken: {step_duration:.2f} seconds")

View File

@ -150,7 +150,7 @@ class Tool:
name: str
description: str
inputs: Dict[str, Dict[str, Union[str, type]]]
output_type: type
output_type: str
def __init__(self, *args, **kwargs):
self.is_initialized = False

View File

@ -16,9 +16,10 @@ import os
import tempfile
import unittest
import uuid
import pytest
from pathlib import Path
from agents.agent_types import AgentText
from agents.agents import (
AgentMaxIterationsError,
@ -26,16 +27,18 @@ from agents.agents import (
CodeAgent,
JsonAgent,
Toolbox,
ToolCall
)
from agents.tools import tool
from agents.default_tools import PythonInterpreterTool
from transformers.testing_utils import get_tests_dir
def get_new_path(suffix="") -> str:
directory = tempfile.mkdtemp()
return os.path.join(directory, str(uuid.uuid4()) + suffix)
def fake_react_json_llm(messages, stop_sequences=None, grammar=None) -> str:
def fake_json_llm(messages, stop_sequences=None, grammar=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
@ -57,8 +60,29 @@ Action:
}
"""
def fake_json_llm_image(messages, stop_sequences=None, grammar=None) -> str:
prompt = str(messages)
def fake_react_code_llm(messages, stop_sequences=None, grammar=None) -> str:
if "special_marker" not in prompt:
return """
Thought: I should generate an image. special_marker
Action:
{
"action": "fake_image_generation_tool",
"action_input": {"prompt": "An image of a cat"}
}
"""
else: # We're at step 2
return """
Thought: I can now answer the initial question
Action:
{
"action": "final_answer",
"action_input": "image.png"
}
"""
def fake_code_llm(messages, stop_sequences=None, grammar=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
return """
@ -78,7 +102,7 @@ final_answer(7.2904)
"""
def fake_react_code_llm_error(messages, stop_sequences=None) -> str:
def fake_code_llm_error(messages, stop_sequences=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
return """
@ -98,7 +122,7 @@ final_answer("got an error")
"""
def fake_react_code_functiondef(messages, stop_sequences=None) -> str:
def fake_code_functiondef(messages, stop_sequences=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
return """
@ -146,27 +170,23 @@ print(result)
class AgentTests(unittest.TestCase):
def test_fake_code_agent(self):
def test_fake_oneshot_code_agent(self):
agent = CodeAgent(
tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_oneshot
)
output = agent.run("What is 2 multiplied by 3.6452?")
output = agent.run("What is 2 multiplied by 3.6452?", oneshot=True)
assert isinstance(output, str)
assert output == "7.2904"
def test_fake_react_json_agent(self):
def test_fake_json_agent(self):
agent = JsonAgent(
tools=[PythonInterpreterTool()], llm_engine=fake_react_json_llm
tools=[PythonInterpreterTool()], llm_engine=fake_json_llm
)
output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, str)
assert output == "7.2904"
assert agent.logs[1].task == "What is 2 multiplied by 3.6452?"
assert agent.logs[2].observation == "7.2904"
assert (
agent.logs[2].rationale.strip()
== "Thought: I should multiply 2 by 3.6452. special_marker"
)
assert agent.logs[2].observations == "7.2904"
assert (
agent.logs[3].llm_output
== """
@ -179,22 +199,43 @@ Action:
"""
)
def test_fake_react_code_agent(self):
def test_json_agent_handles_image_tool_outputs(self):
from PIL import Image
@tool
def fake_image_generation_tool(prompt: str) -> Image.Image:
"""Tool that generates an image.
Args:
prompt: The prompt
"""
return Image.open(
Path(get_tests_dir("fixtures")) / "000000039769.png"
)
agent = JsonAgent(
tools=[fake_image_generation_tool], llm_engine=fake_json_llm_image
)
output = agent.run("Make me an image.")
assert isinstance(output, Image.Image)
assert isinstance(agent.state["image.png"], Image.Image)
def test_fake_code_agent(self):
agent = CodeAgent(
tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm
tools=[PythonInterpreterTool()], llm_engine=fake_code_llm
)
output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, float)
assert output == 7.2904
assert agent.logs[1].task == "What is 2 multiplied by 3.6452?"
assert agent.logs[3].tool_call == {
"tool_arguments": "final_answer(7.2904)",
"tool_name": "code interpreter",
}
assert agent.logs[3].tool_call == ToolCall(
tool_name="python_interpreter",
tool_arguments="final_answer(7.2904)",
)
def test_react_code_agent_code_errors_show_offending_lines(self):
def test_code_agent_code_errors_show_offending_lines(self):
agent = CodeAgent(
tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm_error
tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_error
)
output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, AgentText)
@ -202,9 +243,9 @@ Action:
assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs)
def test_setup_agent_with_empty_toolbox(self):
JsonAgent(llm_engine=fake_react_json_llm, tools=[])
JsonAgent(llm_engine=fake_json_llm, tools=[])
def test_react_fails_max_iterations(self):
def test_fails_max_iterations(self):
agent = CodeAgent(
tools=[PythonInterpreterTool()],
llm_engine=fake_code_llm_no_return, # use this callable because it never ends
@ -216,19 +257,19 @@ Action:
def test_init_agent_with_different_toolsets(self):
toolset_1 = []
agent = CodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm)
agent = CodeAgent(tools=toolset_1, llm_engine=fake_code_llm)
assert (
len(agent.toolbox.tools) == 1
) # when no tools are provided, only the final_answer tool is added by default
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
agent = CodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm)
agent = CodeAgent(tools=toolset_2, llm_engine=fake_code_llm)
assert (
len(agent.toolbox.tools) == 2
) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer
toolset_3 = Toolbox(toolset_2)
agent = CodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm)
agent = CodeAgent(tools=toolset_3, llm_engine=fake_code_llm)
assert (
len(agent.toolbox.tools) == 2
) # same as previous one, where toolset_3 is an instantiation of previous one
@ -236,12 +277,12 @@ Action:
# check that add_base_tools will not interfere with existing tools
with pytest.raises(KeyError) as e:
agent = JsonAgent(
tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True
tools=toolset_3, llm_engine=fake_json_llm, add_base_tools=True
)
assert "already exists in the toolbox" in str(e)
# check that python_interpreter base tool does not get added to code agents
agent = CodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
agent = CodeAgent(tools=[], llm_engine=fake_code_llm, add_base_tools=True)
assert (
len(agent.toolbox.tools) == 2
) # added final_answer tool + search
@ -249,7 +290,7 @@ Action:
def test_function_persistence_across_steps(self):
agent = CodeAgent(
tools=[],
llm_engine=fake_react_code_functiondef,
llm_engine=fake_code_functiondef,
max_iterations=2,
additional_authorized_imports=["numpy"],
)
@ -257,17 +298,17 @@ Action:
assert res[0] == 0.5
def test_init_managed_agent(self):
agent = CodeAgent(tools=[], llm_engine=fake_react_code_functiondef)
agent = CodeAgent(tools=[], llm_engine=fake_code_functiondef)
managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
assert managed_agent.name == "managed_agent"
assert managed_agent.description == "Empty"
def test_agent_description_gets_correctly_inserted_in_system_prompt(self):
agent = CodeAgent(tools=[], llm_engine=fake_react_code_functiondef)
agent = CodeAgent(tools=[], llm_engine=fake_code_functiondef)
managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
manager_agent = CodeAgent(
tools=[],
llm_engine=fake_react_code_functiondef,
llm_engine=fake_code_functiondef,
managed_agents=[managed_agent],
)
assert "You can also give requests to team members." not in agent.system_prompt