Fix additional args sent to e2b server
This commit is contained in:
parent
1abaf69b67
commit
f8b9cb34f9
|
@ -62,7 +62,7 @@ and [reaches higher performance on difficult benchmarks](https://huggingface.co/
|
||||||
|
|
||||||
Especially, since code execution can be a security concern (arbitrary code execution!), we provide options at runtime:
|
Especially, since code execution can be a security concern (arbitrary code execution!), we provide options at runtime:
|
||||||
- a secure python interpreter to run code more safely in your environment
|
- a secure python interpreter to run code more safely in your environment
|
||||||
- a sandboxed environment.
|
- a sandboxed environment using [E2B](https://e2b.dev/).
|
||||||
|
|
||||||
## How lightweight is it?
|
## How lightweight is it?
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ from huggingface_hub import login, InferenceClient
|
||||||
|
|
||||||
login("<YOUR_HUGGINGFACEHUB_API_TOKEN>")
|
login("<YOUR_HUGGINGFACEHUB_API_TOKEN>")
|
||||||
|
|
||||||
model_id = "Qwen/Qwen2.5-72B-Instruct"
|
model_id = "meta-llama/Llama-3.3-70B-Instruct"
|
||||||
|
|
||||||
client = InferenceClient(model=model_id)
|
client = InferenceClient(model=model_id)
|
||||||
|
|
||||||
|
@ -71,12 +71,19 @@ agent.run(
|
||||||
|
|
||||||
Note that we used an additional `additional_detail` argument: you can additional kwargs to `agent.run()`, they will be baked into the prompt as text.
|
Note that we used an additional `additional_detail` argument: you can additional kwargs to `agent.run()`, they will be baked into the prompt as text.
|
||||||
|
|
||||||
You can use this to indicate the path to local or remote files for the model to use:
|
You can use this to pass files in various formats:
|
||||||
|
|
||||||
```py
|
```py
|
||||||
agent = CodeAgent(tools=[], model=model, add_base_tools=True)
|
from smolagents import CodeAgent, HfApiModel
|
||||||
|
|
||||||
agent.run("Why does Mike not know many people in New York?", audio="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/recording.mp3")
|
model_id = "meta-llama/Llama-3.3-70B-Instruct"
|
||||||
|
|
||||||
|
agent = CodeAgent(tools=[], model=HfApiModel(model_id=model_id), add_base_tools=True)
|
||||||
|
|
||||||
|
agent.run(
|
||||||
|
"Why does Mike not know many people in New York?",
|
||||||
|
additional_args={"mp3_sound_file_url":'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/recording.mp3'}
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
It's important to explain as clearly as possible the task you want to perform.
|
It's important to explain as clearly as possible the task you want to perform.
|
||||||
|
|
|
@ -27,12 +27,11 @@ LAUNCH_GRADIO = False
|
||||||
|
|
||||||
get_cat_image = GetCatImageTool()
|
get_cat_image = GetCatImageTool()
|
||||||
|
|
||||||
|
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools = [get_cat_image, VisitWebpageTool()],
|
tools = [get_cat_image, VisitWebpageTool()],
|
||||||
model=HfApiModel(),
|
model=HfApiModel(),
|
||||||
additional_authorized_imports=["Pillow", "requests", "markdownify"], # "duckduckgo-search",
|
additional_authorized_imports=["Pillow", "requests", "markdownify"], # "duckduckgo-search",
|
||||||
use_e2b_executor=False
|
use_e2b_executor=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if LAUNCH_GRADIO:
|
if LAUNCH_GRADIO:
|
||||||
|
@ -41,6 +40,5 @@ if LAUNCH_GRADIO:
|
||||||
GradioUI(agent).launch()
|
GradioUI(agent).launch()
|
||||||
else:
|
else:
|
||||||
agent.run(
|
agent.run(
|
||||||
"Return me an image of Lincoln's preferred pet",
|
"Return me an image of a cat. Directly use the image provided in your state.", additional_args={"cat_image":get_cat_image()}
|
||||||
additional_context="Here is a webpage about US presidents and pets: https://www.9lives.com/blog/a-history-of-cats-in-the-white-house/"
|
) # Asking to directly return the image from state tests that additional_args are properly sent to server.
|
||||||
)
|
|
||||||
|
|
|
@ -188,6 +188,7 @@ class MultiStepAgent:
|
||||||
self.tool_parser = tool_parser
|
self.tool_parser = tool_parser
|
||||||
self.grammar = grammar
|
self.grammar = grammar
|
||||||
self.planning_interval = planning_interval
|
self.planning_interval = planning_interval
|
||||||
|
self.state = {}
|
||||||
|
|
||||||
self.managed_agents = {}
|
self.managed_agents = {}
|
||||||
if managed_agents is not None:
|
if managed_agents is not None:
|
||||||
|
@ -370,8 +371,7 @@ class MultiStepAgent:
|
||||||
return self.model(self.input_messages)
|
return self.model(self.input_messages)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error in generating final LLM output:\n{e}"
|
error_msg = f"Error in generating final LLM output:\n{e}"
|
||||||
console.print(f"[bold red]{error_msg}[/bold red]")
|
raise AgentGenerationError(error_msg)
|
||||||
return error_msg
|
|
||||||
|
|
||||||
def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any:
|
def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any:
|
||||||
"""
|
"""
|
||||||
|
@ -385,7 +385,6 @@ class MultiStepAgent:
|
||||||
available_tools = {**self.toolbox.tools, **self.managed_agents}
|
available_tools = {**self.toolbox.tools, **self.managed_agents}
|
||||||
if tool_name not in available_tools:
|
if tool_name not in available_tools:
|
||||||
error_msg = f"Unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
|
error_msg = f"Unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
|
||||||
console.print(f"[bold red]{error_msg}")
|
|
||||||
raise AgentExecutionError(error_msg)
|
raise AgentExecutionError(error_msg)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -398,7 +397,6 @@ class MultiStepAgent:
|
||||||
observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True)
|
observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True)
|
||||||
else:
|
else:
|
||||||
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
|
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
|
||||||
console.print(f"[bold red]{error_msg}")
|
|
||||||
raise AgentExecutionError(error_msg)
|
raise AgentExecutionError(error_msg)
|
||||||
return observation
|
return observation
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -410,14 +408,12 @@ class MultiStepAgent:
|
||||||
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
|
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
|
||||||
f"As a reminder, this tool's description is the following:\n{tool_description}"
|
f"As a reminder, this tool's description is the following:\n{tool_description}"
|
||||||
)
|
)
|
||||||
console.print(f"[bold red]{error_msg}")
|
|
||||||
raise AgentExecutionError(error_msg)
|
raise AgentExecutionError(error_msg)
|
||||||
elif tool_name in self.managed_agents:
|
elif tool_name in self.managed_agents:
|
||||||
error_msg = (
|
error_msg = (
|
||||||
f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
|
f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
|
||||||
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
|
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
|
||||||
)
|
)
|
||||||
console.print(f"[bold red]{error_msg}")
|
|
||||||
raise AgentExecutionError(error_msg)
|
raise AgentExecutionError(error_msg)
|
||||||
|
|
||||||
def step(self, log_entry: ActionStep) -> Union[None, Any]:
|
def step(self, log_entry: ActionStep) -> Union[None, Any]:
|
||||||
|
@ -430,7 +426,7 @@ class MultiStepAgent:
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
reset: bool = True,
|
reset: bool = True,
|
||||||
single_step: bool = False,
|
single_step: bool = False,
|
||||||
**kwargs,
|
additional_args: Optional[Dict] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Runs the agent for the given task.
|
Runs the agent for the given task.
|
||||||
|
@ -440,6 +436,7 @@ class MultiStepAgent:
|
||||||
stream (`bool`): Wether to run in a streaming way.
|
stream (`bool`): Wether to run in a streaming way.
|
||||||
reset (`bool`): Wether to reset the conversation or keep it going from previous run.
|
reset (`bool`): Wether to reset the conversation or keep it going from previous run.
|
||||||
single_step (`bool`): Should the agent run in one shot or multi-step fashion?
|
single_step (`bool`): Should the agent run in one shot or multi-step fashion?
|
||||||
|
additional_args (`dict`): Any other variables that you want to pass to the agent run, for instance images or dataframes. Give them clear names!
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```py
|
```py
|
||||||
|
@ -449,11 +446,11 @@ class MultiStepAgent:
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
self.task = task
|
self.task = task
|
||||||
if len(kwargs) > 0:
|
if additional_args is not None:
|
||||||
self.task += (
|
self.state.update(additional_args)
|
||||||
f"\nYou have been provided with these initial arguments: {str(kwargs)}."
|
self.task += f"""
|
||||||
)
|
You have been provided with these additional arguments, that you can access as variables in your python code using the keys:
|
||||||
self.state = kwargs.copy()
|
{str(additional_args)}."""
|
||||||
|
|
||||||
self.initialize_system_prompt()
|
self.initialize_system_prompt()
|
||||||
system_prompt_step = SystemPromptStep(system_prompt=self.system_prompt)
|
system_prompt_step = SystemPromptStep(system_prompt=self.system_prompt)
|
||||||
|
@ -468,14 +465,9 @@ class MultiStepAgent:
|
||||||
else:
|
else:
|
||||||
self.logs.append(system_prompt_step)
|
self.logs.append(system_prompt_step)
|
||||||
|
|
||||||
# console.print(
|
|
||||||
# Group(
|
|
||||||
# Rule("[bold]New run", characters="═", style=YELLOW_HEX), Text(self.task)
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
console.print(
|
console.print(
|
||||||
Panel(
|
Panel(
|
||||||
f"\n[bold]{task.strip()}\n",
|
f"\n[bold]{self.task.strip()}\n",
|
||||||
title="[bold]New run",
|
title="[bold]New run",
|
||||||
subtitle=f"{type(self.model).__name__} - {(self.model.model_id if hasattr(self.model, "model_id") else "")}",
|
subtitle=f"{type(self.model).__name__} - {(self.model.model_id if hasattr(self.model, "model_id") else "")}",
|
||||||
border_style=YELLOW_HEX,
|
border_style=YELLOW_HEX,
|
||||||
|
@ -891,17 +883,6 @@ class CodeAgent(MultiStepAgent):
|
||||||
console.print_exception()
|
console.print_exception()
|
||||||
raise AgentGenerationError(f"Error in generating model output:\n{e}")
|
raise AgentGenerationError(f"Error in generating model output:\n{e}")
|
||||||
|
|
||||||
# from rich.live import Live
|
|
||||||
# from rich.markdown import Markdown
|
|
||||||
# import time
|
|
||||||
|
|
||||||
# with Live(console=console, vertical_overflow="visible") as live:
|
|
||||||
# message = ""
|
|
||||||
# for i in range(100):
|
|
||||||
# time.sleep(0.02)
|
|
||||||
# message += str(i)
|
|
||||||
# live.update(Markdown(message))
|
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
console.print(
|
console.print(
|
||||||
Group(
|
Group(
|
||||||
|
@ -946,6 +927,7 @@ class CodeAgent(MultiStepAgent):
|
||||||
try:
|
try:
|
||||||
output, execution_logs = self.python_executor(
|
output, execution_logs = self.python_executor(
|
||||||
code_action,
|
code_action,
|
||||||
|
self.state,
|
||||||
)
|
)
|
||||||
execution_outputs_console = []
|
execution_outputs_console = []
|
||||||
if len(execution_logs) > 0:
|
if len(execution_logs) > 0:
|
||||||
|
|
|
@ -295,7 +295,7 @@ class SpeechToTextTool(PipelineTool):
|
||||||
pre_processor_class = WhisperProcessor
|
pre_processor_class = WhisperProcessor
|
||||||
model_class = WhisperForConditionalGeneration
|
model_class = WhisperForConditionalGeneration
|
||||||
|
|
||||||
inputs = {"audio": {"type": "audio", "description": "The audio to transcribe"}}
|
inputs = {"audio": {"type": "audio", "description": "The audio to transcribe. Can be a local path, an url, or a tensor."}}
|
||||||
output_type = "string"
|
output_type = "string"
|
||||||
|
|
||||||
def encode(self, audio):
|
def encode(self, audio):
|
||||||
|
|
|
@ -17,13 +17,14 @@
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
import textwrap
|
import textwrap
|
||||||
import base64
|
import base64
|
||||||
|
import pickle
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from e2b_code_interpreter import Sandbox
|
from e2b_code_interpreter import Sandbox
|
||||||
from typing import List, Tuple, Any
|
from typing import List, Tuple, Any
|
||||||
from .tool_validation import validate_tool_attributes
|
from .tool_validation import validate_tool_attributes
|
||||||
from .utils import instance_to_source, BASE_BUILTIN_MODULES
|
from .utils import instance_to_source, BASE_BUILTIN_MODULES, console
|
||||||
from .tools import Tool
|
from .tools import Tool
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
@ -40,6 +41,7 @@ class E2BExecutor:
|
||||||
# timeout=300
|
# timeout=300
|
||||||
# )
|
# )
|
||||||
# print("Installation of agents package finished.")
|
# print("Installation of agents package finished.")
|
||||||
|
additional_imports = additional_imports + ["pickle5"]
|
||||||
if len(additional_imports) > 0:
|
if len(additional_imports) > 0:
|
||||||
execution = self.sbx.commands.run(
|
execution = self.sbx.commands.run(
|
||||||
"pip install " + " ".join(additional_imports)
|
"pip install " + " ".join(additional_imports)
|
||||||
|
@ -47,7 +49,7 @@ class E2BExecutor:
|
||||||
if execution.error:
|
if execution.error:
|
||||||
raise Exception(f"Error installing dependencies: {execution.error}")
|
raise Exception(f"Error installing dependencies: {execution.error}")
|
||||||
else:
|
else:
|
||||||
print("Installation succeeded!")
|
console.print(f"Installation of {additional_imports} succeeded!")
|
||||||
|
|
||||||
tool_codes = []
|
tool_codes = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
|
@ -71,21 +73,44 @@ class E2BExecutor:
|
||||||
tool_definition_code += "\n\n".join(tool_codes)
|
tool_definition_code += "\n\n".join(tool_codes)
|
||||||
|
|
||||||
tool_definition_execution = self.run_code_raise_errors(tool_definition_code)
|
tool_definition_execution = self.run_code_raise_errors(tool_definition_code)
|
||||||
print(tool_definition_execution.logs)
|
console.print(tool_definition_execution.logs)
|
||||||
|
|
||||||
def run_code_raise_errors(self, code: str):
|
def run_code_raise_errors(self, code: str):
|
||||||
execution = self.sbx.run_code(
|
execution = self.sbx.run_code(
|
||||||
code,
|
code,
|
||||||
)
|
)
|
||||||
if execution.error:
|
if execution.error:
|
||||||
logs = "Executing code yielded an error:"
|
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
|
||||||
|
logs = execution_logs
|
||||||
|
logs += "Executing code yielded an error:"
|
||||||
logs += execution.error.name
|
logs += execution.error.name
|
||||||
logs += execution.error.value
|
logs += execution.error.value
|
||||||
logs += execution.error.traceback
|
logs += execution.error.traceback
|
||||||
raise ValueError(logs)
|
raise ValueError(logs)
|
||||||
return execution
|
return execution
|
||||||
|
|
||||||
def __call__(self, code_action: str) -> Tuple[Any, Any]:
|
def __call__(self, code_action: str, additional_args: dict) -> Tuple[Any, Any]:
|
||||||
|
if len(additional_args) > 0:
|
||||||
|
# Pickle additional_args to server
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile() as f:
|
||||||
|
pickle.dump(additional_args, f)
|
||||||
|
f.flush()
|
||||||
|
with open(f.name, "rb") as file:
|
||||||
|
self.sbx.files.write("/home/state.pkl", file)
|
||||||
|
remote_unloading_code = """import pickle
|
||||||
|
import os
|
||||||
|
print("File path", os.path.getsize('/home/state.pkl'))
|
||||||
|
with open('/home/state.pkl', 'rb') as f:
|
||||||
|
pickle_dict = pickle.load(f)
|
||||||
|
locals().update({key: value for key, value in pickle_dict.items()})
|
||||||
|
"""
|
||||||
|
execution = self.run_code_raise_errors(remote_unloading_code)
|
||||||
|
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
|
||||||
|
console.print(execution_logs)
|
||||||
|
|
||||||
|
|
||||||
execution = self.run_code_raise_errors(code_action)
|
execution = self.run_code_raise_errors(code_action)
|
||||||
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
|
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
|
||||||
if not execution.results:
|
if not execution.results:
|
||||||
|
|
|
@ -1058,7 +1058,8 @@ class LocalPythonInterpreter:
|
||||||
}
|
}
|
||||||
# TODO: assert self.authorized imports are all installed locally
|
# TODO: assert self.authorized imports are all installed locally
|
||||||
|
|
||||||
def __call__(self, code_action: str) -> Tuple[Any, str]:
|
def __call__(self, code_action: str, additional_variables: Dict) -> Tuple[Any, str]:
|
||||||
|
self.state.update(additional_variables)
|
||||||
output = evaluate_python_code(
|
output = evaluate_python_code(
|
||||||
code_action,
|
code_action,
|
||||||
static_tools=self.static_tools,
|
static_tools=self.static_tools,
|
||||||
|
|
|
@ -201,7 +201,8 @@ class Tool:
|
||||||
|
|
||||||
assert getattr(self, "output_type", None) in AUTHORIZED_TYPES
|
assert getattr(self, "output_type", None) in AUTHORIZED_TYPES
|
||||||
|
|
||||||
# Validate forward function signature
|
# Validate forward function signature, except for PipelineTool
|
||||||
|
if not (hasattr(self, "is_pipeline_tool") and getattr(self, "is_pipeline_tool") is True):
|
||||||
signature = inspect.signature(self.forward)
|
signature = inspect.signature(self.forward)
|
||||||
|
|
||||||
if not set(signature.parameters.keys()) == set(self.inputs.keys()):
|
if not set(signature.parameters.keys()) == set(self.inputs.keys()):
|
||||||
|
@ -1074,6 +1075,7 @@ class PipelineTool(Tool):
|
||||||
name = "pipeline"
|
name = "pipeline"
|
||||||
inputs = {"prompt": str}
|
inputs = {"prompt": str}
|
||||||
output_type = str
|
output_type = str
|
||||||
|
is_pipeline_tool = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in New Issue