Simplify step logs
This commit is contained in:
parent
1606b9a80c
commit
0a0402d090
119
agents/agents.py
119
agents/agents.py
|
@ -79,6 +79,11 @@ class AgentGenerationError(AgentError):
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolCall():
|
||||||
|
tool_name: str
|
||||||
|
tool_arguments: Any
|
||||||
|
|
||||||
|
|
||||||
class AgentStep:
|
class AgentStep:
|
||||||
pass
|
pass
|
||||||
|
@ -86,17 +91,16 @@ class AgentStep:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ActionStep(AgentStep):
|
class ActionStep(AgentStep):
|
||||||
tool_call: Dict[str, str] | None = None
|
|
||||||
start_time: float | None = None
|
|
||||||
step_end_time: float | None = None
|
|
||||||
iteration: int | None = None
|
|
||||||
final_answer: Any = None
|
|
||||||
error: AgentError | None = None
|
|
||||||
step_duration: float | None = None
|
|
||||||
llm_output: str | None = None
|
|
||||||
observation: str | None = None
|
|
||||||
agent_memory: List[Dict[str, str]] | None = None
|
agent_memory: List[Dict[str, str]] | None = None
|
||||||
rationale: str | None = None
|
tool_call: ToolCall | None = None
|
||||||
|
start_time: float | None = None
|
||||||
|
end_time: float | None = None
|
||||||
|
iteration: int | None = None
|
||||||
|
error: AgentError | None = None
|
||||||
|
duration: float | None = None
|
||||||
|
llm_output: str | None = None
|
||||||
|
observations: str | None = None
|
||||||
|
action_output: Any = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -222,7 +226,6 @@ class BaseAgent:
|
||||||
self._toolbox.add_tool(FinalAnswerTool())
|
self._toolbox.add_tool(FinalAnswerTool())
|
||||||
|
|
||||||
self.system_prompt = self.initialize_system_prompt()
|
self.system_prompt = self.initialize_system_prompt()
|
||||||
print("SYS0:", self.system_prompt)
|
|
||||||
self.prompt_messages = None
|
self.prompt_messages = None
|
||||||
self.logs = []
|
self.logs = []
|
||||||
self.task = None
|
self.task = None
|
||||||
|
@ -313,15 +316,15 @@ class BaseAgent:
|
||||||
}
|
}
|
||||||
memory.append(tool_call_message)
|
memory.append(tool_call_message)
|
||||||
|
|
||||||
if step_log.error is not None or step_log.observation is not None:
|
if step_log.error is not None or step_log.observations is not None:
|
||||||
if step_log.error is not None:
|
if step_log.error is not None:
|
||||||
message_content = (
|
message_content = (
|
||||||
f"[OUTPUT OF STEP {i}] -> Error:\n"
|
f"[OUTPUT OF STEP {i}] -> Error:\n"
|
||||||
+ str(step_log.error)
|
+ str(step_log.error)
|
||||||
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
|
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
|
||||||
)
|
)
|
||||||
elif step_log.observation is not None:
|
elif step_log.observations is not None:
|
||||||
message_content = f"[OUTPUT OF STEP {i}] -> Observation:\n{step_log.observation}"
|
message_content = f"[OUTPUT OF STEP {i}] -> Observation:\n{step_log.observations}"
|
||||||
tool_response_message = {
|
tool_response_message = {
|
||||||
"role": MessageRole.TOOL_RESPONSE,
|
"role": MessageRole.TOOL_RESPONSE,
|
||||||
"content": message_content,
|
"content": message_content,
|
||||||
|
@ -466,8 +469,8 @@ class ReactAgent(BaseAgent):
|
||||||
console.print(f"[bold red]{error_msg}")
|
console.print(f"[bold red]{error_msg}")
|
||||||
raise AgentExecutionError(error_msg)
|
raise AgentExecutionError(error_msg)
|
||||||
|
|
||||||
def step(self, log_entry: ActionStep):
|
def step(self, log_entry: ActionStep) -> Union[None, Any]:
|
||||||
"""To be implemented in children classes"""
|
"""To be implemented in children classes. Should return either None if the step is not final."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
|
@ -521,8 +524,8 @@ class ReactAgent(BaseAgent):
|
||||||
if oneshot:
|
if oneshot:
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
step_log = ActionStep(start_time=step_start_time)
|
step_log = ActionStep(start_time=step_start_time)
|
||||||
step_log.step_end_time = time.time()
|
step_log.end_time = time.time()
|
||||||
step_log.step_duration = step_log.step_end_time - step_start_time
|
step_log.duration = step_log.end_time - step_start_time
|
||||||
|
|
||||||
# Run the agent's step
|
# Run the agent's step
|
||||||
result = self.step(step_log)
|
result = self.step(step_log)
|
||||||
|
@ -551,14 +554,14 @@ class ReactAgent(BaseAgent):
|
||||||
task, is_first_step=(iteration == 0), iteration=iteration
|
task, is_first_step=(iteration == 0), iteration=iteration
|
||||||
)
|
)
|
||||||
console.rule("[bold]New step")
|
console.rule("[bold]New step")
|
||||||
self.step(step_log)
|
|
||||||
if step_log.final_answer is not None:
|
# Run one step!
|
||||||
final_answer = step_log.final_answer
|
final_answer = self.step(step_log)
|
||||||
except AgentError as e:
|
except AgentError as e:
|
||||||
step_log.error = e
|
step_log.error = e
|
||||||
finally:
|
finally:
|
||||||
step_log.step_end_time = time.time()
|
step_log.end_time = time.time()
|
||||||
step_log.step_duration = step_log.step_end_time - step_start_time
|
step_log.duration = step_log.end_time - step_start_time
|
||||||
self.logs.append(step_log)
|
self.logs.append(step_log)
|
||||||
for callback in self.step_callbacks:
|
for callback in self.step_callbacks:
|
||||||
callback(step_log)
|
callback(step_log)
|
||||||
|
@ -570,9 +573,9 @@ class ReactAgent(BaseAgent):
|
||||||
final_step_log = ActionStep(error=AgentMaxIterationsError(error_message))
|
final_step_log = ActionStep(error=AgentMaxIterationsError(error_message))
|
||||||
self.logs.append(final_step_log)
|
self.logs.append(final_step_log)
|
||||||
final_answer = self.provide_final_answer(task)
|
final_answer = self.provide_final_answer(task)
|
||||||
final_step_log.final_answer = final_answer
|
final_step_log.action_output = final_answer
|
||||||
final_step_log.step_end_time = time.time()
|
final_step_log.end_time = time.time()
|
||||||
final_step_log.step_duration = step_log.step_end_time - step_start_time
|
final_step_log.duration = step_log.end_time - step_start_time
|
||||||
for callback in self.step_callbacks:
|
for callback in self.step_callbacks:
|
||||||
callback(final_step_log)
|
callback(final_step_log)
|
||||||
yield final_step_log
|
yield final_step_log
|
||||||
|
@ -597,15 +600,16 @@ class ReactAgent(BaseAgent):
|
||||||
task, is_first_step=(iteration == 0), iteration=iteration
|
task, is_first_step=(iteration == 0), iteration=iteration
|
||||||
)
|
)
|
||||||
console.rule("[bold]New step")
|
console.rule("[bold]New step")
|
||||||
self.step(step_log)
|
|
||||||
if step_log.final_answer is not None:
|
# Run one step!
|
||||||
final_answer = step_log.final_answer
|
final_answer = self.step(step_log)
|
||||||
|
|
||||||
except AgentError as e:
|
except AgentError as e:
|
||||||
step_log.error = e
|
step_log.error = e
|
||||||
finally:
|
finally:
|
||||||
step_end_time = time.time()
|
step_end_time = time.time()
|
||||||
step_log.step_end_time = step_end_time
|
step_log.end_time = step_end_time
|
||||||
step_log.step_duration = step_end_time - step_start_time
|
step_log.duration = step_end_time - step_start_time
|
||||||
self.logs.append(step_log)
|
self.logs.append(step_log)
|
||||||
for callback in self.step_callbacks:
|
for callback in self.step_callbacks:
|
||||||
callback(step_log)
|
callback(step_log)
|
||||||
|
@ -616,8 +620,8 @@ class ReactAgent(BaseAgent):
|
||||||
final_step_log = ActionStep(error=AgentMaxIterationsError(error_message))
|
final_step_log = ActionStep(error=AgentMaxIterationsError(error_message))
|
||||||
self.logs.append(final_step_log)
|
self.logs.append(final_step_log)
|
||||||
final_answer = self.provide_final_answer(task)
|
final_answer = self.provide_final_answer(task)
|
||||||
final_step_log.final_answer = final_answer
|
final_step_log.action_output = final_answer
|
||||||
final_step_log.step_duration = 0
|
final_step_log.duration = 0
|
||||||
for callback in self.step_callbacks:
|
for callback in self.step_callbacks:
|
||||||
callback(final_step_log)
|
callback(final_step_log)
|
||||||
|
|
||||||
|
@ -777,10 +781,10 @@ class JsonAgent(ReactAgent):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def step(self, log_entry: ActionStep):
|
def step(self, log_entry: ActionStep) -> Union[None, Any]:
|
||||||
"""
|
"""
|
||||||
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
|
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
|
||||||
The errors are raised here, they are caught and logged in the run() method.
|
Returns None if the step is not final.
|
||||||
"""
|
"""
|
||||||
agent_memory = self.write_inner_memory_from_logs()
|
agent_memory = self.write_inner_memory_from_logs()
|
||||||
|
|
||||||
|
@ -823,8 +827,7 @@ class JsonAgent(ReactAgent):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AgentParsingError(f"Could not parse the given action: {e}.")
|
raise AgentParsingError(f"Could not parse the given action: {e}.")
|
||||||
|
|
||||||
log_entry.rationale = rationale
|
log_entry.tool_call = ToolCall(tool_name=tool_name, tool_arguments=arguments)
|
||||||
log_entry.tool_call = {"tool_name": tool_name, "tool_arguments": arguments}
|
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
console.rule("Agent thoughts:")
|
console.rule("Agent thoughts:")
|
||||||
|
@ -835,15 +838,15 @@ class JsonAgent(ReactAgent):
|
||||||
if isinstance(arguments, dict):
|
if isinstance(arguments, dict):
|
||||||
if "answer" in arguments:
|
if "answer" in arguments:
|
||||||
answer = arguments["answer"]
|
answer = arguments["answer"]
|
||||||
if (
|
|
||||||
isinstance(answer, str) and answer in self.state.keys()
|
|
||||||
): # if the answer is a state variable, return the value
|
|
||||||
answer = self.state[answer]
|
|
||||||
else:
|
else:
|
||||||
answer = arguments
|
answer = arguments
|
||||||
else:
|
else:
|
||||||
answer = arguments
|
answer = arguments
|
||||||
log_entry.final_answer = answer
|
if (
|
||||||
|
isinstance(answer, str) and answer in self.state.keys()
|
||||||
|
): # if the answer is a state variable, return the value
|
||||||
|
answer = self.state[answer]
|
||||||
|
log_entry.action_output = answer
|
||||||
return answer
|
return answer
|
||||||
else:
|
else:
|
||||||
if arguments is None:
|
if arguments is None:
|
||||||
|
@ -861,8 +864,8 @@ class JsonAgent(ReactAgent):
|
||||||
updated_information = f"Stored '{observation_name}' in memory."
|
updated_information = f"Stored '{observation_name}' in memory."
|
||||||
else:
|
else:
|
||||||
updated_information = str(observation).strip()
|
updated_information = str(observation).strip()
|
||||||
log_entry.observation = updated_information
|
log_entry.observations = updated_information
|
||||||
return log_entry
|
return None
|
||||||
|
|
||||||
|
|
||||||
class CodeAgent(ReactAgent):
|
class CodeAgent(ReactAgent):
|
||||||
|
@ -906,16 +909,15 @@ class CodeAgent(ReactAgent):
|
||||||
self.authorized_imports = list(
|
self.authorized_imports = list(
|
||||||
set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)
|
set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)
|
||||||
)
|
)
|
||||||
print("SYSS:", self.system_prompt)
|
|
||||||
self.system_prompt = self.system_prompt.replace(
|
self.system_prompt = self.system_prompt.replace(
|
||||||
"{{authorized_imports}}", str(self.authorized_imports)
|
"{{authorized_imports}}", str(self.authorized_imports)
|
||||||
)
|
)
|
||||||
self.custom_tools = {}
|
self.custom_tools = {}
|
||||||
|
|
||||||
def step(self, log_entry: ActionStep):
|
def step(self, log_entry: ActionStep) -> Union[None, Any]:
|
||||||
"""
|
"""
|
||||||
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
|
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
|
||||||
The errors are raised here, they are caught and logged in the run() method.
|
Returns None if the step is not final.
|
||||||
"""
|
"""
|
||||||
agent_memory = self.write_inner_memory_from_logs()
|
agent_memory = self.write_inner_memory_from_logs()
|
||||||
|
|
||||||
|
@ -967,11 +969,7 @@ class CodeAgent(ReactAgent):
|
||||||
error_msg = f"Error in code parsing: {e}. Make sure to provide correct code"
|
error_msg = f"Error in code parsing: {e}. Make sure to provide correct code"
|
||||||
raise AgentParsingError(error_msg)
|
raise AgentParsingError(error_msg)
|
||||||
|
|
||||||
log_entry.rationale = rationale
|
log_entry.tool_call = ToolCall(tool_name="python_interpreter", tool_arguments=code_action)
|
||||||
log_entry.tool_call = {
|
|
||||||
"tool_name": "code interpreter",
|
|
||||||
"tool_arguments": code_action,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
@ -988,7 +986,7 @@ class CodeAgent(ReactAgent):
|
||||||
}
|
}
|
||||||
if self.managed_agents is not None:
|
if self.managed_agents is not None:
|
||||||
static_tools = {**static_tools, **self.managed_agents}
|
static_tools = {**static_tools, **self.managed_agents}
|
||||||
result = self.python_evaluator(
|
output = self.python_evaluator(
|
||||||
code_action,
|
code_action,
|
||||||
static_tools=static_tools,
|
static_tools=static_tools,
|
||||||
custom_tools=self.custom_tools,
|
custom_tools=self.custom_tools,
|
||||||
|
@ -998,13 +996,13 @@ class CodeAgent(ReactAgent):
|
||||||
console.print("Print outputs:")
|
console.print("Print outputs:")
|
||||||
console.print(self.state["print_outputs"])
|
console.print(self.state["print_outputs"])
|
||||||
observation = "Print outputs:\n" + self.state["print_outputs"]
|
observation = "Print outputs:\n" + self.state["print_outputs"]
|
||||||
if result is not None:
|
if output is not None:
|
||||||
console.rule("Last output from code snippet:", align="left")
|
console.rule("Last output from code snippet:", align="left")
|
||||||
console.print(str(result))
|
console.print(str(output))
|
||||||
observation += "Last output from code snippet:\n" + truncate_content(
|
observation += "Last output from code snippet:\n" + truncate_content(
|
||||||
str(result)
|
str(output)
|
||||||
)
|
)
|
||||||
log_entry.observation = observation
|
log_entry.observations = observation
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Code execution failed due to the following error:\n{str(e)}"
|
error_msg = f"Code execution failed due to the following error:\n{str(e)}"
|
||||||
if "'dict' object has no attribute 'read'" in str(e):
|
if "'dict' object has no attribute 'read'" in str(e):
|
||||||
|
@ -1013,9 +1011,10 @@ class CodeAgent(ReactAgent):
|
||||||
for line in code_action.split("\n"):
|
for line in code_action.split("\n"):
|
||||||
if line[: len("final_answer")] == "final_answer":
|
if line[: len("final_answer")] == "final_answer":
|
||||||
console.print("Final answer:")
|
console.print("Final answer:")
|
||||||
console.print(f"[bold]{result}")
|
console.print(f"[bold]{output}")
|
||||||
log_entry.final_answer = result
|
log_entry.action_output = output
|
||||||
return result
|
return output
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class ManagedAgent:
|
class ManagedAgent:
|
||||||
|
|
|
@ -126,7 +126,9 @@ def get_remote_tools(logger, organization="huggingface-tools"):
|
||||||
class PythonInterpreterTool(Tool):
|
class PythonInterpreterTool(Tool):
|
||||||
name = "python_interpreter"
|
name = "python_interpreter"
|
||||||
description = "This is a tool that evaluates python code. It can be used to perform calculations."
|
description = "This is a tool that evaluates python code. It can be used to perform calculations."
|
||||||
|
inputs = {
|
||||||
|
"code": {"type": "string", "description": "The python code to run in interpreter"}
|
||||||
|
}
|
||||||
output_type = "string"
|
output_type = "string"
|
||||||
|
|
||||||
def __init__(self, *args, authorized_imports=None, **kwargs):
|
def __init__(self, *args, authorized_imports=None, **kwargs):
|
||||||
|
@ -147,7 +149,7 @@ class PythonInterpreterTool(Tool):
|
||||||
}
|
}
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def forward(self, code):
|
def forward(self, code: str) -> str:
|
||||||
output = str(
|
output = str(
|
||||||
evaluate_python_code(
|
evaluate_python_code(
|
||||||
code,
|
code,
|
||||||
|
|
|
@ -22,20 +22,20 @@ import gradio as gr
|
||||||
def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
|
def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
|
||||||
"""Extract ChatMessage objects from agent steps"""
|
"""Extract ChatMessage objects from agent steps"""
|
||||||
if isinstance(step_log, ActionStep):
|
if isinstance(step_log, ActionStep):
|
||||||
yield gr.ChatMessage(role="assistant", content=step_log.rationale)
|
yield gr.ChatMessage(role="assistant", content=step_log.llm_output)
|
||||||
if step_log.tool_call is not None:
|
if step_log.tool_call is not None:
|
||||||
used_code = step_log.tool_call["tool_name"] == "code interpreter"
|
used_code = step_log.tool_call.tool_name == "code interpreter"
|
||||||
content = step_log.tool_call["tool_arguments"]
|
content = step_log.tool_call.tool_arguments
|
||||||
if used_code:
|
if used_code:
|
||||||
content = f"```py\n{content}\n```"
|
content = f"```py\n{content}\n```"
|
||||||
yield gr.ChatMessage(
|
yield gr.ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
metadata={"title": f"🛠️ Used tool {step_log.tool_call['tool_name']}"},
|
metadata={"title": f"🛠️ Used tool {step_log.tool_call.tool_name}"},
|
||||||
content=str(content),
|
content=str(content),
|
||||||
)
|
)
|
||||||
if step_log.observation is not None:
|
if step_log.observations is not None:
|
||||||
yield gr.ChatMessage(
|
yield gr.ChatMessage(
|
||||||
role="assistant", content=f"```\n{step_log.observation}\n```"
|
role="assistant", content=f"```\n{step_log.observations}\n```"
|
||||||
)
|
)
|
||||||
if step_log.error is not None:
|
if step_log.error is not None:
|
||||||
yield gr.ChatMessage(
|
yield gr.ChatMessage(
|
||||||
|
|
|
@ -29,7 +29,7 @@ class Monitor:
|
||||||
self.total_output_token_count = 0
|
self.total_output_token_count = 0
|
||||||
|
|
||||||
def update_metrics(self, step_log):
|
def update_metrics(self, step_log):
|
||||||
step_duration = step_log.step_duration
|
step_duration = step_log.duration
|
||||||
self.step_durations.append(step_duration)
|
self.step_durations.append(step_duration)
|
||||||
console.print(f"Step {len(self.step_durations)}:")
|
console.print(f"Step {len(self.step_durations)}:")
|
||||||
console.print(f"- Time taken: {step_duration:.2f} seconds")
|
console.print(f"- Time taken: {step_duration:.2f} seconds")
|
||||||
|
|
|
@ -150,7 +150,7 @@ class Tool:
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
inputs: Dict[str, Dict[str, Union[str, type]]]
|
inputs: Dict[str, Dict[str, Union[str, type]]]
|
||||||
output_type: type
|
output_type: str
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self.is_initialized = False
|
self.is_initialized = False
|
||||||
|
|
|
@ -16,9 +16,10 @@ import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from agents.agent_types import AgentText
|
from agents.agent_types import AgentText
|
||||||
from agents.agents import (
|
from agents.agents import (
|
||||||
AgentMaxIterationsError,
|
AgentMaxIterationsError,
|
||||||
|
@ -26,16 +27,18 @@ from agents.agents import (
|
||||||
CodeAgent,
|
CodeAgent,
|
||||||
JsonAgent,
|
JsonAgent,
|
||||||
Toolbox,
|
Toolbox,
|
||||||
|
ToolCall
|
||||||
)
|
)
|
||||||
|
from agents.tools import tool
|
||||||
from agents.default_tools import PythonInterpreterTool
|
from agents.default_tools import PythonInterpreterTool
|
||||||
|
from transformers.testing_utils import get_tests_dir
|
||||||
|
|
||||||
def get_new_path(suffix="") -> str:
|
def get_new_path(suffix="") -> str:
|
||||||
directory = tempfile.mkdtemp()
|
directory = tempfile.mkdtemp()
|
||||||
return os.path.join(directory, str(uuid.uuid4()) + suffix)
|
return os.path.join(directory, str(uuid.uuid4()) + suffix)
|
||||||
|
|
||||||
|
|
||||||
def fake_react_json_llm(messages, stop_sequences=None, grammar=None) -> str:
|
def fake_json_llm(messages, stop_sequences=None, grammar=None) -> str:
|
||||||
prompt = str(messages)
|
prompt = str(messages)
|
||||||
|
|
||||||
if "special_marker" not in prompt:
|
if "special_marker" not in prompt:
|
||||||
|
@ -57,8 +60,29 @@ Action:
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def fake_json_llm_image(messages, stop_sequences=None, grammar=None) -> str:
|
||||||
|
prompt = str(messages)
|
||||||
|
|
||||||
def fake_react_code_llm(messages, stop_sequences=None, grammar=None) -> str:
|
if "special_marker" not in prompt:
|
||||||
|
return """
|
||||||
|
Thought: I should generate an image. special_marker
|
||||||
|
Action:
|
||||||
|
{
|
||||||
|
"action": "fake_image_generation_tool",
|
||||||
|
"action_input": {"prompt": "An image of a cat"}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
else: # We're at step 2
|
||||||
|
return """
|
||||||
|
Thought: I can now answer the initial question
|
||||||
|
Action:
|
||||||
|
{
|
||||||
|
"action": "final_answer",
|
||||||
|
"action_input": "image.png"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def fake_code_llm(messages, stop_sequences=None, grammar=None) -> str:
|
||||||
prompt = str(messages)
|
prompt = str(messages)
|
||||||
if "special_marker" not in prompt:
|
if "special_marker" not in prompt:
|
||||||
return """
|
return """
|
||||||
|
@ -78,7 +102,7 @@ final_answer(7.2904)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def fake_react_code_llm_error(messages, stop_sequences=None) -> str:
|
def fake_code_llm_error(messages, stop_sequences=None) -> str:
|
||||||
prompt = str(messages)
|
prompt = str(messages)
|
||||||
if "special_marker" not in prompt:
|
if "special_marker" not in prompt:
|
||||||
return """
|
return """
|
||||||
|
@ -98,7 +122,7 @@ final_answer("got an error")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def fake_react_code_functiondef(messages, stop_sequences=None) -> str:
|
def fake_code_functiondef(messages, stop_sequences=None) -> str:
|
||||||
prompt = str(messages)
|
prompt = str(messages)
|
||||||
if "special_marker" not in prompt:
|
if "special_marker" not in prompt:
|
||||||
return """
|
return """
|
||||||
|
@ -146,27 +170,23 @@ print(result)
|
||||||
|
|
||||||
|
|
||||||
class AgentTests(unittest.TestCase):
|
class AgentTests(unittest.TestCase):
|
||||||
def test_fake_code_agent(self):
|
def test_fake_oneshot_code_agent(self):
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_oneshot
|
tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_oneshot
|
||||||
)
|
)
|
||||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
output = agent.run("What is 2 multiplied by 3.6452?", oneshot=True)
|
||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
assert output == "7.2904"
|
assert output == "7.2904"
|
||||||
|
|
||||||
def test_fake_react_json_agent(self):
|
def test_fake_json_agent(self):
|
||||||
agent = JsonAgent(
|
agent = JsonAgent(
|
||||||
tools=[PythonInterpreterTool()], llm_engine=fake_react_json_llm
|
tools=[PythonInterpreterTool()], llm_engine=fake_json_llm
|
||||||
)
|
)
|
||||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
assert output == "7.2904"
|
assert output == "7.2904"
|
||||||
assert agent.logs[1].task == "What is 2 multiplied by 3.6452?"
|
assert agent.logs[1].task == "What is 2 multiplied by 3.6452?"
|
||||||
assert agent.logs[2].observation == "7.2904"
|
assert agent.logs[2].observations == "7.2904"
|
||||||
assert (
|
|
||||||
agent.logs[2].rationale.strip()
|
|
||||||
== "Thought: I should multiply 2 by 3.6452. special_marker"
|
|
||||||
)
|
|
||||||
assert (
|
assert (
|
||||||
agent.logs[3].llm_output
|
agent.logs[3].llm_output
|
||||||
== """
|
== """
|
||||||
|
@ -179,22 +199,43 @@ Action:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_fake_react_code_agent(self):
|
def test_json_agent_handles_image_tool_outputs(self):
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def fake_image_generation_tool(prompt: str) -> Image.Image:
|
||||||
|
"""Tool that generates an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt
|
||||||
|
"""
|
||||||
|
return Image.open(
|
||||||
|
Path(get_tests_dir("fixtures")) / "000000039769.png"
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = JsonAgent(
|
||||||
|
tools=[fake_image_generation_tool], llm_engine=fake_json_llm_image
|
||||||
|
)
|
||||||
|
output = agent.run("Make me an image.")
|
||||||
|
assert isinstance(output, Image.Image)
|
||||||
|
assert isinstance(agent.state["image.png"], Image.Image)
|
||||||
|
|
||||||
|
def test_fake_code_agent(self):
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm
|
tools=[PythonInterpreterTool()], llm_engine=fake_code_llm
|
||||||
)
|
)
|
||||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||||
assert isinstance(output, float)
|
assert isinstance(output, float)
|
||||||
assert output == 7.2904
|
assert output == 7.2904
|
||||||
assert agent.logs[1].task == "What is 2 multiplied by 3.6452?"
|
assert agent.logs[1].task == "What is 2 multiplied by 3.6452?"
|
||||||
assert agent.logs[3].tool_call == {
|
assert agent.logs[3].tool_call == ToolCall(
|
||||||
"tool_arguments": "final_answer(7.2904)",
|
tool_name="python_interpreter",
|
||||||
"tool_name": "code interpreter",
|
tool_arguments="final_answer(7.2904)",
|
||||||
}
|
)
|
||||||
|
|
||||||
def test_react_code_agent_code_errors_show_offending_lines(self):
|
def test_code_agent_code_errors_show_offending_lines(self):
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm_error
|
tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_error
|
||||||
)
|
)
|
||||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||||
assert isinstance(output, AgentText)
|
assert isinstance(output, AgentText)
|
||||||
|
@ -202,9 +243,9 @@ Action:
|
||||||
assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs)
|
assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs)
|
||||||
|
|
||||||
def test_setup_agent_with_empty_toolbox(self):
|
def test_setup_agent_with_empty_toolbox(self):
|
||||||
JsonAgent(llm_engine=fake_react_json_llm, tools=[])
|
JsonAgent(llm_engine=fake_json_llm, tools=[])
|
||||||
|
|
||||||
def test_react_fails_max_iterations(self):
|
def test_fails_max_iterations(self):
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[PythonInterpreterTool()],
|
tools=[PythonInterpreterTool()],
|
||||||
llm_engine=fake_code_llm_no_return, # use this callable because it never ends
|
llm_engine=fake_code_llm_no_return, # use this callable because it never ends
|
||||||
|
@ -216,19 +257,19 @@ Action:
|
||||||
|
|
||||||
def test_init_agent_with_different_toolsets(self):
|
def test_init_agent_with_different_toolsets(self):
|
||||||
toolset_1 = []
|
toolset_1 = []
|
||||||
agent = CodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm)
|
agent = CodeAgent(tools=toolset_1, llm_engine=fake_code_llm)
|
||||||
assert (
|
assert (
|
||||||
len(agent.toolbox.tools) == 1
|
len(agent.toolbox.tools) == 1
|
||||||
) # when no tools are provided, only the final_answer tool is added by default
|
) # when no tools are provided, only the final_answer tool is added by default
|
||||||
|
|
||||||
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
|
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
|
||||||
agent = CodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm)
|
agent = CodeAgent(tools=toolset_2, llm_engine=fake_code_llm)
|
||||||
assert (
|
assert (
|
||||||
len(agent.toolbox.tools) == 2
|
len(agent.toolbox.tools) == 2
|
||||||
) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer
|
) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer
|
||||||
|
|
||||||
toolset_3 = Toolbox(toolset_2)
|
toolset_3 = Toolbox(toolset_2)
|
||||||
agent = CodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm)
|
agent = CodeAgent(tools=toolset_3, llm_engine=fake_code_llm)
|
||||||
assert (
|
assert (
|
||||||
len(agent.toolbox.tools) == 2
|
len(agent.toolbox.tools) == 2
|
||||||
) # same as previous one, where toolset_3 is an instantiation of previous one
|
) # same as previous one, where toolset_3 is an instantiation of previous one
|
||||||
|
@ -236,12 +277,12 @@ Action:
|
||||||
# check that add_base_tools will not interfere with existing tools
|
# check that add_base_tools will not interfere with existing tools
|
||||||
with pytest.raises(KeyError) as e:
|
with pytest.raises(KeyError) as e:
|
||||||
agent = JsonAgent(
|
agent = JsonAgent(
|
||||||
tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True
|
tools=toolset_3, llm_engine=fake_json_llm, add_base_tools=True
|
||||||
)
|
)
|
||||||
assert "already exists in the toolbox" in str(e)
|
assert "already exists in the toolbox" in str(e)
|
||||||
|
|
||||||
# check that python_interpreter base tool does not get added to code agents
|
# check that python_interpreter base tool does not get added to code agents
|
||||||
agent = CodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
|
agent = CodeAgent(tools=[], llm_engine=fake_code_llm, add_base_tools=True)
|
||||||
assert (
|
assert (
|
||||||
len(agent.toolbox.tools) == 2
|
len(agent.toolbox.tools) == 2
|
||||||
) # added final_answer tool + search
|
) # added final_answer tool + search
|
||||||
|
@ -249,7 +290,7 @@ Action:
|
||||||
def test_function_persistence_across_steps(self):
|
def test_function_persistence_across_steps(self):
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
llm_engine=fake_react_code_functiondef,
|
llm_engine=fake_code_functiondef,
|
||||||
max_iterations=2,
|
max_iterations=2,
|
||||||
additional_authorized_imports=["numpy"],
|
additional_authorized_imports=["numpy"],
|
||||||
)
|
)
|
||||||
|
@ -257,17 +298,17 @@ Action:
|
||||||
assert res[0] == 0.5
|
assert res[0] == 0.5
|
||||||
|
|
||||||
def test_init_managed_agent(self):
|
def test_init_managed_agent(self):
|
||||||
agent = CodeAgent(tools=[], llm_engine=fake_react_code_functiondef)
|
agent = CodeAgent(tools=[], llm_engine=fake_code_functiondef)
|
||||||
managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
|
managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
|
||||||
assert managed_agent.name == "managed_agent"
|
assert managed_agent.name == "managed_agent"
|
||||||
assert managed_agent.description == "Empty"
|
assert managed_agent.description == "Empty"
|
||||||
|
|
||||||
def test_agent_description_gets_correctly_inserted_in_system_prompt(self):
|
def test_agent_description_gets_correctly_inserted_in_system_prompt(self):
|
||||||
agent = CodeAgent(tools=[], llm_engine=fake_react_code_functiondef)
|
agent = CodeAgent(tools=[], llm_engine=fake_code_functiondef)
|
||||||
managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
|
managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
|
||||||
manager_agent = CodeAgent(
|
manager_agent = CodeAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
llm_engine=fake_react_code_functiondef,
|
llm_engine=fake_code_functiondef,
|
||||||
managed_agents=[managed_agent],
|
managed_agents=[managed_agent],
|
||||||
)
|
)
|
||||||
assert "You can also give requests to team members." not in agent.system_prompt
|
assert "You can also give requests to team members." not in agent.system_prompt
|
||||||
|
|
Loading…
Reference in New Issue