Improve code logging

This commit is contained in:
Aymeric 2024-12-09 22:34:56 +01:00
parent 154d1e938e
commit 43a3f46835
2 changed files with 30 additions and 28 deletions

View File

@ -17,6 +17,7 @@
import time
from typing import Any, Callable, Dict, List, Optional, Union
from dataclasses import dataclass
from rich.syntax import Syntax
from transformers.utils import is_torch_available
from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content
@ -353,14 +354,6 @@ class BaseAgent:
raise AgentExecutionError(error_msg)
def log_rationale_code_action(self, rationale: str, code_action: str) -> None:
if self.verbose:
console.rule("Agent thoughts")
console.print(rationale)
console.rule("Agent is executing the code below:", align="left")
console.print(code_action)
console.rule("", align="left")
def run(self, **kwargs):
"""To be implemented in the child class"""
raise NotImplementedError
@ -447,20 +440,23 @@ class ReactAgent(BaseAgent):
else:
self.logs.append(TaskStep(task=task))
if oneshot:
step_start_time = time.time()
step_log = ActionStep(start_time=step_start_time)
step_log.step_end_time = time.time()
step_log.step_duration = step_log.step_end_time - step_start_time
with console.status(
"Agent is running...", spinner="aesthetic"
):
if oneshot:
step_start_time = time.time()
step_log = ActionStep(start_time=step_start_time)
step_log.step_end_time = time.time()
step_log.step_duration = step_log.step_end_time - step_start_time
# Run the agent's step
result = self.step(step_log)
return result
# Run the agent's step
result = self.step(step_log)
return result
if stream:
return self.stream_run(task)
else:
return self.direct_run(task)
if stream:
return self.stream_run(task)
else:
return self.direct_run(task)
def stream_run(self, task: str):
"""
@ -687,7 +683,7 @@ class JsonAgent(ReactAgent):
log_entry.agent_memory = agent_memory.copy()
if self.verbose:
console.rule("Calling LLM engine with this last message:", align="left")
console.rule("[italic]Calling LLM engine with this last message:", align="left")
console.print(self.prompt[-1])
console.rule()
@ -806,7 +802,7 @@ class CodeAgent(ReactAgent):
log_entry.agent_memory = agent_memory.copy()
if self.verbose:
console.rule("Calling LLM engine with these last messages:", align="left")
console.rule("[italic]Calling LLM engine with these last messages:", align="left")
console.print(self.prompt[-2:])
console.rule()
@ -819,8 +815,8 @@ class CodeAgent(ReactAgent):
raise AgentGenerationError(f"Error in generating llm output: {e}.")
if self.verbose:
console.rule("Output message of the LLM:")
console.print(llm_output)
console.rule("[italic]Output message of the LLM:")
console.print(Syntax(llm_output, lexer='markdown', background_color='default'))
log_entry.llm_output = llm_output
# Parse
@ -840,7 +836,13 @@ class CodeAgent(ReactAgent):
log_entry.tool_call = {"tool_name": "code interpreter", "tool_arguments": code_action}
# Execute
self.log_rationale_code_action(rationale, code_action)
if self.verbose:
console.rule("[italic]Agent thoughts")
console.print(rationale)
console.rule("[bold]Agent is executing the code below:", align="left")
console.print(Syntax(code_action, lexer='python', background_color='default'))
console.rule("", align="left")
try:
static_tools = {
**BASE_PYTHON_TOOLS.copy(),
@ -859,7 +861,7 @@ class CodeAgent(ReactAgent):
console.print(self.state["print_outputs"])
observation = "Print outputs:\n" + self.state["print_outputs"]
if result is not None:
console.print("Last output from code snippet:")
console.rule("Last output from code snippet:", align="left")
console.print(str(result))
observation += "Last output from code snippet:\n" + truncate_content(str(result))
log_entry.observation = observation
@ -875,7 +877,7 @@ class CodeAgent(ReactAgent):
log_entry.final_answer = result
return result
class ManagedAgent:
def __init__(self, agent, name, description, additional_prompting=None, provide_run_summary=False):
self.agent = agent

View File

@ -20,7 +20,7 @@ agent = CodeAgent(
# Run it!
result = agent.run(
"When was Llama 3 first released?", oneshot=True
"When was Llama 3 first released?"
)
print(result)