Clean local python interpreter: propagate imports (#175)

This commit is contained in:
Aymeric Roucher 2025-01-13 17:23:03 +01:00 committed by GitHub
parent a5a3448551
commit c611dfc7e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 763 additions and 245 deletions

File diff suppressed because one or more lines are too long

View File

@ -884,7 +884,6 @@ class CodeAgent(MultiStepAgent):
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
grammar: Optional[Dict[str, str]] = None, grammar: Optional[Dict[str, str]] = None,
additional_authorized_imports: Optional[List[str]] = None, additional_authorized_imports: Optional[List[str]] = None,
allow_all_imports: bool = False,
planning_interval: Optional[int] = None, planning_interval: Optional[int] = None,
use_e2b_executor: bool = False, use_e2b_executor: bool = False,
**kwargs, **kwargs,
@ -899,35 +898,9 @@ class CodeAgent(MultiStepAgent):
planning_interval=planning_interval, planning_interval=planning_interval,
**kwargs, **kwargs,
) )
if ( allow_all_imports and
( not(additional_authorized_imports is None) and (len(additional_authorized_imports)) > 0)):
raise Exception(
f"You passed both allow_all_imports and additional_authorized_imports. Please choose one."
)
if allow_all_imports: additional_authorized_imports=['*']
self.additional_authorized_imports = ( self.additional_authorized_imports = (
additional_authorized_imports if additional_authorized_imports else [] additional_authorized_imports if additional_authorized_imports else []
) )
if use_e2b_executor and len(self.managed_agents) > 0:
raise Exception(
f"You passed both {use_e2b_executor=} and some managed agents. Managed agents is not yet supported with remote code execution."
)
all_tools = {**self.tools, **self.managed_agents}
if use_e2b_executor:
self.python_executor = E2BExecutor(
self.additional_authorized_imports, list(all_tools.values())
)
else:
self.python_executor = LocalPythonInterpreter(
self.additional_authorized_imports, all_tools
)
if allow_all_imports:
self.authorized_imports = 'all imports without restriction'
else:
self.authorized_imports = list( self.authorized_imports = list(
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports) set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
) )
@ -936,7 +909,34 @@ class CodeAgent(MultiStepAgent):
"Tag '{{authorized_imports}}' should be provided in the prompt." "Tag '{{authorized_imports}}' should be provided in the prompt."
) )
self.system_prompt = self.system_prompt.replace( self.system_prompt = self.system_prompt.replace(
"{{authorized_imports}}", str(self.authorized_imports) "{{authorized_imports}}",
"You can import from any package you want."
if "*" in self.authorized_imports
else str(self.authorized_imports),
)
if "*" in self.additional_authorized_imports:
self.logger.log(
"Caution: you set an authorization for all imports, meaning your agent can decide to import any package it deems necessary. This might raise issues if the package is not installed in your environment.",
0,
)
if use_e2b_executor and len(self.managed_agents) > 0:
raise Exception(
f"You passed both {use_e2b_executor=} and some managed agents. Managed agents is not yet supported with remote code execution."
)
all_tools = {**self.tools, **self.managed_agents}
if use_e2b_executor:
self.python_executor = E2BExecutor(
self.additional_authorized_imports,
list(all_tools.values()),
self.logger,
)
else:
self.python_executor = LocalPythonInterpreter(
self.additional_authorized_imports,
all_tools,
) )
def step(self, log_entry: ActionStep) -> Union[None, Any]: def step(self, log_entry: ActionStep) -> Union[None, Any]:

View File

@ -26,13 +26,13 @@ from PIL import Image
from .tool_validation import validate_tool_attributes from .tool_validation import validate_tool_attributes
from .tools import Tool from .tools import Tool
from .utils import BASE_BUILTIN_MODULES, console, instance_to_source from .utils import BASE_BUILTIN_MODULES, instance_to_source
load_dotenv() load_dotenv()
class E2BExecutor: class E2BExecutor:
def __init__(self, additional_imports: List[str], tools: List[Tool]): def __init__(self, additional_imports: List[str], tools: List[Tool], logger):
self.custom_tools = {} self.custom_tools = {}
self.sbx = Sandbox() # "qywp2ctmu2q7jzprcf4j") self.sbx = Sandbox() # "qywp2ctmu2q7jzprcf4j")
# TODO: validate installing agents package or not # TODO: validate installing agents package or not
@ -42,6 +42,7 @@ class E2BExecutor:
# timeout=300 # timeout=300
# ) # )
# print("Installation of agents package finished.") # print("Installation of agents package finished.")
self.logger = logger
additional_imports = additional_imports + ["pickle5"] additional_imports = additional_imports + ["pickle5"]
if len(additional_imports) > 0: if len(additional_imports) > 0:
execution = self.sbx.commands.run( execution = self.sbx.commands.run(
@ -50,7 +51,7 @@ class E2BExecutor:
if execution.error: if execution.error:
raise Exception(f"Error installing dependencies: {execution.error}") raise Exception(f"Error installing dependencies: {execution.error}")
else: else:
console.print(f"Installation of {additional_imports} succeeded!") logger.log(f"Installation of {additional_imports} succeeded!", 0)
tool_codes = [] tool_codes = []
for tool in tools: for tool in tools:
@ -74,7 +75,7 @@ class E2BExecutor:
tool_definition_code += "\n\n".join(tool_codes) tool_definition_code += "\n\n".join(tool_codes)
tool_definition_execution = self.run_code_raise_errors(tool_definition_code) tool_definition_execution = self.run_code_raise_errors(tool_definition_code)
console.print(tool_definition_execution.logs) self.logger.log(tool_definition_execution.logs)
def run_code_raise_errors(self, code: str): def run_code_raise_errors(self, code: str):
execution = self.sbx.run_code( execution = self.sbx.run_code(
@ -109,7 +110,7 @@ locals().update({key: value for key, value in pickle_dict.items()})
""" """
execution = self.run_code_raise_errors(remote_unloading_code) execution = self.run_code_raise_errors(remote_unloading_code)
execution_logs = "\n".join([str(log) for log in execution.logs.stdout]) execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
console.print(execution_logs) self.logger.log(execution_logs, 1)
execution = self.run_code_raise_errors(code_action) execution = self.run_code_raise_errors(code_action)
execution_logs = "\n".join([str(log) for log in execution.logs.stdout]) execution_logs = "\n".join([str(log) for log in execution.logs.stdout])

View File

@ -100,7 +100,15 @@ class GradioUI:
yield messages yield messages
yield messages yield messages
def upload_file(self, file, allowed_file_types=["application/pdf", "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "text/plain"]): def upload_file(
self,
file,
allowed_file_types=[
"application/pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"text/plain",
],
):
""" """
Handle file uploads, default allowed types are pdf, docx, and .txt Handle file uploads, default allowed types are pdf, docx, and .txt
""" """
@ -110,7 +118,6 @@ class GradioUI:
return "No file uploaded" return "No file uploaded"
# Check if file is in allowed filetypes # Check if file is in allowed filetypes
name = os.path.basename(file.name)
try: try:
mime_type, _ = mimetypes.guess_type(file.name) mime_type, _ = mimetypes.guess_type(file.name)
except Exception as e: except Exception as e:
@ -121,7 +128,9 @@ class GradioUI:
# Sanitize file name # Sanitize file name
original_name = os.path.basename(file.name) original_name = os.path.basename(file.name)
sanitized_name = re.sub(r'[^\w\-.]', '_', original_name) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores sanitized_name = re.sub(
r"[^\w\-.]", "_", original_name
) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores
type_to_ext = {} type_to_ext = {}
for ext, t in mimetypes.types_map.items(): for ext, t in mimetypes.types_map.items():
@ -134,7 +143,9 @@ class GradioUI:
sanitized_name = "".join(sanitized_name) sanitized_name = "".join(sanitized_name)
# Save the uploaded file to the specified folder # Save the uploaded file to the specified folder
file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name)) file_path = os.path.join(
self.file_upload_folder, os.path.basename(sanitized_name)
)
shutil.copy(file.name, file_path) shutil.copy(file.name, file_path)
return f"File uploaded successfully to {self.file_upload_folder}" return f"File uploaded successfully to {self.file_upload_folder}"
@ -155,9 +166,7 @@ class GradioUI:
upload_file = gr.File(label="Upload a file") upload_file = gr.File(label="Upload a file")
upload_status = gr.Textbox(label="Upload Status", interactive=False) upload_status = gr.Textbox(label="Upload Status", interactive=False)
upload_file.change( upload_file.change(self.upload_file, [upload_file], [upload_status])
self.upload_file, [upload_file], [upload_status]
)
text_input = gr.Textbox(lines=1, label="Chat Message") text_input = gr.Textbox(lines=1, label="Chat Message")
text_input.submit( text_input.submit(
lambda s: (s, ""), [text_input], [stored_message, text_input] lambda s: (s, ""), [text_input], [stored_message, text_input]

File diff suppressed because it is too large Load Diff

View File

@ -313,9 +313,11 @@ class AgentTests(unittest.TestCase):
assert isinstance(output, float) assert isinstance(output, float)
assert output == 7.2904 assert output == 7.2904
assert agent.logs[1].task == "What is 2 multiplied by 3.6452?" assert agent.logs[1].task == "What is 2 multiplied by 3.6452?"
assert agent.logs[3].tool_call == ToolCall( assert agent.logs[3].tool_calls == [
ToolCall(
name="python_interpreter", arguments="final_answer(7.2904)", id="call_3" name="python_interpreter", arguments="final_answer(7.2904)", id="call_3"
) )
]
def test_additional_args_added_to_task(self): def test_additional_args_added_to_task(self):
agent = CodeAgent(tools=[], model=fake_code_model) agent = CodeAgent(tools=[], model=fake_code_model)