From 34810986e048b295a3f518c883b045b2ec0045c0 Mon Sep 17 00:00:00 2001 From: Jan <5099251+jank@users.noreply.github.com> Date: Sat, 18 Jan 2025 18:27:48 +0100 Subject: [PATCH] refactor: update model type to ChatMessage in agent classes (#263) --- src/smolagents/agents.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index da791aa..46bfa51 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -34,7 +34,10 @@ from .local_python_executor import ( LocalPythonInterpreter, fix_final_answer_code, ) -from .models import MessageRole +from .models import ( + MessageRole, + ChatMessage, +) from .monitoring import Monitor from .prompts import ( CODE_SYSTEM_PROMPT, @@ -191,7 +194,7 @@ class MultiStepAgent: def __init__( self, tools: List[Tool], - model: Callable[[List[Dict[str, str]]], str], + model: Callable[[List[Dict[str, str]]], ChatMessage], system_prompt: Optional[str] = None, tool_description_template: Optional[str] = None, max_steps: int = 6, @@ -775,7 +778,7 @@ class ToolCallingAgent(MultiStepAgent): def __init__( self, tools: List[Tool], - model: Callable, + model: Callable[[List[Dict[str, str]]], ChatMessage], system_prompt: Optional[str] = None, planning_interval: Optional[int] = None, **kwargs, @@ -885,7 +888,7 @@ class CodeAgent(MultiStepAgent): def __init__( self, tools: List[Tool], - model: Callable, + model: Callable[[List[Dict[str, str]]], ChatMessage], system_prompt: Optional[str] = None, grammar: Optional[Dict[str, str]] = None, additional_authorized_imports: Optional[List[str]] = None,