Clean local python interpreter: propagate imports (#175)
This commit is contained in:
parent
a5a3448551
commit
c611dfc7e5
File diff suppressed because one or more lines are too long
|
@ -884,7 +884,6 @@ class CodeAgent(MultiStepAgent):
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
grammar: Optional[Dict[str, str]] = None,
|
grammar: Optional[Dict[str, str]] = None,
|
||||||
additional_authorized_imports: Optional[List[str]] = None,
|
additional_authorized_imports: Optional[List[str]] = None,
|
||||||
allow_all_imports: bool = False,
|
|
||||||
planning_interval: Optional[int] = None,
|
planning_interval: Optional[int] = None,
|
||||||
use_e2b_executor: bool = False,
|
use_e2b_executor: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -899,18 +898,29 @@ class CodeAgent(MultiStepAgent):
|
||||||
planning_interval=planning_interval,
|
planning_interval=planning_interval,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if ( allow_all_imports and
|
|
||||||
( not(additional_authorized_imports is None) and (len(additional_authorized_imports)) > 0)):
|
|
||||||
raise Exception(
|
|
||||||
f"You passed both allow_all_imports and additional_authorized_imports. Please choose one."
|
|
||||||
)
|
|
||||||
|
|
||||||
if allow_all_imports: additional_authorized_imports=['*']
|
|
||||||
|
|
||||||
self.additional_authorized_imports = (
|
self.additional_authorized_imports = (
|
||||||
additional_authorized_imports if additional_authorized_imports else []
|
additional_authorized_imports if additional_authorized_imports else []
|
||||||
)
|
)
|
||||||
|
self.authorized_imports = list(
|
||||||
|
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
|
||||||
|
)
|
||||||
|
if "{{authorized_imports}}" not in self.system_prompt:
|
||||||
|
raise AgentError(
|
||||||
|
"Tag '{{authorized_imports}}' should be provided in the prompt."
|
||||||
|
)
|
||||||
|
self.system_prompt = self.system_prompt.replace(
|
||||||
|
"{{authorized_imports}}",
|
||||||
|
"You can import from any package you want."
|
||||||
|
if "*" in self.authorized_imports
|
||||||
|
else str(self.authorized_imports),
|
||||||
|
)
|
||||||
|
|
||||||
|
if "*" in self.additional_authorized_imports:
|
||||||
|
self.logger.log(
|
||||||
|
"Caution: you set an authorization for all imports, meaning your agent can decide to import any package it deems necessary. This might raise issues if the package is not installed in your environment.",
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
if use_e2b_executor and len(self.managed_agents) > 0:
|
if use_e2b_executor and len(self.managed_agents) > 0:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"You passed both {use_e2b_executor=} and some managed agents. Managed agents is not yet supported with remote code execution."
|
f"You passed both {use_e2b_executor=} and some managed agents. Managed agents is not yet supported with remote code execution."
|
||||||
|
@ -919,25 +929,15 @@ class CodeAgent(MultiStepAgent):
|
||||||
all_tools = {**self.tools, **self.managed_agents}
|
all_tools = {**self.tools, **self.managed_agents}
|
||||||
if use_e2b_executor:
|
if use_e2b_executor:
|
||||||
self.python_executor = E2BExecutor(
|
self.python_executor = E2BExecutor(
|
||||||
self.additional_authorized_imports, list(all_tools.values())
|
self.additional_authorized_imports,
|
||||||
|
list(all_tools.values()),
|
||||||
|
self.logger,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.python_executor = LocalPythonInterpreter(
|
self.python_executor = LocalPythonInterpreter(
|
||||||
self.additional_authorized_imports, all_tools
|
self.additional_authorized_imports,
|
||||||
|
all_tools,
|
||||||
)
|
)
|
||||||
if allow_all_imports:
|
|
||||||
self.authorized_imports = 'all imports without restriction'
|
|
||||||
else:
|
|
||||||
self.authorized_imports = list(
|
|
||||||
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
|
|
||||||
)
|
|
||||||
if "{{authorized_imports}}" not in self.system_prompt:
|
|
||||||
raise AgentError(
|
|
||||||
"Tag '{{authorized_imports}}' should be provided in the prompt."
|
|
||||||
)
|
|
||||||
self.system_prompt = self.system_prompt.replace(
|
|
||||||
"{{authorized_imports}}", str(self.authorized_imports)
|
|
||||||
)
|
|
||||||
|
|
||||||
def step(self, log_entry: ActionStep) -> Union[None, Any]:
|
def step(self, log_entry: ActionStep) -> Union[None, Any]:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -26,13 +26,13 @@ from PIL import Image
|
||||||
|
|
||||||
from .tool_validation import validate_tool_attributes
|
from .tool_validation import validate_tool_attributes
|
||||||
from .tools import Tool
|
from .tools import Tool
|
||||||
from .utils import BASE_BUILTIN_MODULES, console, instance_to_source
|
from .utils import BASE_BUILTIN_MODULES, instance_to_source
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
class E2BExecutor:
|
class E2BExecutor:
|
||||||
def __init__(self, additional_imports: List[str], tools: List[Tool]):
|
def __init__(self, additional_imports: List[str], tools: List[Tool], logger):
|
||||||
self.custom_tools = {}
|
self.custom_tools = {}
|
||||||
self.sbx = Sandbox() # "qywp2ctmu2q7jzprcf4j")
|
self.sbx = Sandbox() # "qywp2ctmu2q7jzprcf4j")
|
||||||
# TODO: validate installing agents package or not
|
# TODO: validate installing agents package or not
|
||||||
|
@ -42,6 +42,7 @@ class E2BExecutor:
|
||||||
# timeout=300
|
# timeout=300
|
||||||
# )
|
# )
|
||||||
# print("Installation of agents package finished.")
|
# print("Installation of agents package finished.")
|
||||||
|
self.logger = logger
|
||||||
additional_imports = additional_imports + ["pickle5"]
|
additional_imports = additional_imports + ["pickle5"]
|
||||||
if len(additional_imports) > 0:
|
if len(additional_imports) > 0:
|
||||||
execution = self.sbx.commands.run(
|
execution = self.sbx.commands.run(
|
||||||
|
@ -50,7 +51,7 @@ class E2BExecutor:
|
||||||
if execution.error:
|
if execution.error:
|
||||||
raise Exception(f"Error installing dependencies: {execution.error}")
|
raise Exception(f"Error installing dependencies: {execution.error}")
|
||||||
else:
|
else:
|
||||||
console.print(f"Installation of {additional_imports} succeeded!")
|
logger.log(f"Installation of {additional_imports} succeeded!", 0)
|
||||||
|
|
||||||
tool_codes = []
|
tool_codes = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
|
@ -74,7 +75,7 @@ class E2BExecutor:
|
||||||
tool_definition_code += "\n\n".join(tool_codes)
|
tool_definition_code += "\n\n".join(tool_codes)
|
||||||
|
|
||||||
tool_definition_execution = self.run_code_raise_errors(tool_definition_code)
|
tool_definition_execution = self.run_code_raise_errors(tool_definition_code)
|
||||||
console.print(tool_definition_execution.logs)
|
self.logger.log(tool_definition_execution.logs)
|
||||||
|
|
||||||
def run_code_raise_errors(self, code: str):
|
def run_code_raise_errors(self, code: str):
|
||||||
execution = self.sbx.run_code(
|
execution = self.sbx.run_code(
|
||||||
|
@ -109,7 +110,7 @@ locals().update({key: value for key, value in pickle_dict.items()})
|
||||||
"""
|
"""
|
||||||
execution = self.run_code_raise_errors(remote_unloading_code)
|
execution = self.run_code_raise_errors(remote_unloading_code)
|
||||||
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
|
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
|
||||||
console.print(execution_logs)
|
self.logger.log(execution_logs, 1)
|
||||||
|
|
||||||
execution = self.run_code_raise_errors(code_action)
|
execution = self.run_code_raise_errors(code_action)
|
||||||
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
|
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
|
||||||
|
|
|
@ -85,7 +85,7 @@ def stream_to_gradio(
|
||||||
class GradioUI:
|
class GradioUI:
|
||||||
"""A one-line interface to launch your agent in Gradio"""
|
"""A one-line interface to launch your agent in Gradio"""
|
||||||
|
|
||||||
def __init__(self, agent: MultiStepAgent, file_upload_folder: str | None=None):
|
def __init__(self, agent: MultiStepAgent, file_upload_folder: str | None = None):
|
||||||
self.agent = agent
|
self.agent = agent
|
||||||
self.file_upload_folder = file_upload_folder
|
self.file_upload_folder = file_upload_folder
|
||||||
if self.file_upload_folder is not None:
|
if self.file_upload_folder is not None:
|
||||||
|
@ -100,7 +100,15 @@ class GradioUI:
|
||||||
yield messages
|
yield messages
|
||||||
yield messages
|
yield messages
|
||||||
|
|
||||||
def upload_file(self, file, allowed_file_types=["application/pdf", "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "text/plain"]):
|
def upload_file(
|
||||||
|
self,
|
||||||
|
file,
|
||||||
|
allowed_file_types=[
|
||||||
|
"application/pdf",
|
||||||
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||||
|
"text/plain",
|
||||||
|
],
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Handle file uploads, default allowed types are pdf, docx, and .txt
|
Handle file uploads, default allowed types are pdf, docx, and .txt
|
||||||
"""
|
"""
|
||||||
|
@ -110,7 +118,6 @@ class GradioUI:
|
||||||
return "No file uploaded"
|
return "No file uploaded"
|
||||||
|
|
||||||
# Check if file is in allowed filetypes
|
# Check if file is in allowed filetypes
|
||||||
name = os.path.basename(file.name)
|
|
||||||
try:
|
try:
|
||||||
mime_type, _ = mimetypes.guess_type(file.name)
|
mime_type, _ = mimetypes.guess_type(file.name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -121,7 +128,9 @@ class GradioUI:
|
||||||
|
|
||||||
# Sanitize file name
|
# Sanitize file name
|
||||||
original_name = os.path.basename(file.name)
|
original_name = os.path.basename(file.name)
|
||||||
sanitized_name = re.sub(r'[^\w\-.]', '_', original_name) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores
|
sanitized_name = re.sub(
|
||||||
|
r"[^\w\-.]", "_", original_name
|
||||||
|
) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores
|
||||||
|
|
||||||
type_to_ext = {}
|
type_to_ext = {}
|
||||||
for ext, t in mimetypes.types_map.items():
|
for ext, t in mimetypes.types_map.items():
|
||||||
|
@ -134,7 +143,9 @@ class GradioUI:
|
||||||
sanitized_name = "".join(sanitized_name)
|
sanitized_name = "".join(sanitized_name)
|
||||||
|
|
||||||
# Save the uploaded file to the specified folder
|
# Save the uploaded file to the specified folder
|
||||||
file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name))
|
file_path = os.path.join(
|
||||||
|
self.file_upload_folder, os.path.basename(sanitized_name)
|
||||||
|
)
|
||||||
shutil.copy(file.name, file_path)
|
shutil.copy(file.name, file_path)
|
||||||
|
|
||||||
return f"File uploaded successfully to {self.file_upload_folder}"
|
return f"File uploaded successfully to {self.file_upload_folder}"
|
||||||
|
@ -155,9 +166,7 @@ class GradioUI:
|
||||||
upload_file = gr.File(label="Upload a file")
|
upload_file = gr.File(label="Upload a file")
|
||||||
upload_status = gr.Textbox(label="Upload Status", interactive=False)
|
upload_status = gr.Textbox(label="Upload Status", interactive=False)
|
||||||
|
|
||||||
upload_file.change(
|
upload_file.change(self.upload_file, [upload_file], [upload_status])
|
||||||
self.upload_file, [upload_file], [upload_status]
|
|
||||||
)
|
|
||||||
text_input = gr.Textbox(lines=1, label="Chat Message")
|
text_input = gr.Textbox(lines=1, label="Chat Message")
|
||||||
text_input.submit(
|
text_input.submit(
|
||||||
lambda s: (s, ""), [text_input], [stored_message, text_input]
|
lambda s: (s, ""), [text_input], [stored_message, text_input]
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -313,9 +313,11 @@ class AgentTests(unittest.TestCase):
|
||||||
assert isinstance(output, float)
|
assert isinstance(output, float)
|
||||||
assert output == 7.2904
|
assert output == 7.2904
|
||||||
assert agent.logs[1].task == "What is 2 multiplied by 3.6452?"
|
assert agent.logs[1].task == "What is 2 multiplied by 3.6452?"
|
||||||
assert agent.logs[3].tool_call == ToolCall(
|
assert agent.logs[3].tool_calls == [
|
||||||
name="python_interpreter", arguments="final_answer(7.2904)", id="call_3"
|
ToolCall(
|
||||||
)
|
name="python_interpreter", arguments="final_answer(7.2904)", id="call_3"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
def test_additional_args_added_to_task(self):
|
def test_additional_args_added_to_task(self):
|
||||||
agent = CodeAgent(tools=[], model=fake_code_model)
|
agent = CodeAgent(tools=[], model=fake_code_model)
|
||||||
|
|
Loading…
Reference in New Issue