Improve logging

This commit is contained in:
Aymeric 2024-12-17 19:31:56 +01:00
parent 06066437fd
commit 1c4ac9cb57
2 changed files with 52 additions and 42 deletions

View File

@ -18,6 +18,10 @@ import time
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
from dataclasses import dataclass
from rich.syntax import Syntax
from rich.console import Group
from rich.panel import Panel
from rich.rule import Rule
from rich.text import Text
from transformers.utils import is_torch_available
@ -523,8 +527,7 @@ class ReactAgent(BaseAgent):
else:
self.logs.append(system_prompt_step)
console.rule("[bold]New task", characters="=")
console.print(self.task)
console.print(Group(Rule("[bold]New task", characters="="), Text(self.task)))
self.logs.append(TaskStep(task=task))
if oneshot:
@ -693,8 +696,7 @@ Now begin!""",
self.logs.append(
PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)
)
console.rule("[orange]Initial plan")
console.print(final_plan_redaction)
console.print(Rule("[bold]Initial plan", style="orange"), Text(final_plan_redaction))
else: # update plan
agent_memory = self.write_inner_memory_from_logs(
summary_mode=False
@ -750,8 +752,8 @@ Now begin!""",
self.logs.append(
PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)
)
console.rule("[orange]Updated plan")
console.print(final_plan_redaction)
console.print(Rule("[bold]Updated plan", style="orange"), Text(final_plan_redaction))
class JsonAgent(ReactAgent):
@ -800,11 +802,10 @@ class JsonAgent(ReactAgent):
log_entry.agent_memory = agent_memory.copy()
if self.verbose:
console.rule(
"[italic]Calling LLM engine with this last message:", align="left"
)
console.print(self.prompt_messages[-1])
console.rule()
console.print(Group(
Rule("[italic]Calling LLM engine with this last message:", align="left", style="orange"),
Text(str(self.prompt_messages[-1]))
))
try:
additional_args = (
@ -820,8 +821,10 @@ class JsonAgent(ReactAgent):
raise AgentGenerationError(f"Error in generating llm_engine output: {e}.")
if self.verbose:
console.rule("[italic]Output message of the LLM:")
console.print(llm_output)
console.print(Group(
Rule("[italic]Output message of the LLM:", align="left", style="orange"),
Text(llm_output)
))
# Parse
rationale, action = self.extract_action(
@ -836,10 +839,8 @@ class JsonAgent(ReactAgent):
log_entry.tool_call = ToolCall(tool_name=tool_name, tool_arguments=arguments)
# Execute
console.rule("Agent thoughts:")
console.print(rationale)
console.rule()
console.print(f">>> Calling tool: '{tool_name}' with arguments: {arguments}")
console.print(Rule("Agent thoughts:", align="left"), Text(rationale))
console.print(Panel(Text(f"Calling tool: '{tool_name}' with arguments: {arguments}")))
if tool_name == "final_answer":
if isinstance(arguments, dict):
if "answer" in arguments:
@ -933,11 +934,10 @@ class CodeAgent(ReactAgent):
log_entry.agent_memory = agent_memory.copy()
if self.verbose:
console.rule(
"[italic]Calling LLM engine with these last messages:", align="left"
)
console.print(self.prompt_messages[-2:])
console.rule()
console.print(Group(
Rule("[italic]Calling LLM engine with these last messages:", align="left", style="orange"),
Text(str(self.prompt_messages[-2:]))
))
try:
additional_args = (
@ -953,10 +953,10 @@ class CodeAgent(ReactAgent):
raise AgentGenerationError(f"Error in generating llm_engine output: {e}.")
if self.verbose:
console.rule("[italic]Output message of the LLM:")
console.print(
Syntax(llm_output, lexer="markdown", background_color="default")
)
console.print(Group(
Rule("[italic]Output message of the LLM:", align="left", style="orange"),
Syntax(llm_output, lexer="markdown", theme="github-dark")
))
# Parse
try:
@ -981,11 +981,14 @@ class CodeAgent(ReactAgent):
# Execute
if self.verbose:
console.rule("[italic]Agent thoughts")
console.print(rationale)
console.rule("[bold]Agent is executing the code below:", align="left")
console.print(Syntax(code_action, lexer="python", background_color="default"))
console.rule("", align="left")
console.print(Group(
Rule("[italic]Agent thoughts", align="left"),
Text(rationale)
))
console.print(Panel(
Syntax(code_action, lexer="python", theme="github-dark"), title="[bold]Agent is executing the code below:", title_align="left")
)
try:
static_tools = {
@ -1001,12 +1004,14 @@ class CodeAgent(ReactAgent):
state=self.state,
authorized_imports=self.authorized_imports,
)
console.print("Print outputs:")
console.print(self.state["print_outputs"])
if len(self.state["print_outputs"]) > 0:
console.print(Group(Text("Print outputs:", style="bold"), Text(self.state["print_outputs"])))
observation = "Print outputs:\n" + self.state["print_outputs"]
if output is not None:
console.rule("Last output from code snippet:", align="left")
console.print(str(output))
truncated_output = truncate_content(
str(output)
)
console.print(Group(Text("Last output from code snippet:", style="bold"), Text(truncated_output)))
observation += "Last output from code snippet:\n" + truncate_content(
str(output)
)
@ -1018,8 +1023,7 @@ class CodeAgent(ReactAgent):
raise AgentExecutionError(error_msg)
for line in code_action.split("\n"):
if line[: len("final_answer")] == "final_answer":
console.print("Final answer:")
console.print(f"[bold]{output}")
console.print(Group(Text("Final answer:", style="bold"), Text(str(output), style="bold green")))
log_entry.action_output = output
return output
return None

View File

@ -15,7 +15,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .utils import console
from rich.text import Text
from rich.console import Group
class Monitor:
def __init__(self, tracked_llm_engine):
@ -31,8 +32,10 @@ class Monitor:
def update_metrics(self, step_log):
step_duration = step_log.duration
self.step_durations.append(step_duration)
console.print(f"Step {len(self.step_durations)}:")
console.print(f"- Time taken: {step_duration:.2f} seconds")
console_outputs = [
Text(f"Step {len(self.step_durations)}:", style="bold"),
Text(f"- Time taken: {step_duration:.2f} seconds")
]
if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None:
self.total_input_token_count += (
@ -41,8 +44,11 @@ class Monitor:
self.total_output_token_count += (
self.tracked_llm_engine.last_output_token_count
)
console.print(f"- Input tokens: {self.total_input_token_count:,}")
console.print(f"- Output tokens: {self.total_output_token_count:,}")
console_outputs += [
Text(f"- Input tokens: {self.total_input_token_count:,}"),
Text(f"- Output tokens: {self.total_output_token_count:,}")
]
console.print(Group(*console_outputs))
__all__ = ["Monitor"]