Fix missing python modules in CodeAgent system prompt (#226)

* fix modules in system prompt + test
This commit is contained in:
Edward Beeching 2025-01-17 11:59:30 +01:00 committed by GitHub
parent 11a738e53a
commit fabc59aa08
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 23 deletions

View File

@ -895,6 +895,24 @@ class CodeAgent(MultiStepAgent):
):
if system_prompt is None:
system_prompt = CODE_SYSTEM_PROMPT
self.additional_authorized_imports = (
additional_authorized_imports if additional_authorized_imports else []
)
self.authorized_imports = list(
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
)
if "{{authorized_imports}}" not in system_prompt:
raise AgentError(
"Tag '{{authorized_imports}}' should be provided in the prompt."
)
if "*" in self.additional_authorized_imports:
self.logger.log(
"Caution: you set an authorization for all imports, meaning your agent can decide to import any package it deems necessary. This might raise issues if the package is not installed in your environment.",
0,
)
super().__init__(
tools=tools,
model=model,
@ -903,29 +921,6 @@ class CodeAgent(MultiStepAgent):
planning_interval=planning_interval,
**kwargs,
)
self.additional_authorized_imports = (
additional_authorized_imports if additional_authorized_imports else []
)
self.authorized_imports = list(
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
)
if "{{authorized_imports}}" not in self.system_prompt:
raise AgentError(
"Tag '{{authorized_imports}}' should be provided in the prompt."
)
self.system_prompt = self.system_prompt.replace(
"{{authorized_imports}}",
"You can import from any package you want."
if "*" in self.authorized_imports
else str(self.authorized_imports),
)
if "*" in self.additional_authorized_imports:
self.logger.log(
"Caution: you set an authorization for all imports, meaning your agent can decide to import any package it deems necessary. This might raise issues if the package is not installed in your environment.",
0,
)
if use_e2b_executor and len(self.managed_agents) > 0:
raise Exception(
f"You passed both {use_e2b_executor=} and some managed agents. Managed agents is not yet supported with remote code execution."
@ -944,6 +939,16 @@ class CodeAgent(MultiStepAgent):
all_tools,
)
def initialize_system_prompt(self):
super().initialize_system_prompt()
self.system_prompt = self.system_prompt.replace(
"{{authorized_imports}}",
"You can import from any package you want."
if "*" in self.authorized_imports
else str(self.authorized_imports),
)
return self.system_prompt
def step(self, log_entry: ActionStep) -> Union[None, Any]:
"""
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.

View File

@ -35,6 +35,7 @@ from smolagents.models import (
ChatMessageToolCall,
ChatMessageToolCallDefinition,
)
from smolagents.utils import BASE_BUILTIN_MODULES
def get_new_path(suffix="") -> str:
@ -381,6 +382,12 @@ class AgentTests(unittest.TestCase):
assert tool.name in agent.system_prompt
assert tool.description in agent.system_prompt
def test_module_imports_get_baked_in_system_prompt(self):
agent = CodeAgent(tools=[], model=fake_code_model)
agent.run("Empty task")
for module in BASE_BUILTIN_MODULES:
assert module in agent.system_prompt
def test_init_agent_with_different_toolsets(self):
toolset_1 = []
agent = CodeAgent(tools=toolset_1, model=fake_code_model)