diff --git a/src/agents/agents.py b/src/agents/agents.py index 2a67665..4003e40 100644 --- a/src/agents/agents.py +++ b/src/agents/agents.py @@ -18,6 +18,10 @@ import time from typing import Any, Callable, Dict, List, Optional, Union, Tuple from dataclasses import dataclass from rich.syntax import Syntax +from rich.console import Group +from rich.panel import Panel +from rich.rule import Rule +from rich.text import Text from transformers.utils import is_torch_available @@ -523,8 +527,7 @@ class ReactAgent(BaseAgent): else: self.logs.append(system_prompt_step) - console.rule("[bold]New task", characters="=") - console.print(self.task) + console.print(Group(Rule("[bold]New task", characters="="), Text(self.task))) self.logs.append(TaskStep(task=task)) if oneshot: @@ -693,8 +696,7 @@ Now begin!""", self.logs.append( PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction) ) - console.rule("[orange]Initial plan") - console.print(final_plan_redaction) + console.print(Rule("[bold]Initial plan", style="orange"), Text(final_plan_redaction)) else: # update plan agent_memory = self.write_inner_memory_from_logs( summary_mode=False @@ -750,8 +752,8 @@ Now begin!""", self.logs.append( PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction) ) - console.rule("[orange]Updated plan") - console.print(final_plan_redaction) + console.print(Rule("[bold]Updated plan", style="orange"), Text(final_plan_redaction)) + class JsonAgent(ReactAgent): @@ -800,11 +802,10 @@ class JsonAgent(ReactAgent): log_entry.agent_memory = agent_memory.copy() if self.verbose: - console.rule( - "[italic]Calling LLM engine with this last message:", align="left" - ) - console.print(self.prompt_messages[-1]) - console.rule() + console.print(Group( + Rule("[italic]Calling LLM engine with this last message:", align="left", style="orange"), + Text(str(self.prompt_messages[-1])) + )) try: additional_args = ( @@ -820,8 +821,10 @@ class JsonAgent(ReactAgent): raise AgentGenerationError(f"Error in generating llm_engine output: {e}.") if self.verbose: - console.rule("[italic]Output message of the LLM:") - console.print(llm_output) + console.print(Group( + Rule("[italic]Output message of the LLM:", align="left", style="orange"), + Text(llm_output) + )) # Parse rationale, action = self.extract_action( @@ -836,10 +839,8 @@ class JsonAgent(ReactAgent): log_entry.tool_call = ToolCall(tool_name=tool_name, tool_arguments=arguments) # Execute - console.rule("Agent thoughts:") - console.print(rationale) - console.rule() - console.print(f">>> Calling tool: '{tool_name}' with arguments: {arguments}") + console.print(Rule("Agent thoughts:", align="left"), Text(rationale)) + console.print(Panel(Text(f"Calling tool: '{tool_name}' with arguments: {arguments}"))) if tool_name == "final_answer": if isinstance(arguments, dict): if "answer" in arguments: @@ -933,11 +934,10 @@ class CodeAgent(ReactAgent): log_entry.agent_memory = agent_memory.copy() if self.verbose: - console.rule( - "[italic]Calling LLM engine with these last messages:", align="left" - ) - console.print(self.prompt_messages[-2:]) - console.rule() + console.print(Group( + Rule("[italic]Calling LLM engine with these last messages:", align="left", style="orange"), + Text(str(self.prompt_messages[-2:])) + )) try: additional_args = ( @@ -953,10 +953,10 @@ class CodeAgent(ReactAgent): raise AgentGenerationError(f"Error in generating llm_engine output: {e}.") if self.verbose: - console.rule("[italic]Output message of the LLM:") - console.print( - Syntax(llm_output, lexer="markdown", background_color="default") - ) + console.print(Group( + Rule("[italic]Output message of the LLM:", align="left", style="orange"), + Syntax(llm_output, lexer="markdown", theme="github-dark") + )) # Parse try: @@ -981,11 +981,14 @@ class CodeAgent(ReactAgent): # Execute if self.verbose: - console.rule("[italic]Agent thoughts") - console.print(rationale) - console.rule("[bold]Agent is executing the code below:", align="left") - console.print(Syntax(code_action, lexer="python", background_color="default")) - console.rule("", align="left") + console.print(Group( + Rule("[italic]Agent thoughts", align="left"), + Text(rationale) + )) + + console.print(Panel( + Syntax(code_action, lexer="python", theme="github-dark"), title="[bold]Agent is executing the code below:", title_align="left") + ) try: static_tools = { @@ -1001,12 +1004,14 @@ class CodeAgent(ReactAgent): state=self.state, authorized_imports=self.authorized_imports, ) - console.print("Print outputs:") - console.print(self.state["print_outputs"]) + if len(self.state["print_outputs"]) > 0: + console.print(Group(Text("Print outputs:", style="bold"), Text(self.state["print_outputs"]))) observation = "Print outputs:\n" + self.state["print_outputs"] if output is not None: - console.rule("Last output from code snippet:", align="left") - console.print(str(output)) + truncated_output = truncate_content( + str(output) + ) + console.print(Group(Text("Last output from code snippet:", style="bold"), Text(truncated_output))) observation += "Last output from code snippet:\n" + truncate_content( str(output) ) @@ -1018,8 +1023,7 @@ class CodeAgent(ReactAgent): raise AgentExecutionError(error_msg) for line in code_action.split("\n"): if line[: len("final_answer")] == "final_answer": - console.print("Final answer:") - console.print(f"[bold]{output}") + console.print(Group(Text("Final answer:", style="bold"), Text(str(output), style="bold green"))) log_entry.action_output = output return output return None diff --git a/src/agents/monitoring.py b/src/agents/monitoring.py index 2f57268..a89cb5d 100644 --- a/src/agents/monitoring.py +++ b/src/agents/monitoring.py @@ -15,7 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from .utils import console - +from rich.text import Text +from rich.console import Group class Monitor: def __init__(self, tracked_llm_engine): @@ -31,8 +32,10 @@ class Monitor: def update_metrics(self, step_log): step_duration = step_log.duration self.step_durations.append(step_duration) - console.print(f"Step {len(self.step_durations)}:") - console.print(f"- Time taken: {step_duration:.2f} seconds") + console_outputs = [ + Text(f"Step {len(self.step_durations)}:", style="bold"), + Text(f"- Time taken: {step_duration:.2f} seconds") + ] if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None: self.total_input_token_count += ( @@ -41,8 +44,11 @@ class Monitor: self.total_output_token_count += ( self.tracked_llm_engine.last_output_token_count ) - console.print(f"- Input tokens: {self.total_input_token_count:,}") - console.print(f"- Output tokens: {self.total_output_token_count:,}") + console_outputs += [ + Text(f"- Input tokens: {self.total_input_token_count:,}"), + Text(f"- Output tokens: {self.total_output_token_count:,}") + ] + console.print(Group(*console_outputs)) __all__ = ["Monitor"]