Create PromptTemplates typed dict (#547)
This commit is contained in:
parent
a17f915f61
commit
02b2b7ebb9
|
@ -57,3 +57,11 @@ _This class is deprecated since 1.8.0: now you simply need to pass attributes `n
|
||||||
> You must have `gradio` installed to use the UI. Please run `pip install smolagents[gradio]` if it's not the case.
|
> You must have `gradio` installed to use the UI. Please run `pip install smolagents[gradio]` if it's not the case.
|
||||||
|
|
||||||
[[autodoc]] GradioUI
|
[[autodoc]] GradioUI
|
||||||
|
|
||||||
|
## Prompts
|
||||||
|
|
||||||
|
[[autodoc]] smolagents.agents.PromptTemplates
|
||||||
|
|
||||||
|
[[autodoc]] smolagents.agents.PlanningPromptTemplate
|
||||||
|
|
||||||
|
[[autodoc]] smolagents.agents.ManagedAgentPromptTemplate
|
||||||
|
|
|
@ -154,4 +154,12 @@ model = OpenAIServerModel(
|
||||||
api_base="https://api.openai.com/v1",
|
api_base="https://api.openai.com/v1",
|
||||||
api_key=os.environ["OPENAI_API_KEY"],
|
api_key=os.environ["OPENAI_API_KEY"],
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Prompts
|
||||||
|
|
||||||
|
[[autodoc]] smolagents.agents.PromptTemplates
|
||||||
|
|
||||||
|
[[autodoc]] smolagents.agents.PlanningPromptTemplate
|
||||||
|
|
||||||
|
[[autodoc]] smolagents.agents.ManagedAgentPromptTemplate
|
||||||
|
|
|
@ -146,4 +146,12 @@ model = LiteLLMModel("anthropic/claude-3-5-sonnet-latest", temperature=0.2, max_
|
||||||
print(model(messages))
|
print(model(messages))
|
||||||
```
|
```
|
||||||
|
|
||||||
[[autodoc]] LiteLLMModel
|
[[autodoc]] LiteLLMModel
|
||||||
|
|
||||||
|
## Prompts
|
||||||
|
|
||||||
|
[[autodoc]] smolagents.agents.PromptTemplates
|
||||||
|
|
||||||
|
[[autodoc]] smolagents.agents.PlanningPromptTemplate
|
||||||
|
|
||||||
|
[[autodoc]] smolagents.agents.ManagedAgentPromptTemplate
|
||||||
|
|
|
@ -14,6 +14,9 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
__all__ = ["AgentMemory", "CodeAgent", "MultiStepAgent", "ToolCallingAgent"]
|
||||||
|
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
import inspect
|
import inspect
|
||||||
import re
|
import re
|
||||||
|
@ -21,7 +24,7 @@ import textwrap
|
||||||
import time
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union
|
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, TypedDict, Union
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from jinja2 import StrictUndefined, Template
|
from jinja2 import StrictUndefined, Template
|
||||||
|
@ -80,6 +83,69 @@ def populate_template(template: str, variables: Dict[str, Any]) -> str:
|
||||||
raise Exception(f"Error during jinja template rendering: {type(e).__name__}: {e}")
|
raise Exception(f"Error during jinja template rendering: {type(e).__name__}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class PlanningPromptTemplate(TypedDict):
|
||||||
|
"""
|
||||||
|
Prompt templates for the planning step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial_facts (`str`): Initial facts prompt.
|
||||||
|
initial_plan (`str`): Initial plan prompt.
|
||||||
|
update_facts_pre_messages (`str`): Update facts pre-messages prompt.
|
||||||
|
update_facts_post_messages (`str`): Update facts post-messages prompt.
|
||||||
|
update_plan_pre_messages (`str`): Update plan pre-messages prompt.
|
||||||
|
update_plan_post_messages (`str`): Update plan post-messages prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
initial_facts: str
|
||||||
|
initial_plan: str
|
||||||
|
update_facts_pre_messages: str
|
||||||
|
update_facts_post_messages: str
|
||||||
|
update_plan_pre_messages: str
|
||||||
|
update_plan_post_messages: str
|
||||||
|
|
||||||
|
|
||||||
|
class ManagedAgentPromptTemplate(TypedDict):
|
||||||
|
"""
|
||||||
|
Prompt templates for the managed agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (`str`): Task prompt.
|
||||||
|
report (`str`): Report prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
task: str
|
||||||
|
report: str
|
||||||
|
|
||||||
|
|
||||||
|
class PromptTemplates(TypedDict):
|
||||||
|
"""
|
||||||
|
Prompt templates for the agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_prompt (`str`): System prompt.
|
||||||
|
planning ([`~agents.PlanningPromptTemplate`]): Planning prompt template.
|
||||||
|
managed_agent ([`~agents.ManagedAgentPromptTemplate`]): Managed agent prompt template.
|
||||||
|
"""
|
||||||
|
|
||||||
|
system_prompt: str
|
||||||
|
planning: PlanningPromptTemplate
|
||||||
|
managed_agent: ManagedAgentPromptTemplate
|
||||||
|
|
||||||
|
|
||||||
|
EMPTY_PROMPT_TEMPLATES = PromptTemplates(
|
||||||
|
system_prompt="",
|
||||||
|
planning=PlanningPromptTemplate(
|
||||||
|
initial_facts="",
|
||||||
|
initial_plan="",
|
||||||
|
update_facts_pre_messages="",
|
||||||
|
update_facts_post_messages="",
|
||||||
|
update_plan_pre_messages="",
|
||||||
|
update_plan_post_messages="",
|
||||||
|
),
|
||||||
|
managed_agent=ManagedAgentPromptTemplate(task="", report=""),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MultiStepAgent:
|
class MultiStepAgent:
|
||||||
"""
|
"""
|
||||||
Agent class that solves the given task step by step, using the ReAct framework:
|
Agent class that solves the given task step by step, using the ReAct framework:
|
||||||
|
@ -88,7 +154,7 @@ class MultiStepAgent:
|
||||||
Args:
|
Args:
|
||||||
tools (`list[Tool]`): [`Tool`]s that the agent can use.
|
tools (`list[Tool]`): [`Tool`]s that the agent can use.
|
||||||
model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
|
model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
|
||||||
prompt_templates (`dict`, *optional*): Prompt templates.
|
prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates.
|
||||||
max_steps (`int`, default `6`): Maximum number of steps the agent can take to solve the task.
|
max_steps (`int`, default `6`): Maximum number of steps the agent can take to solve the task.
|
||||||
tool_parser (`Callable`, *optional*): Function used to parse the tool calls from the LLM output.
|
tool_parser (`Callable`, *optional*): Function used to parse the tool calls from the LLM output.
|
||||||
add_base_tools (`bool`, default `False`): Whether to add the base tools to the agent's tools.
|
add_base_tools (`bool`, default `False`): Whether to add the base tools to the agent's tools.
|
||||||
|
@ -107,7 +173,7 @@ class MultiStepAgent:
|
||||||
self,
|
self,
|
||||||
tools: List[Tool],
|
tools: List[Tool],
|
||||||
model: Callable[[List[Dict[str, str]]], ChatMessage],
|
model: Callable[[List[Dict[str, str]]], ChatMessage],
|
||||||
prompt_templates: Optional[dict] = None,
|
prompt_templates: Optional[PromptTemplates] = None,
|
||||||
max_steps: int = 6,
|
max_steps: int = 6,
|
||||||
tool_parser: Optional[Callable] = None,
|
tool_parser: Optional[Callable] = None,
|
||||||
add_base_tools: bool = False,
|
add_base_tools: bool = False,
|
||||||
|
@ -125,7 +191,7 @@ class MultiStepAgent:
|
||||||
tool_parser = parse_json_tool_call
|
tool_parser = parse_json_tool_call
|
||||||
self.agent_name = self.__class__.__name__
|
self.agent_name = self.__class__.__name__
|
||||||
self.model = model
|
self.model = model
|
||||||
self.prompt_templates = prompt_templates or {}
|
self.prompt_templates = prompt_templates or EMPTY_PROMPT_TEMPLATES
|
||||||
self.max_steps = max_steps
|
self.max_steps = max_steps
|
||||||
self.step_number: int = 0
|
self.step_number: int = 0
|
||||||
self.tool_parser = tool_parser
|
self.tool_parser = tool_parser
|
||||||
|
@ -653,7 +719,7 @@ class ToolCallingAgent(MultiStepAgent):
|
||||||
Args:
|
Args:
|
||||||
tools (`list[Tool]`): [`Tool`]s that the agent can use.
|
tools (`list[Tool]`): [`Tool`]s that the agent can use.
|
||||||
model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
|
model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
|
||||||
prompt_templates (`dict`, *optional*): Prompt templates.
|
prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates.
|
||||||
planning_interval (`int`, *optional*): Interval at which the agent will run a planning step.
|
planning_interval (`int`, *optional*): Interval at which the agent will run a planning step.
|
||||||
**kwargs: Additional keyword arguments.
|
**kwargs: Additional keyword arguments.
|
||||||
"""
|
"""
|
||||||
|
@ -662,7 +728,7 @@ class ToolCallingAgent(MultiStepAgent):
|
||||||
self,
|
self,
|
||||||
tools: List[Tool],
|
tools: List[Tool],
|
||||||
model: Callable[[List[Dict[str, str]]], ChatMessage],
|
model: Callable[[List[Dict[str, str]]], ChatMessage],
|
||||||
prompt_templates: Optional[dict] = None,
|
prompt_templates: Optional[PromptTemplates] = None,
|
||||||
planning_interval: Optional[int] = None,
|
planning_interval: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
@ -775,7 +841,7 @@ class CodeAgent(MultiStepAgent):
|
||||||
Args:
|
Args:
|
||||||
tools (`list[Tool]`): [`Tool`]s that the agent can use.
|
tools (`list[Tool]`): [`Tool`]s that the agent can use.
|
||||||
model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
|
model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
|
||||||
prompt_templates (`dict`, *optional*): Prompt templates.
|
prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates.
|
||||||
grammar (`dict[str, str]`, *optional*): Grammar used to parse the LLM output.
|
grammar (`dict[str, str]`, *optional*): Grammar used to parse the LLM output.
|
||||||
additional_authorized_imports (`list[str]`, *optional*): Additional authorized imports for the agent.
|
additional_authorized_imports (`list[str]`, *optional*): Additional authorized imports for the agent.
|
||||||
planning_interval (`int`, *optional*): Interval at which the agent will run a planning step.
|
planning_interval (`int`, *optional*): Interval at which the agent will run a planning step.
|
||||||
|
@ -789,7 +855,7 @@ class CodeAgent(MultiStepAgent):
|
||||||
self,
|
self,
|
||||||
tools: List[Tool],
|
tools: List[Tool],
|
||||||
model: Callable[[List[Dict[str, str]]], ChatMessage],
|
model: Callable[[List[Dict[str, str]]], ChatMessage],
|
||||||
prompt_templates: Optional[dict] = None,
|
prompt_templates: Optional[PromptTemplates] = None,
|
||||||
grammar: Optional[Dict[str, str]] = None,
|
grammar: Optional[Dict[str, str]] = None,
|
||||||
additional_authorized_imports: Optional[List[str]] = None,
|
additional_authorized_imports: Optional[List[str]] = None,
|
||||||
planning_interval: Optional[int] = None,
|
planning_interval: Optional[int] = None,
|
||||||
|
@ -941,6 +1007,3 @@ class CodeAgent(MultiStepAgent):
|
||||||
self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO)
|
self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO)
|
||||||
memory_step.action_output = output
|
memory_step.action_output = output
|
||||||
return output if is_final_answer else None
|
return output if is_final_answer else None
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["MultiStepAgent", "CodeAgent", "ToolCallingAgent", "AgentMemory"]
|
|
||||||
|
|
Loading…
Reference in New Issue