diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index da791aa..46bfa51 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -34,7 +34,10 @@ from .local_python_executor import ( LocalPythonInterpreter, fix_final_answer_code, ) -from .models import MessageRole +from .models import ( + MessageRole, + ChatMessage, +) from .monitoring import Monitor from .prompts import ( CODE_SYSTEM_PROMPT, @@ -191,7 +194,7 @@ class MultiStepAgent: def __init__( self, tools: List[Tool], - model: Callable[[List[Dict[str, str]]], str], + model: Callable[[List[Dict[str, str]]], ChatMessage], system_prompt: Optional[str] = None, tool_description_template: Optional[str] = None, max_steps: int = 6, @@ -775,7 +778,7 @@ class ToolCallingAgent(MultiStepAgent): def __init__( self, tools: List[Tool], - model: Callable, + model: Callable[[List[Dict[str, str]]], ChatMessage], system_prompt: Optional[str] = None, planning_interval: Optional[int] = None, **kwargs, @@ -885,7 +888,7 @@ class CodeAgent(MultiStepAgent): def __init__( self, tools: List[Tool], - model: Callable, + model: Callable[[List[Dict[str, str]]], ChatMessage], system_prompt: Optional[str] = None, grammar: Optional[Dict[str, str]] = None, additional_authorized_imports: Optional[List[str]] = None,