Fix additional args sent to e2b server

This commit is contained in:
Aymeric 2024-12-26 17:59:15 +01:00
parent 1abaf69b67
commit f8b9cb34f9
8 changed files with 73 additions and 58 deletions

View File

@ -62,7 +62,7 @@ and [reaches higher performance on difficult benchmarks](https://huggingface.co/
Especially, since code execution can be a security concern (arbitrary code execution!), we provide options at runtime: Especially, since code execution can be a security concern (arbitrary code execution!), we provide options at runtime:
- a secure python interpreter to run code more safely in your environment - a secure python interpreter to run code more safely in your environment
- a sandboxed environment. - a sandboxed environment using [E2B](https://e2b.dev/).
## How lightweight is it? ## How lightweight is it?

View File

@ -33,7 +33,7 @@ from huggingface_hub import login, InferenceClient
login("<YOUR_HUGGINGFACEHUB_API_TOKEN>") login("<YOUR_HUGGINGFACEHUB_API_TOKEN>")
model_id = "Qwen/Qwen2.5-72B-Instruct" model_id = "meta-llama/Llama-3.3-70B-Instruct"
client = InferenceClient(model=model_id) client = InferenceClient(model=model_id)
@ -71,12 +71,19 @@ agent.run(
Note that we used an additional `additional_detail` argument: you can additional kwargs to `agent.run()`, they will be baked into the prompt as text. Note that we used an additional `additional_detail` argument: you can additional kwargs to `agent.run()`, they will be baked into the prompt as text.
You can use this to indicate the path to local or remote files for the model to use: You can use this to pass files in various formats:
```py ```py
agent = CodeAgent(tools=[], model=model, add_base_tools=True) from smolagents import CodeAgent, HfApiModel
agent.run("Why does Mike not know many people in New York?", audio="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/recording.mp3") model_id = "meta-llama/Llama-3.3-70B-Instruct"
agent = CodeAgent(tools=[], model=HfApiModel(model_id=model_id), add_base_tools=True)
agent.run(
"Why does Mike not know many people in New York?",
additional_args={"mp3_sound_file_url":'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/recording.mp3'}
)
``` ```
It's important to explain as clearly as possible the task you want to perform. It's important to explain as clearly as possible the task you want to perform.

View File

@ -27,12 +27,11 @@ LAUNCH_GRADIO = False
get_cat_image = GetCatImageTool() get_cat_image = GetCatImageTool()
agent = CodeAgent( agent = CodeAgent(
tools = [get_cat_image, VisitWebpageTool()], tools = [get_cat_image, VisitWebpageTool()],
model=HfApiModel(), model=HfApiModel(),
additional_authorized_imports=["Pillow", "requests", "markdownify"], # "duckduckgo-search", additional_authorized_imports=["Pillow", "requests", "markdownify"], # "duckduckgo-search",
use_e2b_executor=False use_e2b_executor=True
) )
if LAUNCH_GRADIO: if LAUNCH_GRADIO:
@ -41,6 +40,5 @@ if LAUNCH_GRADIO:
GradioUI(agent).launch() GradioUI(agent).launch()
else: else:
agent.run( agent.run(
"Return me an image of Lincoln's preferred pet", "Return me an image of a cat. Directly use the image provided in your state.", additional_args={"cat_image":get_cat_image()}
additional_context="Here is a webpage about US presidents and pets: https://www.9lives.com/blog/a-history-of-cats-in-the-white-house/" ) # Asking to directly return the image from state tests that additional_args are properly sent to server.
)

View File

@ -188,6 +188,7 @@ class MultiStepAgent:
self.tool_parser = tool_parser self.tool_parser = tool_parser
self.grammar = grammar self.grammar = grammar
self.planning_interval = planning_interval self.planning_interval = planning_interval
self.state = {}
self.managed_agents = {} self.managed_agents = {}
if managed_agents is not None: if managed_agents is not None:
@ -370,8 +371,7 @@ class MultiStepAgent:
return self.model(self.input_messages) return self.model(self.input_messages)
except Exception as e: except Exception as e:
error_msg = f"Error in generating final LLM output:\n{e}" error_msg = f"Error in generating final LLM output:\n{e}"
console.print(f"[bold red]{error_msg}[/bold red]") raise AgentGenerationError(error_msg)
return error_msg
def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any: def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any:
""" """
@ -385,7 +385,6 @@ class MultiStepAgent:
available_tools = {**self.toolbox.tools, **self.managed_agents} available_tools = {**self.toolbox.tools, **self.managed_agents}
if tool_name not in available_tools: if tool_name not in available_tools:
error_msg = f"Unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}." error_msg = f"Unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg) raise AgentExecutionError(error_msg)
try: try:
@ -398,7 +397,6 @@ class MultiStepAgent:
observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True) observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True)
else: else:
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}." error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg) raise AgentExecutionError(error_msg)
return observation return observation
except Exception as e: except Exception as e:
@ -410,14 +408,12 @@ class MultiStepAgent:
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n" f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
f"As a reminder, this tool's description is the following:\n{tool_description}" f"As a reminder, this tool's description is the following:\n{tool_description}"
) )
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg) raise AgentExecutionError(error_msg)
elif tool_name in self.managed_agents: elif tool_name in self.managed_agents:
error_msg = ( error_msg = (
f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n" f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}" f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
) )
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg) raise AgentExecutionError(error_msg)
def step(self, log_entry: ActionStep) -> Union[None, Any]: def step(self, log_entry: ActionStep) -> Union[None, Any]:
@ -430,7 +426,7 @@ class MultiStepAgent:
stream: bool = False, stream: bool = False,
reset: bool = True, reset: bool = True,
single_step: bool = False, single_step: bool = False,
**kwargs, additional_args: Optional[Dict] = None,
): ):
""" """
Runs the agent for the given task. Runs the agent for the given task.
@ -440,6 +436,7 @@ class MultiStepAgent:
stream (`bool`): Wether to run in a streaming way. stream (`bool`): Wether to run in a streaming way.
reset (`bool`): Wether to reset the conversation or keep it going from previous run. reset (`bool`): Wether to reset the conversation or keep it going from previous run.
single_step (`bool`): Should the agent run in one shot or multi-step fashion? single_step (`bool`): Should the agent run in one shot or multi-step fashion?
additional_args (`dict`): Any other variables that you want to pass to the agent run, for instance images or dataframes. Give them clear names!
Example: Example:
```py ```py
@ -449,11 +446,11 @@ class MultiStepAgent:
``` ```
""" """
self.task = task self.task = task
if len(kwargs) > 0: if additional_args is not None:
self.task += ( self.state.update(additional_args)
f"\nYou have been provided with these initial arguments: {str(kwargs)}." self.task += f"""
) You have been provided with these additional arguments, that you can access as variables in your python code using the keys:
self.state = kwargs.copy() {str(additional_args)}."""
self.initialize_system_prompt() self.initialize_system_prompt()
system_prompt_step = SystemPromptStep(system_prompt=self.system_prompt) system_prompt_step = SystemPromptStep(system_prompt=self.system_prompt)
@ -468,14 +465,9 @@ class MultiStepAgent:
else: else:
self.logs.append(system_prompt_step) self.logs.append(system_prompt_step)
# console.print(
# Group(
# Rule("[bold]New run", characters="═", style=YELLOW_HEX), Text(self.task)
# )
# )
console.print( console.print(
Panel( Panel(
f"\n[bold]{task.strip()}\n", f"\n[bold]{self.task.strip()}\n",
title="[bold]New run", title="[bold]New run",
subtitle=f"{type(self.model).__name__} - {(self.model.model_id if hasattr(self.model, "model_id") else "")}", subtitle=f"{type(self.model).__name__} - {(self.model.model_id if hasattr(self.model, "model_id") else "")}",
border_style=YELLOW_HEX, border_style=YELLOW_HEX,
@ -891,17 +883,6 @@ class CodeAgent(MultiStepAgent):
console.print_exception() console.print_exception()
raise AgentGenerationError(f"Error in generating model output:\n{e}") raise AgentGenerationError(f"Error in generating model output:\n{e}")
# from rich.live import Live
# from rich.markdown import Markdown
# import time
# with Live(console=console, vertical_overflow="visible") as live:
# message = ""
# for i in range(100):
# time.sleep(0.02)
# message += str(i)
# live.update(Markdown(message))
if self.verbose: if self.verbose:
console.print( console.print(
Group( Group(
@ -946,6 +927,7 @@ class CodeAgent(MultiStepAgent):
try: try:
output, execution_logs = self.python_executor( output, execution_logs = self.python_executor(
code_action, code_action,
self.state,
) )
execution_outputs_console = [] execution_outputs_console = []
if len(execution_logs) > 0: if len(execution_logs) > 0:

View File

@ -295,7 +295,7 @@ class SpeechToTextTool(PipelineTool):
pre_processor_class = WhisperProcessor pre_processor_class = WhisperProcessor
model_class = WhisperForConditionalGeneration model_class = WhisperForConditionalGeneration
inputs = {"audio": {"type": "audio", "description": "The audio to transcribe"}} inputs = {"audio": {"type": "audio", "description": "The audio to transcribe. Can be a local path, an url, or a tensor."}}
output_type = "string" output_type = "string"
def encode(self, audio): def encode(self, audio):

View File

@ -17,13 +17,14 @@
from dotenv import load_dotenv from dotenv import load_dotenv
import textwrap import textwrap
import base64 import base64
import pickle
from io import BytesIO from io import BytesIO
from PIL import Image from PIL import Image
from e2b_code_interpreter import Sandbox from e2b_code_interpreter import Sandbox
from typing import List, Tuple, Any from typing import List, Tuple, Any
from .tool_validation import validate_tool_attributes from .tool_validation import validate_tool_attributes
from .utils import instance_to_source, BASE_BUILTIN_MODULES from .utils import instance_to_source, BASE_BUILTIN_MODULES, console
from .tools import Tool from .tools import Tool
load_dotenv() load_dotenv()
@ -40,6 +41,7 @@ class E2BExecutor:
# timeout=300 # timeout=300
# ) # )
# print("Installation of agents package finished.") # print("Installation of agents package finished.")
additional_imports = additional_imports + ["pickle5"]
if len(additional_imports) > 0: if len(additional_imports) > 0:
execution = self.sbx.commands.run( execution = self.sbx.commands.run(
"pip install " + " ".join(additional_imports) "pip install " + " ".join(additional_imports)
@ -47,7 +49,7 @@ class E2BExecutor:
if execution.error: if execution.error:
raise Exception(f"Error installing dependencies: {execution.error}") raise Exception(f"Error installing dependencies: {execution.error}")
else: else:
print("Installation succeeded!") console.print(f"Installation of {additional_imports} succeeded!")
tool_codes = [] tool_codes = []
for tool in tools: for tool in tools:
@ -71,21 +73,44 @@ class E2BExecutor:
tool_definition_code += "\n\n".join(tool_codes) tool_definition_code += "\n\n".join(tool_codes)
tool_definition_execution = self.run_code_raise_errors(tool_definition_code) tool_definition_execution = self.run_code_raise_errors(tool_definition_code)
print(tool_definition_execution.logs) console.print(tool_definition_execution.logs)
def run_code_raise_errors(self, code: str): def run_code_raise_errors(self, code: str):
execution = self.sbx.run_code( execution = self.sbx.run_code(
code, code,
) )
if execution.error: if execution.error:
logs = "Executing code yielded an error:" execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
logs = execution_logs
logs += "Executing code yielded an error:"
logs += execution.error.name logs += execution.error.name
logs += execution.error.value logs += execution.error.value
logs += execution.error.traceback logs += execution.error.traceback
raise ValueError(logs) raise ValueError(logs)
return execution return execution
def __call__(self, code_action: str) -> Tuple[Any, Any]: def __call__(self, code_action: str, additional_args: dict) -> Tuple[Any, Any]:
if len(additional_args) > 0:
# Pickle additional_args to server
import tempfile
with tempfile.NamedTemporaryFile() as f:
pickle.dump(additional_args, f)
f.flush()
with open(f.name, "rb") as file:
self.sbx.files.write("/home/state.pkl", file)
remote_unloading_code = """import pickle
import os
print("File path", os.path.getsize('/home/state.pkl'))
with open('/home/state.pkl', 'rb') as f:
pickle_dict = pickle.load(f)
locals().update({key: value for key, value in pickle_dict.items()})
"""
execution = self.run_code_raise_errors(remote_unloading_code)
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
console.print(execution_logs)
execution = self.run_code_raise_errors(code_action) execution = self.run_code_raise_errors(code_action)
execution_logs = "\n".join([str(log) for log in execution.logs.stdout]) execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
if not execution.results: if not execution.results:

View File

@ -1058,7 +1058,8 @@ class LocalPythonInterpreter:
} }
# TODO: assert self.authorized imports are all installed locally # TODO: assert self.authorized imports are all installed locally
def __call__(self, code_action: str) -> Tuple[Any, str]: def __call__(self, code_action: str, additional_variables: Dict) -> Tuple[Any, str]:
self.state.update(additional_variables)
output = evaluate_python_code( output = evaluate_python_code(
code_action, code_action,
static_tools=self.static_tools, static_tools=self.static_tools,

View File

@ -201,7 +201,8 @@ class Tool:
assert getattr(self, "output_type", None) in AUTHORIZED_TYPES assert getattr(self, "output_type", None) in AUTHORIZED_TYPES
# Validate forward function signature # Validate forward function signature, except for PipelineTool
if not (hasattr(self, "is_pipeline_tool") and getattr(self, "is_pipeline_tool") is True):
signature = inspect.signature(self.forward) signature = inspect.signature(self.forward)
if not set(signature.parameters.keys()) == set(self.inputs.keys()): if not set(signature.parameters.keys()) == set(self.inputs.keys()):
@ -1074,6 +1075,7 @@ class PipelineTool(Tool):
name = "pipeline" name = "pipeline"
inputs = {"prompt": str} inputs = {"prompt": str}
output_type = str output_type = str
is_pipeline_tool = True
def __init__( def __init__(
self, self,