Fix additional args sent to e2b server

This commit is contained in:
Aymeric 2024-12-26 17:59:15 +01:00
parent 1abaf69b67
commit f8b9cb34f9
8 changed files with 73 additions and 58 deletions

View File

@ -62,7 +62,7 @@ and [reaches higher performance on difficult benchmarks](https://huggingface.co/
Especially, since code execution can be a security concern (arbitrary code execution!), we provide options at runtime: Especially, since code execution can be a security concern (arbitrary code execution!), we provide options at runtime:
- a secure python interpreter to run code more safely in your environment - a secure python interpreter to run code more safely in your environment
- a sandboxed environment. - a sandboxed environment using [E2B](https://e2b.dev/).
## How lightweight is it? ## How lightweight is it?

View File

@ -33,7 +33,7 @@ from huggingface_hub import login, InferenceClient
login("<YOUR_HUGGINGFACEHUB_API_TOKEN>") login("<YOUR_HUGGINGFACEHUB_API_TOKEN>")
model_id = "Qwen/Qwen2.5-72B-Instruct" model_id = "meta-llama/Llama-3.3-70B-Instruct"
client = InferenceClient(model=model_id) client = InferenceClient(model=model_id)
@ -71,12 +71,19 @@ agent.run(
Note that we used an additional `additional_detail` argument: you can additional kwargs to `agent.run()`, they will be baked into the prompt as text. Note that we used an additional `additional_detail` argument: you can additional kwargs to `agent.run()`, they will be baked into the prompt as text.
You can use this to indicate the path to local or remote files for the model to use: You can use this to pass files in various formats:
```py ```py
agent = CodeAgent(tools=[], model=model, add_base_tools=True) from smolagents import CodeAgent, HfApiModel
agent.run("Why does Mike not know many people in New York?", audio="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/recording.mp3") model_id = "meta-llama/Llama-3.3-70B-Instruct"
agent = CodeAgent(tools=[], model=HfApiModel(model_id=model_id), add_base_tools=True)
agent.run(
"Why does Mike not know many people in New York?",
additional_args={"mp3_sound_file_url":'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/recording.mp3'}
)
``` ```
It's important to explain as clearly as possible the task you want to perform. It's important to explain as clearly as possible the task you want to perform.

View File

@ -27,12 +27,11 @@ LAUNCH_GRADIO = False
get_cat_image = GetCatImageTool() get_cat_image = GetCatImageTool()
agent = CodeAgent( agent = CodeAgent(
tools = [get_cat_image, VisitWebpageTool()], tools = [get_cat_image, VisitWebpageTool()],
model=HfApiModel(), model=HfApiModel(),
additional_authorized_imports=["Pillow", "requests", "markdownify"], # "duckduckgo-search", additional_authorized_imports=["Pillow", "requests", "markdownify"], # "duckduckgo-search",
use_e2b_executor=False use_e2b_executor=True
) )
if LAUNCH_GRADIO: if LAUNCH_GRADIO:
@ -41,6 +40,5 @@ if LAUNCH_GRADIO:
GradioUI(agent).launch() GradioUI(agent).launch()
else: else:
agent.run( agent.run(
"Return me an image of Lincoln's preferred pet", "Return me an image of a cat. Directly use the image provided in your state.", additional_args={"cat_image":get_cat_image()}
additional_context="Here is a webpage about US presidents and pets: https://www.9lives.com/blog/a-history-of-cats-in-the-white-house/" ) # Asking to directly return the image from state tests that additional_args are properly sent to server.
)

View File

@ -188,6 +188,7 @@ class MultiStepAgent:
self.tool_parser = tool_parser self.tool_parser = tool_parser
self.grammar = grammar self.grammar = grammar
self.planning_interval = planning_interval self.planning_interval = planning_interval
self.state = {}
self.managed_agents = {} self.managed_agents = {}
if managed_agents is not None: if managed_agents is not None:
@ -370,8 +371,7 @@ class MultiStepAgent:
return self.model(self.input_messages) return self.model(self.input_messages)
except Exception as e: except Exception as e:
error_msg = f"Error in generating final LLM output:\n{e}" error_msg = f"Error in generating final LLM output:\n{e}"
console.print(f"[bold red]{error_msg}[/bold red]") raise AgentGenerationError(error_msg)
return error_msg
def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any: def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any:
""" """
@ -385,7 +385,6 @@ class MultiStepAgent:
available_tools = {**self.toolbox.tools, **self.managed_agents} available_tools = {**self.toolbox.tools, **self.managed_agents}
if tool_name not in available_tools: if tool_name not in available_tools:
error_msg = f"Unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}." error_msg = f"Unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg) raise AgentExecutionError(error_msg)
try: try:
@ -398,7 +397,6 @@ class MultiStepAgent:
observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True) observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True)
else: else:
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}." error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg) raise AgentExecutionError(error_msg)
return observation return observation
except Exception as e: except Exception as e:
@ -410,14 +408,12 @@ class MultiStepAgent:
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n" f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
f"As a reminder, this tool's description is the following:\n{tool_description}" f"As a reminder, this tool's description is the following:\n{tool_description}"
) )
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg) raise AgentExecutionError(error_msg)
elif tool_name in self.managed_agents: elif tool_name in self.managed_agents:
error_msg = ( error_msg = (
f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n" f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}" f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
) )
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg) raise AgentExecutionError(error_msg)
def step(self, log_entry: ActionStep) -> Union[None, Any]: def step(self, log_entry: ActionStep) -> Union[None, Any]:
@ -430,7 +426,7 @@ class MultiStepAgent:
stream: bool = False, stream: bool = False,
reset: bool = True, reset: bool = True,
single_step: bool = False, single_step: bool = False,
**kwargs, additional_args: Optional[Dict] = None,
): ):
""" """
Runs the agent for the given task. Runs the agent for the given task.
@ -440,6 +436,7 @@ class MultiStepAgent:
stream (`bool`): Wether to run in a streaming way. stream (`bool`): Wether to run in a streaming way.
reset (`bool`): Wether to reset the conversation or keep it going from previous run. reset (`bool`): Wether to reset the conversation or keep it going from previous run.
single_step (`bool`): Should the agent run in one shot or multi-step fashion? single_step (`bool`): Should the agent run in one shot or multi-step fashion?
additional_args (`dict`): Any other variables that you want to pass to the agent run, for instance images or dataframes. Give them clear names!
Example: Example:
```py ```py
@ -449,11 +446,11 @@ class MultiStepAgent:
``` ```
""" """
self.task = task self.task = task
if len(kwargs) > 0: if additional_args is not None:
self.task += ( self.state.update(additional_args)
f"\nYou have been provided with these initial arguments: {str(kwargs)}." self.task += f"""
) You have been provided with these additional arguments, that you can access as variables in your python code using the keys:
self.state = kwargs.copy() {str(additional_args)}."""
self.initialize_system_prompt() self.initialize_system_prompt()
system_prompt_step = SystemPromptStep(system_prompt=self.system_prompt) system_prompt_step = SystemPromptStep(system_prompt=self.system_prompt)
@ -468,14 +465,9 @@ class MultiStepAgent:
else: else:
self.logs.append(system_prompt_step) self.logs.append(system_prompt_step)
# console.print(
# Group(
# Rule("[bold]New run", characters="═", style=YELLOW_HEX), Text(self.task)
# )
# )
console.print( console.print(
Panel( Panel(
f"\n[bold]{task.strip()}\n", f"\n[bold]{self.task.strip()}\n",
title="[bold]New run", title="[bold]New run",
subtitle=f"{type(self.model).__name__} - {(self.model.model_id if hasattr(self.model, "model_id") else "")}", subtitle=f"{type(self.model).__name__} - {(self.model.model_id if hasattr(self.model, "model_id") else "")}",
border_style=YELLOW_HEX, border_style=YELLOW_HEX,
@ -891,17 +883,6 @@ class CodeAgent(MultiStepAgent):
console.print_exception() console.print_exception()
raise AgentGenerationError(f"Error in generating model output:\n{e}") raise AgentGenerationError(f"Error in generating model output:\n{e}")
# from rich.live import Live
# from rich.markdown import Markdown
# import time
# with Live(console=console, vertical_overflow="visible") as live:
# message = ""
# for i in range(100):
# time.sleep(0.02)
# message += str(i)
# live.update(Markdown(message))
if self.verbose: if self.verbose:
console.print( console.print(
Group( Group(
@ -946,6 +927,7 @@ class CodeAgent(MultiStepAgent):
try: try:
output, execution_logs = self.python_executor( output, execution_logs = self.python_executor(
code_action, code_action,
self.state,
) )
execution_outputs_console = [] execution_outputs_console = []
if len(execution_logs) > 0: if len(execution_logs) > 0:

View File

@ -295,7 +295,7 @@ class SpeechToTextTool(PipelineTool):
pre_processor_class = WhisperProcessor pre_processor_class = WhisperProcessor
model_class = WhisperForConditionalGeneration model_class = WhisperForConditionalGeneration
inputs = {"audio": {"type": "audio", "description": "The audio to transcribe"}} inputs = {"audio": {"type": "audio", "description": "The audio to transcribe. Can be a local path, an url, or a tensor."}}
output_type = "string" output_type = "string"
def encode(self, audio): def encode(self, audio):

View File

@ -17,13 +17,14 @@
from dotenv import load_dotenv from dotenv import load_dotenv
import textwrap import textwrap
import base64 import base64
import pickle
from io import BytesIO from io import BytesIO
from PIL import Image from PIL import Image
from e2b_code_interpreter import Sandbox from e2b_code_interpreter import Sandbox
from typing import List, Tuple, Any from typing import List, Tuple, Any
from .tool_validation import validate_tool_attributes from .tool_validation import validate_tool_attributes
from .utils import instance_to_source, BASE_BUILTIN_MODULES from .utils import instance_to_source, BASE_BUILTIN_MODULES, console
from .tools import Tool from .tools import Tool
load_dotenv() load_dotenv()
@ -40,6 +41,7 @@ class E2BExecutor:
# timeout=300 # timeout=300
# ) # )
# print("Installation of agents package finished.") # print("Installation of agents package finished.")
additional_imports = additional_imports + ["pickle5"]
if len(additional_imports) > 0: if len(additional_imports) > 0:
execution = self.sbx.commands.run( execution = self.sbx.commands.run(
"pip install " + " ".join(additional_imports) "pip install " + " ".join(additional_imports)
@ -47,7 +49,7 @@ class E2BExecutor:
if execution.error: if execution.error:
raise Exception(f"Error installing dependencies: {execution.error}") raise Exception(f"Error installing dependencies: {execution.error}")
else: else:
print("Installation succeeded!") console.print(f"Installation of {additional_imports} succeeded!")
tool_codes = [] tool_codes = []
for tool in tools: for tool in tools:
@ -71,21 +73,44 @@ class E2BExecutor:
tool_definition_code += "\n\n".join(tool_codes) tool_definition_code += "\n\n".join(tool_codes)
tool_definition_execution = self.run_code_raise_errors(tool_definition_code) tool_definition_execution = self.run_code_raise_errors(tool_definition_code)
print(tool_definition_execution.logs) console.print(tool_definition_execution.logs)
def run_code_raise_errors(self, code: str): def run_code_raise_errors(self, code: str):
execution = self.sbx.run_code( execution = self.sbx.run_code(
code, code,
) )
if execution.error: if execution.error:
logs = "Executing code yielded an error:" execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
logs = execution_logs
logs += "Executing code yielded an error:"
logs += execution.error.name logs += execution.error.name
logs += execution.error.value logs += execution.error.value
logs += execution.error.traceback logs += execution.error.traceback
raise ValueError(logs) raise ValueError(logs)
return execution return execution
def __call__(self, code_action: str) -> Tuple[Any, Any]: def __call__(self, code_action: str, additional_args: dict) -> Tuple[Any, Any]:
if len(additional_args) > 0:
# Pickle additional_args to server
import tempfile
with tempfile.NamedTemporaryFile() as f:
pickle.dump(additional_args, f)
f.flush()
with open(f.name, "rb") as file:
self.sbx.files.write("/home/state.pkl", file)
remote_unloading_code = """import pickle
import os
print("File path", os.path.getsize('/home/state.pkl'))
with open('/home/state.pkl', 'rb') as f:
pickle_dict = pickle.load(f)
locals().update({key: value for key, value in pickle_dict.items()})
"""
execution = self.run_code_raise_errors(remote_unloading_code)
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
console.print(execution_logs)
execution = self.run_code_raise_errors(code_action) execution = self.run_code_raise_errors(code_action)
execution_logs = "\n".join([str(log) for log in execution.logs.stdout]) execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
if not execution.results: if not execution.results:

View File

@ -1058,7 +1058,8 @@ class LocalPythonInterpreter:
} }
# TODO: assert self.authorized imports are all installed locally # TODO: assert self.authorized imports are all installed locally
def __call__(self, code_action: str) -> Tuple[Any, str]: def __call__(self, code_action: str, additional_variables: Dict) -> Tuple[Any, str]:
self.state.update(additional_variables)
output = evaluate_python_code( output = evaluate_python_code(
code_action, code_action,
static_tools=self.static_tools, static_tools=self.static_tools,

View File

@ -201,20 +201,21 @@ class Tool:
assert getattr(self, "output_type", None) in AUTHORIZED_TYPES assert getattr(self, "output_type", None) in AUTHORIZED_TYPES
# Validate forward function signature # Validate forward function signature, except for PipelineTool
signature = inspect.signature(self.forward) if not (hasattr(self, "is_pipeline_tool") and getattr(self, "is_pipeline_tool") is True):
signature = inspect.signature(self.forward)
if not set(signature.parameters.keys()) == set(self.inputs.keys()): if not set(signature.parameters.keys()) == set(self.inputs.keys()):
raise Exception( raise Exception(
"Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
) )
json_schema = _convert_type_hints_to_json_schema(self.forward) json_schema = _convert_type_hints_to_json_schema(self.forward)
for key, value in self.inputs.items(): for key, value in self.inputs.items():
if "nullable" in value: if "nullable" in value:
assert (key in json_schema and "nullable" in json_schema[key]), f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature." assert (key in json_schema and "nullable" in json_schema[key]), f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature."
if key in json_schema and "nullable" in json_schema[key]: if key in json_schema and "nullable" in json_schema[key]:
assert "nullable" in value, f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs." assert "nullable" in value, f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs."
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return NotImplementedError("Write this method in your subclass of `Tool`.") return NotImplementedError("Write this method in your subclass of `Tool`.")
@ -1074,6 +1075,7 @@ class PipelineTool(Tool):
name = "pipeline" name = "pipeline"
inputs = {"prompt": str} inputs = {"prompt": str}
output_type = str output_type = str
is_pipeline_tool = True
def __init__( def __init__(
self, self,