Add visualization method to display the agent' structure as a tree 🌳 (#470)

This commit is contained in:
Aymeric Roucher 2025-02-03 11:07:05 +01:00 committed by GitHub
parent 39908439d2
commit 42f95d8ee1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 50 additions and 3 deletions

View File

@ -225,6 +225,10 @@ class MultiStepAgent:
messages.extend(memory_step.to_messages(summary_mode=summary_mode))
return messages
def visualize(self):
"""Creates a rich tree visualization of the agent's structure."""
self.logger.visualize_agent_tree(self)
def extract_action(self, model_output: str, split_token: str) -> Tuple[str, str]:
"""
Parse action from the LLM output

View File

@ -241,8 +241,6 @@ class Model:
def __init__(self, **kwargs):
self.last_input_token_count = None
self.last_output_token_count = None
# Set default values for common parameters
kwargs.setdefault("max_tokens", 4096)
self.kwargs = kwargs
def _prepare_completion_kwargs(
@ -643,15 +641,19 @@ class LiteLLMModel(Model):
The base URL of the OpenAI-compatible API server.
api_key (`str`, *optional*):
The API key to use for authentication.
custom_role_conversions (`dict[str, str]`, *optional*):
Custom role conversion mapping to convert message roles in others.
Useful for specific models that do not support specific message roles like "system".
**kwargs:
Additional keyword arguments to pass to the OpenAI API.
"""
def __init__(
self,
model_id="anthropic/claude-3-5-sonnet-20240620",
model_id: str = "anthropic/claude-3-5-sonnet-20240620",
api_base=None,
api_key=None,
custom_role_conversions: Optional[Dict[str, str]] = None,
**kwargs,
):
try:
@ -667,6 +669,7 @@ class LiteLLMModel(Model):
litellm.add_function_to_prompt = True
self.api_base = api_base
self.api_key = api_key
self.custom_role_conversions = custom_role_conversions
def __call__(
self,
@ -687,6 +690,7 @@ class LiteLLMModel(Model):
api_base=self.api_base,
api_key=self.api_key,
convert_images_to_image_urls=True,
custom_role_conversions=self.custom_role_conversions,
**kwargs,
)

View File

@ -23,7 +23,9 @@ from rich.console import Console, Group
from rich.panel import Panel
from rich.rule import Rule
from rich.syntax import Syntax
from rich.table import Table
from rich.text import Text
from rich.tree import Tree
class Monitor:
@ -162,5 +164,42 @@ class AgentLogger:
)
)
def visualize_agent_tree(self, agent):
def create_tools_section(tools_dict):
table = Table(show_header=True, header_style="bold")
table.add_column("Name", style="blue")
table.add_column("Description")
table.add_column("Arguments")
for name, tool in tools_dict.items():
args = [
f"{arg_name} (`{info.get('type', 'Any')}`{', optional' if info.get('optional') else ''}): {info.get('description', '')}"
for arg_name, info in getattr(tool, "inputs", {}).items()
]
table.add_row(name, getattr(tool, "description", str(tool)), "\n".join(args))
return Group(Text("🛠️ Tools", style="bold italic blue"), table)
def build_agent_tree(parent_tree, agent_obj):
"""Recursively builds the agent tree."""
if agent_obj.tools:
parent_tree.add(create_tools_section(agent_obj.tools))
if agent_obj.managed_agents:
agents_branch = parent_tree.add("[bold italic blue]🤖 Managed agents")
for name, managed_agent in agent_obj.managed_agents.items():
agent_node_text = f"[bold {YELLOW_HEX}]{name} - {managed_agent.agent.__class__.__name__}"
agent_tree = agents_branch.add(agent_node_text)
if hasattr(managed_agent, "description"):
agent_tree.add(
f"[bold italic blue]📝 Description:[/bold italic blue] {managed_agent.description}"
)
if hasattr(managed_agent, "agent"):
build_agent_tree(agent_tree, managed_agent.agent)
main_tree = Tree(f"[bold {YELLOW_HEX}]{agent.__class__.__name__}")
build_agent_tree(main_tree, agent)
self.console.print(main_tree)
__all__ = ["AgentLogger", "Monitor"]