Create PromptTemplates typed dict (#547)

This commit is contained in:
Albert Villanova del Moral 2025-02-12 08:45:56 +01:00 committed by GitHub
parent a17f915f61
commit 02b2b7ebb9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 100 additions and 13 deletions

View File

@ -57,3 +57,11 @@ _This class is deprecated since 1.8.0: now you simply need to pass attributes `n
> You must have `gradio` installed to use the UI. Please run `pip install smolagents[gradio]` if it's not the case.
[[autodoc]] GradioUI
## Prompts
[[autodoc]] smolagents.agents.PromptTemplates
[[autodoc]] smolagents.agents.PlanningPromptTemplate
[[autodoc]] smolagents.agents.ManagedAgentPromptTemplate

View File

@ -155,3 +155,11 @@ model = OpenAIServerModel(
api_key=os.environ["OPENAI_API_KEY"],
)
```
## Prompts
[[autodoc]] smolagents.agents.PromptTemplates
[[autodoc]] smolagents.agents.PlanningPromptTemplate
[[autodoc]] smolagents.agents.ManagedAgentPromptTemplate

View File

@ -147,3 +147,11 @@ print(model(messages))
```
[[autodoc]] LiteLLMModel
## Prompts
[[autodoc]] smolagents.agents.PromptTemplates
[[autodoc]] smolagents.agents.PlanningPromptTemplate
[[autodoc]] smolagents.agents.ManagedAgentPromptTemplate

View File

@ -14,6 +14,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["AgentMemory", "CodeAgent", "MultiStepAgent", "ToolCallingAgent"]
import importlib.resources
import inspect
import re
@ -21,7 +24,7 @@ import textwrap
import time
from collections import deque
from logging import getLogger
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, TypedDict, Union
import yaml
from jinja2 import StrictUndefined, Template
@ -80,6 +83,69 @@ def populate_template(template: str, variables: Dict[str, Any]) -> str:
raise Exception(f"Error during jinja template rendering: {type(e).__name__}: {e}")
class PlanningPromptTemplate(TypedDict):
"""
Prompt templates for the planning step.
Args:
initial_facts (`str`): Initial facts prompt.
initial_plan (`str`): Initial plan prompt.
update_facts_pre_messages (`str`): Update facts pre-messages prompt.
update_facts_post_messages (`str`): Update facts post-messages prompt.
update_plan_pre_messages (`str`): Update plan pre-messages prompt.
update_plan_post_messages (`str`): Update plan post-messages prompt.
"""
initial_facts: str
initial_plan: str
update_facts_pre_messages: str
update_facts_post_messages: str
update_plan_pre_messages: str
update_plan_post_messages: str
class ManagedAgentPromptTemplate(TypedDict):
"""
Prompt templates for the managed agent.
Args:
task (`str`): Task prompt.
report (`str`): Report prompt.
"""
task: str
report: str
class PromptTemplates(TypedDict):
"""
Prompt templates for the agent.
Args:
system_prompt (`str`): System prompt.
planning ([`~agents.PlanningPromptTemplate`]): Planning prompt template.
managed_agent ([`~agents.ManagedAgentPromptTemplate`]): Managed agent prompt template.
"""
system_prompt: str
planning: PlanningPromptTemplate
managed_agent: ManagedAgentPromptTemplate
EMPTY_PROMPT_TEMPLATES = PromptTemplates(
system_prompt="",
planning=PlanningPromptTemplate(
initial_facts="",
initial_plan="",
update_facts_pre_messages="",
update_facts_post_messages="",
update_plan_pre_messages="",
update_plan_post_messages="",
),
managed_agent=ManagedAgentPromptTemplate(task="", report=""),
)
class MultiStepAgent:
"""
Agent class that solves the given task step by step, using the ReAct framework:
@ -88,7 +154,7 @@ class MultiStepAgent:
Args:
tools (`list[Tool]`): [`Tool`]s that the agent can use.
model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
prompt_templates (`dict`, *optional*): Prompt templates.
prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates.
max_steps (`int`, default `6`): Maximum number of steps the agent can take to solve the task.
tool_parser (`Callable`, *optional*): Function used to parse the tool calls from the LLM output.
add_base_tools (`bool`, default `False`): Whether to add the base tools to the agent's tools.
@ -107,7 +173,7 @@ class MultiStepAgent:
self,
tools: List[Tool],
model: Callable[[List[Dict[str, str]]], ChatMessage],
prompt_templates: Optional[dict] = None,
prompt_templates: Optional[PromptTemplates] = None,
max_steps: int = 6,
tool_parser: Optional[Callable] = None,
add_base_tools: bool = False,
@ -125,7 +191,7 @@ class MultiStepAgent:
tool_parser = parse_json_tool_call
self.agent_name = self.__class__.__name__
self.model = model
self.prompt_templates = prompt_templates or {}
self.prompt_templates = prompt_templates or EMPTY_PROMPT_TEMPLATES
self.max_steps = max_steps
self.step_number: int = 0
self.tool_parser = tool_parser
@ -653,7 +719,7 @@ class ToolCallingAgent(MultiStepAgent):
Args:
tools (`list[Tool]`): [`Tool`]s that the agent can use.
model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
prompt_templates (`dict`, *optional*): Prompt templates.
prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates.
planning_interval (`int`, *optional*): Interval at which the agent will run a planning step.
**kwargs: Additional keyword arguments.
"""
@ -662,7 +728,7 @@ class ToolCallingAgent(MultiStepAgent):
self,
tools: List[Tool],
model: Callable[[List[Dict[str, str]]], ChatMessage],
prompt_templates: Optional[dict] = None,
prompt_templates: Optional[PromptTemplates] = None,
planning_interval: Optional[int] = None,
**kwargs,
):
@ -775,7 +841,7 @@ class CodeAgent(MultiStepAgent):
Args:
tools (`list[Tool]`): [`Tool`]s that the agent can use.
model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
prompt_templates (`dict`, *optional*): Prompt templates.
prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates.
grammar (`dict[str, str]`, *optional*): Grammar used to parse the LLM output.
additional_authorized_imports (`list[str]`, *optional*): Additional authorized imports for the agent.
planning_interval (`int`, *optional*): Interval at which the agent will run a planning step.
@ -789,7 +855,7 @@ class CodeAgent(MultiStepAgent):
self,
tools: List[Tool],
model: Callable[[List[Dict[str, str]]], ChatMessage],
prompt_templates: Optional[dict] = None,
prompt_templates: Optional[PromptTemplates] = None,
grammar: Optional[Dict[str, str]] = None,
additional_authorized_imports: Optional[List[str]] = None,
planning_interval: Optional[int] = None,
@ -941,6 +1007,3 @@ class CodeAgent(MultiStepAgent):
self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO)
memory_step.action_output = output
return output if is_final_answer else None
__all__ = ["MultiStepAgent", "CodeAgent", "ToolCallingAgent", "AgentMemory"]