Simplify step logs

This commit is contained in:
Aymeric 2024-12-11 20:16:14 +01:00
parent 1606b9a80c
commit 0a0402d090
6 changed files with 146 additions and 104 deletions

View File

@ -79,6 +79,11 @@ class AgentGenerationError(AgentError):
pass pass
@dataclass
class ToolCall():
tool_name: str
tool_arguments: Any
class AgentStep: class AgentStep:
pass pass
@ -86,17 +91,16 @@ class AgentStep:
@dataclass @dataclass
class ActionStep(AgentStep): class ActionStep(AgentStep):
tool_call: Dict[str, str] | None = None
start_time: float | None = None
step_end_time: float | None = None
iteration: int | None = None
final_answer: Any = None
error: AgentError | None = None
step_duration: float | None = None
llm_output: str | None = None
observation: str | None = None
agent_memory: List[Dict[str, str]] | None = None agent_memory: List[Dict[str, str]] | None = None
rationale: str | None = None tool_call: ToolCall | None = None
start_time: float | None = None
end_time: float | None = None
iteration: int | None = None
error: AgentError | None = None
duration: float | None = None
llm_output: str | None = None
observations: str | None = None
action_output: Any = None
@dataclass @dataclass
@ -222,7 +226,6 @@ class BaseAgent:
self._toolbox.add_tool(FinalAnswerTool()) self._toolbox.add_tool(FinalAnswerTool())
self.system_prompt = self.initialize_system_prompt() self.system_prompt = self.initialize_system_prompt()
print("SYS0:", self.system_prompt)
self.prompt_messages = None self.prompt_messages = None
self.logs = [] self.logs = []
self.task = None self.task = None
@ -313,15 +316,15 @@ class BaseAgent:
} }
memory.append(tool_call_message) memory.append(tool_call_message)
if step_log.error is not None or step_log.observation is not None: if step_log.error is not None or step_log.observations is not None:
if step_log.error is not None: if step_log.error is not None:
message_content = ( message_content = (
f"[OUTPUT OF STEP {i}] -> Error:\n" f"[OUTPUT OF STEP {i}] -> Error:\n"
+ str(step_log.error) + str(step_log.error)
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n" + "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
) )
elif step_log.observation is not None: elif step_log.observations is not None:
message_content = f"[OUTPUT OF STEP {i}] -> Observation:\n{step_log.observation}" message_content = f"[OUTPUT OF STEP {i}] -> Observation:\n{step_log.observations}"
tool_response_message = { tool_response_message = {
"role": MessageRole.TOOL_RESPONSE, "role": MessageRole.TOOL_RESPONSE,
"content": message_content, "content": message_content,
@ -466,8 +469,8 @@ class ReactAgent(BaseAgent):
console.print(f"[bold red]{error_msg}") console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg) raise AgentExecutionError(error_msg)
def step(self, log_entry: ActionStep): def step(self, log_entry: ActionStep) -> Union[None, Any]:
"""To be implemented in children classes""" """To be implemented in children classes. Should return either None if the step is not final."""
pass pass
def run( def run(
@ -521,8 +524,8 @@ class ReactAgent(BaseAgent):
if oneshot: if oneshot:
step_start_time = time.time() step_start_time = time.time()
step_log = ActionStep(start_time=step_start_time) step_log = ActionStep(start_time=step_start_time)
step_log.step_end_time = time.time() step_log.end_time = time.time()
step_log.step_duration = step_log.step_end_time - step_start_time step_log.duration = step_log.end_time - step_start_time
# Run the agent's step # Run the agent's step
result = self.step(step_log) result = self.step(step_log)
@ -551,14 +554,14 @@ class ReactAgent(BaseAgent):
task, is_first_step=(iteration == 0), iteration=iteration task, is_first_step=(iteration == 0), iteration=iteration
) )
console.rule("[bold]New step") console.rule("[bold]New step")
self.step(step_log)
if step_log.final_answer is not None: # Run one step!
final_answer = step_log.final_answer final_answer = self.step(step_log)
except AgentError as e: except AgentError as e:
step_log.error = e step_log.error = e
finally: finally:
step_log.step_end_time = time.time() step_log.end_time = time.time()
step_log.step_duration = step_log.step_end_time - step_start_time step_log.duration = step_log.end_time - step_start_time
self.logs.append(step_log) self.logs.append(step_log)
for callback in self.step_callbacks: for callback in self.step_callbacks:
callback(step_log) callback(step_log)
@ -570,9 +573,9 @@ class ReactAgent(BaseAgent):
final_step_log = ActionStep(error=AgentMaxIterationsError(error_message)) final_step_log = ActionStep(error=AgentMaxIterationsError(error_message))
self.logs.append(final_step_log) self.logs.append(final_step_log)
final_answer = self.provide_final_answer(task) final_answer = self.provide_final_answer(task)
final_step_log.final_answer = final_answer final_step_log.action_output = final_answer
final_step_log.step_end_time = time.time() final_step_log.end_time = time.time()
final_step_log.step_duration = step_log.step_end_time - step_start_time final_step_log.duration = step_log.end_time - step_start_time
for callback in self.step_callbacks: for callback in self.step_callbacks:
callback(final_step_log) callback(final_step_log)
yield final_step_log yield final_step_log
@ -597,15 +600,16 @@ class ReactAgent(BaseAgent):
task, is_first_step=(iteration == 0), iteration=iteration task, is_first_step=(iteration == 0), iteration=iteration
) )
console.rule("[bold]New step") console.rule("[bold]New step")
self.step(step_log)
if step_log.final_answer is not None: # Run one step!
final_answer = step_log.final_answer final_answer = self.step(step_log)
except AgentError as e: except AgentError as e:
step_log.error = e step_log.error = e
finally: finally:
step_end_time = time.time() step_end_time = time.time()
step_log.step_end_time = step_end_time step_log.end_time = step_end_time
step_log.step_duration = step_end_time - step_start_time step_log.duration = step_end_time - step_start_time
self.logs.append(step_log) self.logs.append(step_log)
for callback in self.step_callbacks: for callback in self.step_callbacks:
callback(step_log) callback(step_log)
@ -616,8 +620,8 @@ class ReactAgent(BaseAgent):
final_step_log = ActionStep(error=AgentMaxIterationsError(error_message)) final_step_log = ActionStep(error=AgentMaxIterationsError(error_message))
self.logs.append(final_step_log) self.logs.append(final_step_log)
final_answer = self.provide_final_answer(task) final_answer = self.provide_final_answer(task)
final_step_log.final_answer = final_answer final_step_log.action_output = final_answer
final_step_log.step_duration = 0 final_step_log.duration = 0
for callback in self.step_callbacks: for callback in self.step_callbacks:
callback(final_step_log) callback(final_step_log)
@ -777,10 +781,10 @@ class JsonAgent(ReactAgent):
**kwargs, **kwargs,
) )
def step(self, log_entry: ActionStep): def step(self, log_entry: ActionStep) -> Union[None, Any]:
""" """
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
The errors are raised here, they are caught and logged in the run() method. Returns None if the step is not final.
""" """
agent_memory = self.write_inner_memory_from_logs() agent_memory = self.write_inner_memory_from_logs()
@ -823,8 +827,7 @@ class JsonAgent(ReactAgent):
except Exception as e: except Exception as e:
raise AgentParsingError(f"Could not parse the given action: {e}.") raise AgentParsingError(f"Could not parse the given action: {e}.")
log_entry.rationale = rationale log_entry.tool_call = ToolCall(tool_name=tool_name, tool_arguments=arguments)
log_entry.tool_call = {"tool_name": tool_name, "tool_arguments": arguments}
# Execute # Execute
console.rule("Agent thoughts:") console.rule("Agent thoughts:")
@ -835,15 +838,15 @@ class JsonAgent(ReactAgent):
if isinstance(arguments, dict): if isinstance(arguments, dict):
if "answer" in arguments: if "answer" in arguments:
answer = arguments["answer"] answer = arguments["answer"]
if (
isinstance(answer, str) and answer in self.state.keys()
): # if the answer is a state variable, return the value
answer = self.state[answer]
else: else:
answer = arguments answer = arguments
else: else:
answer = arguments answer = arguments
log_entry.final_answer = answer if (
isinstance(answer, str) and answer in self.state.keys()
): # if the answer is a state variable, return the value
answer = self.state[answer]
log_entry.action_output = answer
return answer return answer
else: else:
if arguments is None: if arguments is None:
@ -861,8 +864,8 @@ class JsonAgent(ReactAgent):
updated_information = f"Stored '{observation_name}' in memory." updated_information = f"Stored '{observation_name}' in memory."
else: else:
updated_information = str(observation).strip() updated_information = str(observation).strip()
log_entry.observation = updated_information log_entry.observations = updated_information
return log_entry return None
class CodeAgent(ReactAgent): class CodeAgent(ReactAgent):
@ -906,16 +909,15 @@ class CodeAgent(ReactAgent):
self.authorized_imports = list( self.authorized_imports = list(
set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports) set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)
) )
print("SYSS:", self.system_prompt)
self.system_prompt = self.system_prompt.replace( self.system_prompt = self.system_prompt.replace(
"{{authorized_imports}}", str(self.authorized_imports) "{{authorized_imports}}", str(self.authorized_imports)
) )
self.custom_tools = {} self.custom_tools = {}
def step(self, log_entry: ActionStep): def step(self, log_entry: ActionStep) -> Union[None, Any]:
""" """
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
The errors are raised here, they are caught and logged in the run() method. Returns None if the step is not final.
""" """
agent_memory = self.write_inner_memory_from_logs() agent_memory = self.write_inner_memory_from_logs()
@ -967,11 +969,7 @@ class CodeAgent(ReactAgent):
error_msg = f"Error in code parsing: {e}. Make sure to provide correct code" error_msg = f"Error in code parsing: {e}. Make sure to provide correct code"
raise AgentParsingError(error_msg) raise AgentParsingError(error_msg)
log_entry.rationale = rationale log_entry.tool_call = ToolCall(tool_name="python_interpreter", tool_arguments=code_action)
log_entry.tool_call = {
"tool_name": "code interpreter",
"tool_arguments": code_action,
}
# Execute # Execute
if self.verbose: if self.verbose:
@ -988,7 +986,7 @@ class CodeAgent(ReactAgent):
} }
if self.managed_agents is not None: if self.managed_agents is not None:
static_tools = {**static_tools, **self.managed_agents} static_tools = {**static_tools, **self.managed_agents}
result = self.python_evaluator( output = self.python_evaluator(
code_action, code_action,
static_tools=static_tools, static_tools=static_tools,
custom_tools=self.custom_tools, custom_tools=self.custom_tools,
@ -998,13 +996,13 @@ class CodeAgent(ReactAgent):
console.print("Print outputs:") console.print("Print outputs:")
console.print(self.state["print_outputs"]) console.print(self.state["print_outputs"])
observation = "Print outputs:\n" + self.state["print_outputs"] observation = "Print outputs:\n" + self.state["print_outputs"]
if result is not None: if output is not None:
console.rule("Last output from code snippet:", align="left") console.rule("Last output from code snippet:", align="left")
console.print(str(result)) console.print(str(output))
observation += "Last output from code snippet:\n" + truncate_content( observation += "Last output from code snippet:\n" + truncate_content(
str(result) str(output)
) )
log_entry.observation = observation log_entry.observations = observation
except Exception as e: except Exception as e:
error_msg = f"Code execution failed due to the following error:\n{str(e)}" error_msg = f"Code execution failed due to the following error:\n{str(e)}"
if "'dict' object has no attribute 'read'" in str(e): if "'dict' object has no attribute 'read'" in str(e):
@ -1013,9 +1011,10 @@ class CodeAgent(ReactAgent):
for line in code_action.split("\n"): for line in code_action.split("\n"):
if line[: len("final_answer")] == "final_answer": if line[: len("final_answer")] == "final_answer":
console.print("Final answer:") console.print("Final answer:")
console.print(f"[bold]{result}") console.print(f"[bold]{output}")
log_entry.final_answer = result log_entry.action_output = output
return result return output
return None
class ManagedAgent: class ManagedAgent:

View File

@ -126,7 +126,9 @@ def get_remote_tools(logger, organization="huggingface-tools"):
class PythonInterpreterTool(Tool): class PythonInterpreterTool(Tool):
name = "python_interpreter" name = "python_interpreter"
description = "This is a tool that evaluates python code. It can be used to perform calculations." description = "This is a tool that evaluates python code. It can be used to perform calculations."
inputs = {
"code": {"type": "string", "description": "The python code to run in interpreter"}
}
output_type = "string" output_type = "string"
def __init__(self, *args, authorized_imports=None, **kwargs): def __init__(self, *args, authorized_imports=None, **kwargs):
@ -147,7 +149,7 @@ class PythonInterpreterTool(Tool):
} }
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def forward(self, code): def forward(self, code: str) -> str:
output = str( output = str(
evaluate_python_code( evaluate_python_code(
code, code,

View File

@ -22,20 +22,20 @@ import gradio as gr
def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True): def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
"""Extract ChatMessage objects from agent steps""" """Extract ChatMessage objects from agent steps"""
if isinstance(step_log, ActionStep): if isinstance(step_log, ActionStep):
yield gr.ChatMessage(role="assistant", content=step_log.rationale) yield gr.ChatMessage(role="assistant", content=step_log.llm_output)
if step_log.tool_call is not None: if step_log.tool_call is not None:
used_code = step_log.tool_call["tool_name"] == "code interpreter" used_code = step_log.tool_call.tool_name == "code interpreter"
content = step_log.tool_call["tool_arguments"] content = step_log.tool_call.tool_arguments
if used_code: if used_code:
content = f"```py\n{content}\n```" content = f"```py\n{content}\n```"
yield gr.ChatMessage( yield gr.ChatMessage(
role="assistant", role="assistant",
metadata={"title": f"🛠️ Used tool {step_log.tool_call['tool_name']}"}, metadata={"title": f"🛠️ Used tool {step_log.tool_call.tool_name}"},
content=str(content), content=str(content),
) )
if step_log.observation is not None: if step_log.observations is not None:
yield gr.ChatMessage( yield gr.ChatMessage(
role="assistant", content=f"```\n{step_log.observation}\n```" role="assistant", content=f"```\n{step_log.observations}\n```"
) )
if step_log.error is not None: if step_log.error is not None:
yield gr.ChatMessage( yield gr.ChatMessage(

View File

@ -29,7 +29,7 @@ class Monitor:
self.total_output_token_count = 0 self.total_output_token_count = 0
def update_metrics(self, step_log): def update_metrics(self, step_log):
step_duration = step_log.step_duration step_duration = step_log.duration
self.step_durations.append(step_duration) self.step_durations.append(step_duration)
console.print(f"Step {len(self.step_durations)}:") console.print(f"Step {len(self.step_durations)}:")
console.print(f"- Time taken: {step_duration:.2f} seconds") console.print(f"- Time taken: {step_duration:.2f} seconds")

View File

@ -150,7 +150,7 @@ class Tool:
name: str name: str
description: str description: str
inputs: Dict[str, Dict[str, Union[str, type]]] inputs: Dict[str, Dict[str, Union[str, type]]]
output_type: type output_type: str
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.is_initialized = False self.is_initialized = False

View File

@ -16,9 +16,10 @@ import os
import tempfile import tempfile
import unittest import unittest
import uuid import uuid
import pytest import pytest
from pathlib import Path
from agents.agent_types import AgentText from agents.agent_types import AgentText
from agents.agents import ( from agents.agents import (
AgentMaxIterationsError, AgentMaxIterationsError,
@ -26,16 +27,18 @@ from agents.agents import (
CodeAgent, CodeAgent,
JsonAgent, JsonAgent,
Toolbox, Toolbox,
ToolCall
) )
from agents.tools import tool
from agents.default_tools import PythonInterpreterTool from agents.default_tools import PythonInterpreterTool
from transformers.testing_utils import get_tests_dir
def get_new_path(suffix="") -> str: def get_new_path(suffix="") -> str:
directory = tempfile.mkdtemp() directory = tempfile.mkdtemp()
return os.path.join(directory, str(uuid.uuid4()) + suffix) return os.path.join(directory, str(uuid.uuid4()) + suffix)
def fake_react_json_llm(messages, stop_sequences=None, grammar=None) -> str: def fake_json_llm(messages, stop_sequences=None, grammar=None) -> str:
prompt = str(messages) prompt = str(messages)
if "special_marker" not in prompt: if "special_marker" not in prompt:
@ -57,8 +60,29 @@ Action:
} }
""" """
def fake_json_llm_image(messages, stop_sequences=None, grammar=None) -> str:
prompt = str(messages)
def fake_react_code_llm(messages, stop_sequences=None, grammar=None) -> str: if "special_marker" not in prompt:
return """
Thought: I should generate an image. special_marker
Action:
{
"action": "fake_image_generation_tool",
"action_input": {"prompt": "An image of a cat"}
}
"""
else: # We're at step 2
return """
Thought: I can now answer the initial question
Action:
{
"action": "final_answer",
"action_input": "image.png"
}
"""
def fake_code_llm(messages, stop_sequences=None, grammar=None) -> str:
prompt = str(messages) prompt = str(messages)
if "special_marker" not in prompt: if "special_marker" not in prompt:
return """ return """
@ -78,7 +102,7 @@ final_answer(7.2904)
""" """
def fake_react_code_llm_error(messages, stop_sequences=None) -> str: def fake_code_llm_error(messages, stop_sequences=None) -> str:
prompt = str(messages) prompt = str(messages)
if "special_marker" not in prompt: if "special_marker" not in prompt:
return """ return """
@ -98,7 +122,7 @@ final_answer("got an error")
""" """
def fake_react_code_functiondef(messages, stop_sequences=None) -> str: def fake_code_functiondef(messages, stop_sequences=None) -> str:
prompt = str(messages) prompt = str(messages)
if "special_marker" not in prompt: if "special_marker" not in prompt:
return """ return """
@ -146,27 +170,23 @@ print(result)
class AgentTests(unittest.TestCase): class AgentTests(unittest.TestCase):
def test_fake_code_agent(self): def test_fake_oneshot_code_agent(self):
agent = CodeAgent( agent = CodeAgent(
tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_oneshot tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_oneshot
) )
output = agent.run("What is 2 multiplied by 3.6452?") output = agent.run("What is 2 multiplied by 3.6452?", oneshot=True)
assert isinstance(output, str) assert isinstance(output, str)
assert output == "7.2904" assert output == "7.2904"
def test_fake_react_json_agent(self): def test_fake_json_agent(self):
agent = JsonAgent( agent = JsonAgent(
tools=[PythonInterpreterTool()], llm_engine=fake_react_json_llm tools=[PythonInterpreterTool()], llm_engine=fake_json_llm
) )
output = agent.run("What is 2 multiplied by 3.6452?") output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, str) assert isinstance(output, str)
assert output == "7.2904" assert output == "7.2904"
assert agent.logs[1].task == "What is 2 multiplied by 3.6452?" assert agent.logs[1].task == "What is 2 multiplied by 3.6452?"
assert agent.logs[2].observation == "7.2904" assert agent.logs[2].observations == "7.2904"
assert (
agent.logs[2].rationale.strip()
== "Thought: I should multiply 2 by 3.6452. special_marker"
)
assert ( assert (
agent.logs[3].llm_output agent.logs[3].llm_output
== """ == """
@ -179,22 +199,43 @@ Action:
""" """
) )
def test_fake_react_code_agent(self): def test_json_agent_handles_image_tool_outputs(self):
from PIL import Image
@tool
def fake_image_generation_tool(prompt: str) -> Image.Image:
"""Tool that generates an image.
Args:
prompt: The prompt
"""
return Image.open(
Path(get_tests_dir("fixtures")) / "000000039769.png"
)
agent = JsonAgent(
tools=[fake_image_generation_tool], llm_engine=fake_json_llm_image
)
output = agent.run("Make me an image.")
assert isinstance(output, Image.Image)
assert isinstance(agent.state["image.png"], Image.Image)
def test_fake_code_agent(self):
agent = CodeAgent( agent = CodeAgent(
tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm tools=[PythonInterpreterTool()], llm_engine=fake_code_llm
) )
output = agent.run("What is 2 multiplied by 3.6452?") output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, float) assert isinstance(output, float)
assert output == 7.2904 assert output == 7.2904
assert agent.logs[1].task == "What is 2 multiplied by 3.6452?" assert agent.logs[1].task == "What is 2 multiplied by 3.6452?"
assert agent.logs[3].tool_call == { assert agent.logs[3].tool_call == ToolCall(
"tool_arguments": "final_answer(7.2904)", tool_name="python_interpreter",
"tool_name": "code interpreter", tool_arguments="final_answer(7.2904)",
} )
def test_react_code_agent_code_errors_show_offending_lines(self): def test_code_agent_code_errors_show_offending_lines(self):
agent = CodeAgent( agent = CodeAgent(
tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm_error tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_error
) )
output = agent.run("What is 2 multiplied by 3.6452?") output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, AgentText) assert isinstance(output, AgentText)
@ -202,9 +243,9 @@ Action:
assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs) assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs)
def test_setup_agent_with_empty_toolbox(self): def test_setup_agent_with_empty_toolbox(self):
JsonAgent(llm_engine=fake_react_json_llm, tools=[]) JsonAgent(llm_engine=fake_json_llm, tools=[])
def test_react_fails_max_iterations(self): def test_fails_max_iterations(self):
agent = CodeAgent( agent = CodeAgent(
tools=[PythonInterpreterTool()], tools=[PythonInterpreterTool()],
llm_engine=fake_code_llm_no_return, # use this callable because it never ends llm_engine=fake_code_llm_no_return, # use this callable because it never ends
@ -216,19 +257,19 @@ Action:
def test_init_agent_with_different_toolsets(self): def test_init_agent_with_different_toolsets(self):
toolset_1 = [] toolset_1 = []
agent = CodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm) agent = CodeAgent(tools=toolset_1, llm_engine=fake_code_llm)
assert ( assert (
len(agent.toolbox.tools) == 1 len(agent.toolbox.tools) == 1
) # when no tools are provided, only the final_answer tool is added by default ) # when no tools are provided, only the final_answer tool is added by default
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()] toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
agent = CodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm) agent = CodeAgent(tools=toolset_2, llm_engine=fake_code_llm)
assert ( assert (
len(agent.toolbox.tools) == 2 len(agent.toolbox.tools) == 2
) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer ) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer
toolset_3 = Toolbox(toolset_2) toolset_3 = Toolbox(toolset_2)
agent = CodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm) agent = CodeAgent(tools=toolset_3, llm_engine=fake_code_llm)
assert ( assert (
len(agent.toolbox.tools) == 2 len(agent.toolbox.tools) == 2
) # same as previous one, where toolset_3 is an instantiation of previous one ) # same as previous one, where toolset_3 is an instantiation of previous one
@ -236,12 +277,12 @@ Action:
# check that add_base_tools will not interfere with existing tools # check that add_base_tools will not interfere with existing tools
with pytest.raises(KeyError) as e: with pytest.raises(KeyError) as e:
agent = JsonAgent( agent = JsonAgent(
tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True tools=toolset_3, llm_engine=fake_json_llm, add_base_tools=True
) )
assert "already exists in the toolbox" in str(e) assert "already exists in the toolbox" in str(e)
# check that python_interpreter base tool does not get added to code agents # check that python_interpreter base tool does not get added to code agents
agent = CodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True) agent = CodeAgent(tools=[], llm_engine=fake_code_llm, add_base_tools=True)
assert ( assert (
len(agent.toolbox.tools) == 2 len(agent.toolbox.tools) == 2
) # added final_answer tool + search ) # added final_answer tool + search
@ -249,7 +290,7 @@ Action:
def test_function_persistence_across_steps(self): def test_function_persistence_across_steps(self):
agent = CodeAgent( agent = CodeAgent(
tools=[], tools=[],
llm_engine=fake_react_code_functiondef, llm_engine=fake_code_functiondef,
max_iterations=2, max_iterations=2,
additional_authorized_imports=["numpy"], additional_authorized_imports=["numpy"],
) )
@ -257,17 +298,17 @@ Action:
assert res[0] == 0.5 assert res[0] == 0.5
def test_init_managed_agent(self): def test_init_managed_agent(self):
agent = CodeAgent(tools=[], llm_engine=fake_react_code_functiondef) agent = CodeAgent(tools=[], llm_engine=fake_code_functiondef)
managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty") managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
assert managed_agent.name == "managed_agent" assert managed_agent.name == "managed_agent"
assert managed_agent.description == "Empty" assert managed_agent.description == "Empty"
def test_agent_description_gets_correctly_inserted_in_system_prompt(self): def test_agent_description_gets_correctly_inserted_in_system_prompt(self):
agent = CodeAgent(tools=[], llm_engine=fake_react_code_functiondef) agent = CodeAgent(tools=[], llm_engine=fake_code_functiondef)
managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty") managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
manager_agent = CodeAgent( manager_agent = CodeAgent(
tools=[], tools=[],
llm_engine=fake_react_code_functiondef, llm_engine=fake_code_functiondef,
managed_agents=[managed_agent], managed_agents=[managed_agent],
) )
assert "You can also give requests to team members." not in agent.system_prompt assert "You can also give requests to team members." not in agent.system_prompt