Create PromptTemplates typed dict (#547)

This commit is contained in:
Albert Villanova del Moral 2025-02-12 08:45:56 +01:00 committed by GitHub
parent a17f915f61
commit 02b2b7ebb9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 100 additions and 13 deletions

View File

@ -57,3 +57,11 @@ _This class is deprecated since 1.8.0: now you simply need to pass attributes `n
> You must have `gradio` installed to use the UI. Please run `pip install smolagents[gradio]` if it's not the case. > You must have `gradio` installed to use the UI. Please run `pip install smolagents[gradio]` if it's not the case.
[[autodoc]] GradioUI [[autodoc]] GradioUI
## Prompts
[[autodoc]] smolagents.agents.PromptTemplates
[[autodoc]] smolagents.agents.PlanningPromptTemplate
[[autodoc]] smolagents.agents.ManagedAgentPromptTemplate

View File

@ -154,4 +154,12 @@ model = OpenAIServerModel(
api_base="https://api.openai.com/v1", api_base="https://api.openai.com/v1",
api_key=os.environ["OPENAI_API_KEY"], api_key=os.environ["OPENAI_API_KEY"],
) )
``` ```
## Prompts
[[autodoc]] smolagents.agents.PromptTemplates
[[autodoc]] smolagents.agents.PlanningPromptTemplate
[[autodoc]] smolagents.agents.ManagedAgentPromptTemplate

View File

@ -146,4 +146,12 @@ model = LiteLLMModel("anthropic/claude-3-5-sonnet-latest", temperature=0.2, max_
print(model(messages)) print(model(messages))
``` ```
[[autodoc]] LiteLLMModel [[autodoc]] LiteLLMModel
## Prompts
[[autodoc]] smolagents.agents.PromptTemplates
[[autodoc]] smolagents.agents.PlanningPromptTemplate
[[autodoc]] smolagents.agents.ManagedAgentPromptTemplate

View File

@ -14,6 +14,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
__all__ = ["AgentMemory", "CodeAgent", "MultiStepAgent", "ToolCallingAgent"]
import importlib.resources import importlib.resources
import inspect import inspect
import re import re
@ -21,7 +24,7 @@ import textwrap
import time import time
from collections import deque from collections import deque
from logging import getLogger from logging import getLogger
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, TypedDict, Union
import yaml import yaml
from jinja2 import StrictUndefined, Template from jinja2 import StrictUndefined, Template
@ -80,6 +83,69 @@ def populate_template(template: str, variables: Dict[str, Any]) -> str:
raise Exception(f"Error during jinja template rendering: {type(e).__name__}: {e}") raise Exception(f"Error during jinja template rendering: {type(e).__name__}: {e}")
class PlanningPromptTemplate(TypedDict):
"""
Prompt templates for the planning step.
Args:
initial_facts (`str`): Initial facts prompt.
initial_plan (`str`): Initial plan prompt.
update_facts_pre_messages (`str`): Update facts pre-messages prompt.
update_facts_post_messages (`str`): Update facts post-messages prompt.
update_plan_pre_messages (`str`): Update plan pre-messages prompt.
update_plan_post_messages (`str`): Update plan post-messages prompt.
"""
initial_facts: str
initial_plan: str
update_facts_pre_messages: str
update_facts_post_messages: str
update_plan_pre_messages: str
update_plan_post_messages: str
class ManagedAgentPromptTemplate(TypedDict):
"""
Prompt templates for the managed agent.
Args:
task (`str`): Task prompt.
report (`str`): Report prompt.
"""
task: str
report: str
class PromptTemplates(TypedDict):
"""
Prompt templates for the agent.
Args:
system_prompt (`str`): System prompt.
planning ([`~agents.PlanningPromptTemplate`]): Planning prompt template.
managed_agent ([`~agents.ManagedAgentPromptTemplate`]): Managed agent prompt template.
"""
system_prompt: str
planning: PlanningPromptTemplate
managed_agent: ManagedAgentPromptTemplate
EMPTY_PROMPT_TEMPLATES = PromptTemplates(
system_prompt="",
planning=PlanningPromptTemplate(
initial_facts="",
initial_plan="",
update_facts_pre_messages="",
update_facts_post_messages="",
update_plan_pre_messages="",
update_plan_post_messages="",
),
managed_agent=ManagedAgentPromptTemplate(task="", report=""),
)
class MultiStepAgent: class MultiStepAgent:
""" """
Agent class that solves the given task step by step, using the ReAct framework: Agent class that solves the given task step by step, using the ReAct framework:
@ -88,7 +154,7 @@ class MultiStepAgent:
Args: Args:
tools (`list[Tool]`): [`Tool`]s that the agent can use. tools (`list[Tool]`): [`Tool`]s that the agent can use.
model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
prompt_templates (`dict`, *optional*): Prompt templates. prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates.
max_steps (`int`, default `6`): Maximum number of steps the agent can take to solve the task. max_steps (`int`, default `6`): Maximum number of steps the agent can take to solve the task.
tool_parser (`Callable`, *optional*): Function used to parse the tool calls from the LLM output. tool_parser (`Callable`, *optional*): Function used to parse the tool calls from the LLM output.
add_base_tools (`bool`, default `False`): Whether to add the base tools to the agent's tools. add_base_tools (`bool`, default `False`): Whether to add the base tools to the agent's tools.
@ -107,7 +173,7 @@ class MultiStepAgent:
self, self,
tools: List[Tool], tools: List[Tool],
model: Callable[[List[Dict[str, str]]], ChatMessage], model: Callable[[List[Dict[str, str]]], ChatMessage],
prompt_templates: Optional[dict] = None, prompt_templates: Optional[PromptTemplates] = None,
max_steps: int = 6, max_steps: int = 6,
tool_parser: Optional[Callable] = None, tool_parser: Optional[Callable] = None,
add_base_tools: bool = False, add_base_tools: bool = False,
@ -125,7 +191,7 @@ class MultiStepAgent:
tool_parser = parse_json_tool_call tool_parser = parse_json_tool_call
self.agent_name = self.__class__.__name__ self.agent_name = self.__class__.__name__
self.model = model self.model = model
self.prompt_templates = prompt_templates or {} self.prompt_templates = prompt_templates or EMPTY_PROMPT_TEMPLATES
self.max_steps = max_steps self.max_steps = max_steps
self.step_number: int = 0 self.step_number: int = 0
self.tool_parser = tool_parser self.tool_parser = tool_parser
@ -653,7 +719,7 @@ class ToolCallingAgent(MultiStepAgent):
Args: Args:
tools (`list[Tool]`): [`Tool`]s that the agent can use. tools (`list[Tool]`): [`Tool`]s that the agent can use.
model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
prompt_templates (`dict`, *optional*): Prompt templates. prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates.
planning_interval (`int`, *optional*): Interval at which the agent will run a planning step. planning_interval (`int`, *optional*): Interval at which the agent will run a planning step.
**kwargs: Additional keyword arguments. **kwargs: Additional keyword arguments.
""" """
@ -662,7 +728,7 @@ class ToolCallingAgent(MultiStepAgent):
self, self,
tools: List[Tool], tools: List[Tool],
model: Callable[[List[Dict[str, str]]], ChatMessage], model: Callable[[List[Dict[str, str]]], ChatMessage],
prompt_templates: Optional[dict] = None, prompt_templates: Optional[PromptTemplates] = None,
planning_interval: Optional[int] = None, planning_interval: Optional[int] = None,
**kwargs, **kwargs,
): ):
@ -775,7 +841,7 @@ class CodeAgent(MultiStepAgent):
Args: Args:
tools (`list[Tool]`): [`Tool`]s that the agent can use. tools (`list[Tool]`): [`Tool`]s that the agent can use.
model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
prompt_templates (`dict`, *optional*): Prompt templates. prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates.
grammar (`dict[str, str]`, *optional*): Grammar used to parse the LLM output. grammar (`dict[str, str]`, *optional*): Grammar used to parse the LLM output.
additional_authorized_imports (`list[str]`, *optional*): Additional authorized imports for the agent. additional_authorized_imports (`list[str]`, *optional*): Additional authorized imports for the agent.
planning_interval (`int`, *optional*): Interval at which the agent will run a planning step. planning_interval (`int`, *optional*): Interval at which the agent will run a planning step.
@ -789,7 +855,7 @@ class CodeAgent(MultiStepAgent):
self, self,
tools: List[Tool], tools: List[Tool],
model: Callable[[List[Dict[str, str]]], ChatMessage], model: Callable[[List[Dict[str, str]]], ChatMessage],
prompt_templates: Optional[dict] = None, prompt_templates: Optional[PromptTemplates] = None,
grammar: Optional[Dict[str, str]] = None, grammar: Optional[Dict[str, str]] = None,
additional_authorized_imports: Optional[List[str]] = None, additional_authorized_imports: Optional[List[str]] = None,
planning_interval: Optional[int] = None, planning_interval: Optional[int] = None,
@ -941,6 +1007,3 @@ class CodeAgent(MultiStepAgent):
self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO) self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO)
memory_step.action_output = output memory_step.action_output = output
return output if is_final_answer else None return output if is_final_answer else None
__all__ = ["MultiStepAgent", "CodeAgent", "ToolCallingAgent", "AgentMemory"]