Clean local python interpreter: propagate imports (#175)
This commit is contained in:
parent
a5a3448551
commit
c611dfc7e5
File diff suppressed because one or more lines are too long
|
@ -884,7 +884,6 @@ class CodeAgent(MultiStepAgent):
|
|||
system_prompt: Optional[str] = None,
|
||||
grammar: Optional[Dict[str, str]] = None,
|
||||
additional_authorized_imports: Optional[List[str]] = None,
|
||||
allow_all_imports: bool = False,
|
||||
planning_interval: Optional[int] = None,
|
||||
use_e2b_executor: bool = False,
|
||||
**kwargs,
|
||||
|
@ -899,18 +898,29 @@ class CodeAgent(MultiStepAgent):
|
|||
planning_interval=planning_interval,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if ( allow_all_imports and
|
||||
( not(additional_authorized_imports is None) and (len(additional_authorized_imports)) > 0)):
|
||||
raise Exception(
|
||||
f"You passed both allow_all_imports and additional_authorized_imports. Please choose one."
|
||||
)
|
||||
|
||||
if allow_all_imports: additional_authorized_imports=['*']
|
||||
|
||||
self.additional_authorized_imports = (
|
||||
additional_authorized_imports if additional_authorized_imports else []
|
||||
)
|
||||
self.authorized_imports = list(
|
||||
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
|
||||
)
|
||||
if "{{authorized_imports}}" not in self.system_prompt:
|
||||
raise AgentError(
|
||||
"Tag '{{authorized_imports}}' should be provided in the prompt."
|
||||
)
|
||||
self.system_prompt = self.system_prompt.replace(
|
||||
"{{authorized_imports}}",
|
||||
"You can import from any package you want."
|
||||
if "*" in self.authorized_imports
|
||||
else str(self.authorized_imports),
|
||||
)
|
||||
|
||||
if "*" in self.additional_authorized_imports:
|
||||
self.logger.log(
|
||||
"Caution: you set an authorization for all imports, meaning your agent can decide to import any package it deems necessary. This might raise issues if the package is not installed in your environment.",
|
||||
0,
|
||||
)
|
||||
|
||||
if use_e2b_executor and len(self.managed_agents) > 0:
|
||||
raise Exception(
|
||||
f"You passed both {use_e2b_executor=} and some managed agents. Managed agents is not yet supported with remote code execution."
|
||||
|
@ -919,25 +929,15 @@ class CodeAgent(MultiStepAgent):
|
|||
all_tools = {**self.tools, **self.managed_agents}
|
||||
if use_e2b_executor:
|
||||
self.python_executor = E2BExecutor(
|
||||
self.additional_authorized_imports, list(all_tools.values())
|
||||
self.additional_authorized_imports,
|
||||
list(all_tools.values()),
|
||||
self.logger,
|
||||
)
|
||||
else:
|
||||
self.python_executor = LocalPythonInterpreter(
|
||||
self.additional_authorized_imports, all_tools
|
||||
self.additional_authorized_imports,
|
||||
all_tools,
|
||||
)
|
||||
if allow_all_imports:
|
||||
self.authorized_imports = 'all imports without restriction'
|
||||
else:
|
||||
self.authorized_imports = list(
|
||||
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
|
||||
)
|
||||
if "{{authorized_imports}}" not in self.system_prompt:
|
||||
raise AgentError(
|
||||
"Tag '{{authorized_imports}}' should be provided in the prompt."
|
||||
)
|
||||
self.system_prompt = self.system_prompt.replace(
|
||||
"{{authorized_imports}}", str(self.authorized_imports)
|
||||
)
|
||||
|
||||
def step(self, log_entry: ActionStep) -> Union[None, Any]:
|
||||
"""
|
||||
|
|
|
@ -26,13 +26,13 @@ from PIL import Image
|
|||
|
||||
from .tool_validation import validate_tool_attributes
|
||||
from .tools import Tool
|
||||
from .utils import BASE_BUILTIN_MODULES, console, instance_to_source
|
||||
from .utils import BASE_BUILTIN_MODULES, instance_to_source
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class E2BExecutor:
|
||||
def __init__(self, additional_imports: List[str], tools: List[Tool]):
|
||||
def __init__(self, additional_imports: List[str], tools: List[Tool], logger):
|
||||
self.custom_tools = {}
|
||||
self.sbx = Sandbox() # "qywp2ctmu2q7jzprcf4j")
|
||||
# TODO: validate installing agents package or not
|
||||
|
@ -42,6 +42,7 @@ class E2BExecutor:
|
|||
# timeout=300
|
||||
# )
|
||||
# print("Installation of agents package finished.")
|
||||
self.logger = logger
|
||||
additional_imports = additional_imports + ["pickle5"]
|
||||
if len(additional_imports) > 0:
|
||||
execution = self.sbx.commands.run(
|
||||
|
@ -50,7 +51,7 @@ class E2BExecutor:
|
|||
if execution.error:
|
||||
raise Exception(f"Error installing dependencies: {execution.error}")
|
||||
else:
|
||||
console.print(f"Installation of {additional_imports} succeeded!")
|
||||
logger.log(f"Installation of {additional_imports} succeeded!", 0)
|
||||
|
||||
tool_codes = []
|
||||
for tool in tools:
|
||||
|
@ -74,7 +75,7 @@ class E2BExecutor:
|
|||
tool_definition_code += "\n\n".join(tool_codes)
|
||||
|
||||
tool_definition_execution = self.run_code_raise_errors(tool_definition_code)
|
||||
console.print(tool_definition_execution.logs)
|
||||
self.logger.log(tool_definition_execution.logs)
|
||||
|
||||
def run_code_raise_errors(self, code: str):
|
||||
execution = self.sbx.run_code(
|
||||
|
@ -109,7 +110,7 @@ locals().update({key: value for key, value in pickle_dict.items()})
|
|||
"""
|
||||
execution = self.run_code_raise_errors(remote_unloading_code)
|
||||
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
|
||||
console.print(execution_logs)
|
||||
self.logger.log(execution_logs, 1)
|
||||
|
||||
execution = self.run_code_raise_errors(code_action)
|
||||
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
|
||||
|
|
|
@ -85,7 +85,7 @@ def stream_to_gradio(
|
|||
class GradioUI:
|
||||
"""A one-line interface to launch your agent in Gradio"""
|
||||
|
||||
def __init__(self, agent: MultiStepAgent, file_upload_folder: str | None=None):
|
||||
def __init__(self, agent: MultiStepAgent, file_upload_folder: str | None = None):
|
||||
self.agent = agent
|
||||
self.file_upload_folder = file_upload_folder
|
||||
if self.file_upload_folder is not None:
|
||||
|
@ -100,7 +100,15 @@ class GradioUI:
|
|||
yield messages
|
||||
yield messages
|
||||
|
||||
def upload_file(self, file, allowed_file_types=["application/pdf", "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "text/plain"]):
|
||||
def upload_file(
|
||||
self,
|
||||
file,
|
||||
allowed_file_types=[
|
||||
"application/pdf",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"text/plain",
|
||||
],
|
||||
):
|
||||
"""
|
||||
Handle file uploads, default allowed types are pdf, docx, and .txt
|
||||
"""
|
||||
|
@ -110,18 +118,19 @@ class GradioUI:
|
|||
return "No file uploaded"
|
||||
|
||||
# Check if file is in allowed filetypes
|
||||
name = os.path.basename(file.name)
|
||||
try:
|
||||
mime_type, _ = mimetypes.guess_type(file.name)
|
||||
except Exception as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
|
||||
if mime_type not in allowed_file_types:
|
||||
return "File type disallowed"
|
||||
|
||||
|
||||
# Sanitize file name
|
||||
original_name = os.path.basename(file.name)
|
||||
sanitized_name = re.sub(r'[^\w\-.]', '_', original_name) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores
|
||||
sanitized_name = re.sub(
|
||||
r"[^\w\-.]", "_", original_name
|
||||
) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores
|
||||
|
||||
type_to_ext = {}
|
||||
for ext, t in mimetypes.types_map.items():
|
||||
|
@ -134,7 +143,9 @@ class GradioUI:
|
|||
sanitized_name = "".join(sanitized_name)
|
||||
|
||||
# Save the uploaded file to the specified folder
|
||||
file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name))
|
||||
file_path = os.path.join(
|
||||
self.file_upload_folder, os.path.basename(sanitized_name)
|
||||
)
|
||||
shutil.copy(file.name, file_path)
|
||||
|
||||
return f"File uploaded successfully to {self.file_upload_folder}"
|
||||
|
@ -155,9 +166,7 @@ class GradioUI:
|
|||
upload_file = gr.File(label="Upload a file")
|
||||
upload_status = gr.Textbox(label="Upload Status", interactive=False)
|
||||
|
||||
upload_file.change(
|
||||
self.upload_file, [upload_file], [upload_status]
|
||||
)
|
||||
upload_file.change(self.upload_file, [upload_file], [upload_status])
|
||||
text_input = gr.Textbox(lines=1, label="Chat Message")
|
||||
text_input.submit(
|
||||
lambda s: (s, ""), [text_input], [stored_message, text_input]
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -361,7 +361,7 @@ class TransformersModel(Model):
|
|||
)
|
||||
prompt_tensor = prompt_tensor.to(self.model.device)
|
||||
count_prompt_tokens = prompt_tensor["input_ids"].shape[1]
|
||||
|
||||
|
||||
out = self.model.generate(
|
||||
**prompt_tensor,
|
||||
max_new_tokens=max_tokens,
|
||||
|
|
|
@ -313,9 +313,11 @@ class AgentTests(unittest.TestCase):
|
|||
assert isinstance(output, float)
|
||||
assert output == 7.2904
|
||||
assert agent.logs[1].task == "What is 2 multiplied by 3.6452?"
|
||||
assert agent.logs[3].tool_call == ToolCall(
|
||||
name="python_interpreter", arguments="final_answer(7.2904)", id="call_3"
|
||||
)
|
||||
assert agent.logs[3].tool_calls == [
|
||||
ToolCall(
|
||||
name="python_interpreter", arguments="final_answer(7.2904)", id="call_3"
|
||||
)
|
||||
]
|
||||
|
||||
def test_additional_args_added_to_task(self):
|
||||
agent = CodeAgent(tools=[], model=fake_code_model)
|
||||
|
|
Loading…
Reference in New Issue