Clean local python interpreter: propagate imports (#175)

This commit is contained in:
Aymeric Roucher 2025-01-13 17:23:03 +01:00 committed by GitHub
parent a5a3448551
commit c611dfc7e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 763 additions and 245 deletions

File diff suppressed because one or more lines are too long

View File

@ -884,7 +884,6 @@ class CodeAgent(MultiStepAgent):
system_prompt: Optional[str] = None,
grammar: Optional[Dict[str, str]] = None,
additional_authorized_imports: Optional[List[str]] = None,
allow_all_imports: bool = False,
planning_interval: Optional[int] = None,
use_e2b_executor: bool = False,
**kwargs,
@ -899,18 +898,29 @@ class CodeAgent(MultiStepAgent):
planning_interval=planning_interval,
**kwargs,
)
if ( allow_all_imports and
( not(additional_authorized_imports is None) and (len(additional_authorized_imports)) > 0)):
raise Exception(
f"You passed both allow_all_imports and additional_authorized_imports. Please choose one."
)
if allow_all_imports: additional_authorized_imports=['*']
self.additional_authorized_imports = (
additional_authorized_imports if additional_authorized_imports else []
)
self.authorized_imports = list(
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
)
if "{{authorized_imports}}" not in self.system_prompt:
raise AgentError(
"Tag '{{authorized_imports}}' should be provided in the prompt."
)
self.system_prompt = self.system_prompt.replace(
"{{authorized_imports}}",
"You can import from any package you want."
if "*" in self.authorized_imports
else str(self.authorized_imports),
)
if "*" in self.additional_authorized_imports:
self.logger.log(
"Caution: you set an authorization for all imports, meaning your agent can decide to import any package it deems necessary. This might raise issues if the package is not installed in your environment.",
0,
)
if use_e2b_executor and len(self.managed_agents) > 0:
raise Exception(
f"You passed both {use_e2b_executor=} and some managed agents. Managed agents is not yet supported with remote code execution."
@ -919,25 +929,15 @@ class CodeAgent(MultiStepAgent):
all_tools = {**self.tools, **self.managed_agents}
if use_e2b_executor:
self.python_executor = E2BExecutor(
self.additional_authorized_imports, list(all_tools.values())
self.additional_authorized_imports,
list(all_tools.values()),
self.logger,
)
else:
self.python_executor = LocalPythonInterpreter(
self.additional_authorized_imports, all_tools
self.additional_authorized_imports,
all_tools,
)
if allow_all_imports:
self.authorized_imports = 'all imports without restriction'
else:
self.authorized_imports = list(
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
)
if "{{authorized_imports}}" not in self.system_prompt:
raise AgentError(
"Tag '{{authorized_imports}}' should be provided in the prompt."
)
self.system_prompt = self.system_prompt.replace(
"{{authorized_imports}}", str(self.authorized_imports)
)
def step(self, log_entry: ActionStep) -> Union[None, Any]:
"""

View File

@ -26,13 +26,13 @@ from PIL import Image
from .tool_validation import validate_tool_attributes
from .tools import Tool
from .utils import BASE_BUILTIN_MODULES, console, instance_to_source
from .utils import BASE_BUILTIN_MODULES, instance_to_source
load_dotenv()
class E2BExecutor:
def __init__(self, additional_imports: List[str], tools: List[Tool]):
def __init__(self, additional_imports: List[str], tools: List[Tool], logger):
self.custom_tools = {}
self.sbx = Sandbox() # "qywp2ctmu2q7jzprcf4j")
# TODO: validate installing agents package or not
@ -42,6 +42,7 @@ class E2BExecutor:
# timeout=300
# )
# print("Installation of agents package finished.")
self.logger = logger
additional_imports = additional_imports + ["pickle5"]
if len(additional_imports) > 0:
execution = self.sbx.commands.run(
@ -50,7 +51,7 @@ class E2BExecutor:
if execution.error:
raise Exception(f"Error installing dependencies: {execution.error}")
else:
console.print(f"Installation of {additional_imports} succeeded!")
logger.log(f"Installation of {additional_imports} succeeded!", 0)
tool_codes = []
for tool in tools:
@ -74,7 +75,7 @@ class E2BExecutor:
tool_definition_code += "\n\n".join(tool_codes)
tool_definition_execution = self.run_code_raise_errors(tool_definition_code)
console.print(tool_definition_execution.logs)
self.logger.log(tool_definition_execution.logs)
def run_code_raise_errors(self, code: str):
execution = self.sbx.run_code(
@ -109,7 +110,7 @@ locals().update({key: value for key, value in pickle_dict.items()})
"""
execution = self.run_code_raise_errors(remote_unloading_code)
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
console.print(execution_logs)
self.logger.log(execution_logs, 1)
execution = self.run_code_raise_errors(code_action)
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])

View File

@ -85,7 +85,7 @@ def stream_to_gradio(
class GradioUI:
"""A one-line interface to launch your agent in Gradio"""
def __init__(self, agent: MultiStepAgent, file_upload_folder: str | None=None):
def __init__(self, agent: MultiStepAgent, file_upload_folder: str | None = None):
self.agent = agent
self.file_upload_folder = file_upload_folder
if self.file_upload_folder is not None:
@ -100,7 +100,15 @@ class GradioUI:
yield messages
yield messages
def upload_file(self, file, allowed_file_types=["application/pdf", "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "text/plain"]):
def upload_file(
self,
file,
allowed_file_types=[
"application/pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"text/plain",
],
):
"""
Handle file uploads, default allowed types are pdf, docx, and .txt
"""
@ -110,7 +118,6 @@ class GradioUI:
return "No file uploaded"
# Check if file is in allowed filetypes
name = os.path.basename(file.name)
try:
mime_type, _ = mimetypes.guess_type(file.name)
except Exception as e:
@ -121,7 +128,9 @@ class GradioUI:
# Sanitize file name
original_name = os.path.basename(file.name)
sanitized_name = re.sub(r'[^\w\-.]', '_', original_name) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores
sanitized_name = re.sub(
r"[^\w\-.]", "_", original_name
) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores
type_to_ext = {}
for ext, t in mimetypes.types_map.items():
@ -134,7 +143,9 @@ class GradioUI:
sanitized_name = "".join(sanitized_name)
# Save the uploaded file to the specified folder
file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name))
file_path = os.path.join(
self.file_upload_folder, os.path.basename(sanitized_name)
)
shutil.copy(file.name, file_path)
return f"File uploaded successfully to {self.file_upload_folder}"
@ -155,9 +166,7 @@ class GradioUI:
upload_file = gr.File(label="Upload a file")
upload_status = gr.Textbox(label="Upload Status", interactive=False)
upload_file.change(
self.upload_file, [upload_file], [upload_status]
)
upload_file.change(self.upload_file, [upload_file], [upload_status])
text_input = gr.Textbox(lines=1, label="Chat Message")
text_input.submit(
lambda s: (s, ""), [text_input], [stored_message, text_input]

File diff suppressed because it is too large Load Diff

View File

@ -313,9 +313,11 @@ class AgentTests(unittest.TestCase):
assert isinstance(output, float)
assert output == 7.2904
assert agent.logs[1].task == "What is 2 multiplied by 3.6452?"
assert agent.logs[3].tool_call == ToolCall(
name="python_interpreter", arguments="final_answer(7.2904)", id="call_3"
)
assert agent.logs[3].tool_calls == [
ToolCall(
name="python_interpreter", arguments="final_answer(7.2904)", id="call_3"
)
]
def test_additional_args_added_to_task(self):
agent = CodeAgent(tools=[], model=fake_code_model)