diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index 7e80161..eddced1 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -225,6 +225,10 @@ class MultiStepAgent: messages.extend(memory_step.to_messages(summary_mode=summary_mode)) return messages + def visualize(self): + """Creates a rich tree visualization of the agent's structure.""" + self.logger.visualize_agent_tree(self) + def extract_action(self, model_output: str, split_token: str) -> Tuple[str, str]: """ Parse action from the LLM output diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 2dab05a..d2fdd47 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -241,8 +241,6 @@ class Model: def __init__(self, **kwargs): self.last_input_token_count = None self.last_output_token_count = None - # Set default values for common parameters - kwargs.setdefault("max_tokens", 4096) self.kwargs = kwargs def _prepare_completion_kwargs( @@ -643,15 +641,19 @@ class LiteLLMModel(Model): The base URL of the OpenAI-compatible API server. api_key (`str`, *optional*): The API key to use for authentication. + custom_role_conversions (`dict[str, str]`, *optional*): + Custom role conversion mapping to convert message roles in others. + Useful for specific models that do not support specific message roles like "system". **kwargs: Additional keyword arguments to pass to the OpenAI API. """ def __init__( self, - model_id="anthropic/claude-3-5-sonnet-20240620", + model_id: str = "anthropic/claude-3-5-sonnet-20240620", api_base=None, api_key=None, + custom_role_conversions: Optional[Dict[str, str]] = None, **kwargs, ): try: @@ -667,6 +669,7 @@ class LiteLLMModel(Model): litellm.add_function_to_prompt = True self.api_base = api_base self.api_key = api_key + self.custom_role_conversions = custom_role_conversions def __call__( self, @@ -687,6 +690,7 @@ class LiteLLMModel(Model): api_base=self.api_base, api_key=self.api_key, convert_images_to_image_urls=True, + custom_role_conversions=self.custom_role_conversions, **kwargs, ) diff --git a/src/smolagents/monitoring.py b/src/smolagents/monitoring.py index 53c433f..33ae2ab 100644 --- a/src/smolagents/monitoring.py +++ b/src/smolagents/monitoring.py @@ -23,7 +23,9 @@ from rich.console import Console, Group from rich.panel import Panel from rich.rule import Rule from rich.syntax import Syntax +from rich.table import Table from rich.text import Text +from rich.tree import Tree class Monitor: @@ -162,5 +164,42 @@ class AgentLogger: ) ) + def visualize_agent_tree(self, agent): + def create_tools_section(tools_dict): + table = Table(show_header=True, header_style="bold") + table.add_column("Name", style="blue") + table.add_column("Description") + table.add_column("Arguments") + + for name, tool in tools_dict.items(): + args = [ + f"{arg_name} (`{info.get('type', 'Any')}`{', optional' if info.get('optional') else ''}): {info.get('description', '')}" + for arg_name, info in getattr(tool, "inputs", {}).items() + ] + table.add_row(name, getattr(tool, "description", str(tool)), "\n".join(args)) + + return Group(Text("🛠️ Tools", style="bold italic blue"), table) + + def build_agent_tree(parent_tree, agent_obj): + """Recursively builds the agent tree.""" + if agent_obj.tools: + parent_tree.add(create_tools_section(agent_obj.tools)) + + if agent_obj.managed_agents: + agents_branch = parent_tree.add("[bold italic blue]🤖 Managed agents") + for name, managed_agent in agent_obj.managed_agents.items(): + agent_node_text = f"[bold {YELLOW_HEX}]{name} - {managed_agent.agent.__class__.__name__}" + agent_tree = agents_branch.add(agent_node_text) + if hasattr(managed_agent, "description"): + agent_tree.add( + f"[bold italic blue]📝 Description:[/bold italic blue] {managed_agent.description}" + ) + if hasattr(managed_agent, "agent"): + build_agent_tree(agent_tree, managed_agent.agent) + + main_tree = Tree(f"[bold {YELLOW_HEX}]{agent.__class__.__name__}") + build_agent_tree(main_tree, agent) + self.console.print(main_tree) + __all__ = ["AgentLogger", "Monitor"]