Add E2B code interpreter 🥳

This commit is contained in:
Aymeric 2024-12-20 16:18:40 +01:00
parent 7b0b01d8f3
commit c18bc9037d
24 changed files with 400 additions and 243 deletions

7
.gitignore vendored
View File

@ -36,6 +36,7 @@ sdist/
var/ var/
wheels/ wheels/
share/python-wheels/ share/python-wheels/
node_modules/
*.egg-info/ *.egg-info/
.installed.cfg .installed.cfg
*.egg *.egg
@ -156,4 +157,8 @@ dmypy.json
cython_debug/ cython_debug/
# PyCharm # PyCharm
#.idea/ #.idea/
# Archive
archive/
savedir/

View File

@ -1,5 +1,5 @@
# Base Python image # Base Python image
FROM python:3.9-slim FROM python:3.12-slim
# Set working directory # Set working directory
WORKDIR /app WORKDIR /app
@ -7,8 +7,6 @@ WORKDIR /app
# Install build dependencies # Install build dependencies
RUN apt-get update && apt-get install -y \ RUN apt-get update && apt-get install -y \
build-essential \ build-essential \
gcc \
g++ \
zlib1g-dev \ zlib1g-dev \
libjpeg-dev \ libjpeg-dev \
libpng-dev \ libpng-dev \

5
e2b.Dockerfile Normal file
View File

@ -0,0 +1,5 @@
# You can use most Debian-based base images
FROM e2bdev/code-interpreter:latest
# Install dependencies and customize sandbox
RUN pip install git+https://github.com/huggingface/agents.git

16
e2b.toml Normal file
View File

@ -0,0 +1,16 @@
# This is a config for E2B sandbox template.
# You can use template ID (qywp2ctmu2q7jzprcf4j) to create a sandbox:
# Python SDK
# from e2b import Sandbox, AsyncSandbox
# sandbox = Sandbox("qywp2ctmu2q7jzprcf4j") # Sync sandbox
# sandbox = await AsyncSandbox.create("qywp2ctmu2q7jzprcf4j") # Async sandbox
# JS SDK
# import { Sandbox } from 'e2b'
# const sandbox = await Sandbox.create('qywp2ctmu2q7jzprcf4j')
team_id = "f8776d3a-df2f-4a1d-af48-68c2e13b3b87"
start_cmd = "/root/.jupyter/start-up.sh"
dockerfile = "e2b.Dockerfile"
template_id = "qywp2ctmu2q7jzprcf4j"

View File

@ -1,8 +1,8 @@
from agents.tools.search import DuckDuckGoSearchTool from agents.default_tools.search import DuckDuckGoSearchTool
from agents.docker_alternative import DockerPythonInterpreter from agents.docker_alternative import DockerPythonInterpreter
from agents.tool import Tool from agents.tools import Tool
class DummyTool(Tool): class DummyTool(Tool):
name = "echo" name = "echo"

View File

@ -1,4 +1,4 @@
from agents.tool import Tool from agents.tools import Tool
class DummyTool(Tool): class DummyTool(Tool):

44
examples/e2b_example.py Normal file
View File

@ -0,0 +1,44 @@
from agents import Tool, CodeAgent
from agents.default_tools.search import VisitWebpageTool
from dotenv import load_dotenv
load_dotenv()
LAUNCH_GRADIO = False
class GetCatImageTool(Tool):
name="get_cat_image"
description = "Get a cat image"
inputs = {}
output_type = "image"
def __init__(self):
super().__init__()
self.url = "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png"
def forward(self):
from PIL import Image
import requests
from io import BytesIO
response = requests.get(self.url)
return Image.open(BytesIO(response.content))
get_cat_image = GetCatImageTool()
agent = CodeAgent(
tools = [get_cat_image, VisitWebpageTool()],
additional_authorized_imports=["Pillow", "requests", "markdownify"], # "duckduckgo-search",
use_e2b_executor=False
)
if LAUNCH_GRADIO:
from agents.gradio_ui import GradioUI
GradioUI(agent).launch()
else:
agent.run(
"Return me an image of Lincoln's preferred pet",
additional_context="Here is a webpage about US presidents and pets: https://www.9lives.com/blog/a-history-of-cats-in-the-white-house/"
)

View File

@ -24,22 +24,28 @@ from transformers.utils.import_utils import define_import_structure
if TYPE_CHECKING: if TYPE_CHECKING:
from .agents import * from .agents import *
from .default_tools import * from .default_tools.base import *
from .default_tools.search import *
from .gradio_ui import * from .gradio_ui import *
from .llm_engines import * from .llm_engines import *
from .local_python_executor import * from .local_python_executor import *
from .monitoring import * from .monitoring import *
from .prompts import * from .prompts import *
from .tools.search import * from .tools import *
from .tool import *
from .types import * from .types import *
from .utils import * from .utils import *
from .default_tools.search import *
else: else:
import sys import sys
_file = globals()["__file__"] _file = globals()["__file__"]
import_structure = define_import_structure(_file)
import_structure[""]= {"__version__": __version__}
sys.modules[__name__] = _LazyModule( sys.modules[__name__] = _LazyModule(
__name__, _file, define_import_structure(_file), module_spec=__spec__ __name__,
_file,
import_structure,
module_spec=__spec__,
extra_objects={"__version__": __version__}
) )

View File

@ -27,7 +27,7 @@ from transformers.utils import is_torch_available
from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content
from .types import AgentAudio, AgentImage from .types import AgentAudio, AgentImage
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool from .default_tools.base import FinalAnswerTool
from .llm_engines import HfApiEngine, MessageRole from .llm_engines import HfApiEngine, MessageRole
from .monitoring import Monitor from .monitoring import Monitor
from .prompts import ( from .prompts import (
@ -42,8 +42,9 @@ from .prompts import (
SYSTEM_PROMPT_PLAN_UPDATE, SYSTEM_PROMPT_PLAN_UPDATE,
SYSTEM_PROMPT_PLAN, SYSTEM_PROMPT_PLAN,
) )
from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonExecutor
from .tool import ( from .e2b_executor import E2BExecutor
from .tools import (
DEFAULT_TOOL_DESCRIPTION_TEMPLATE, DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
Tool, Tool,
get_tool_description_with_args, get_tool_description_with_args,
@ -169,17 +170,6 @@ def format_prompt_with_managed_agents_descriptions(
else: else:
return prompt_template.replace(agent_descriptions_placeholder, "") return prompt_template.replace(agent_descriptions_placeholder, "")
def format_prompt_with_imports(
prompt_template: str, authorized_imports: List[str]
) -> str:
if "<<authorized_imports>>" not in prompt_template:
raise AgentError(
"Tag '<<authorized_imports>>' should be provided in the prompt."
)
return prompt_template.replace("<<authorized_imports>>", str(authorized_imports))
class BaseAgent: class BaseAgent:
def __init__( def __init__(
self, self,
@ -264,11 +254,6 @@ class BaseAgent:
self.system_prompt = format_prompt_with_managed_agents_descriptions( self.system_prompt = format_prompt_with_managed_agents_descriptions(
self.system_prompt, self.managed_agents self.system_prompt, self.managed_agents
) )
if hasattr(self, "authorized_imports"):
self.system_prompt = format_prompt_with_imports(
self.system_prompt,
list(set(LIST_SAFE_MODULES) | set(getattr(self, "authorized_imports"))),
)
return self.system_prompt return self.system_prompt
@ -439,9 +424,7 @@ class ReactAgent(BaseAgent):
tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox). tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox).
arguments (Dict[str, str]): Arguments passed to the Tool. arguments (Dict[str, str]): Arguments passed to the Tool.
""" """
available_tools = self.toolbox.tools available_tools = {**self.toolbox.tools, **self.managed_agents}
if self.managed_agents is not None:
available_tools = {**available_tools, **self.managed_agents}
if tool_name not in available_tools: if tool_name not in available_tools:
error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}." error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
console.print(f"[bold red]{error_msg}") console.print(f"[bold red]{error_msg}")
@ -674,8 +657,6 @@ Now begin!""",
), ),
managed_agents_descriptions=( managed_agents_descriptions=(
show_agents_descriptions(self.managed_agents) show_agents_descriptions(self.managed_agents)
if self.managed_agents is not None
else ""
), ),
answer_facts=answer_facts, answer_facts=answer_facts,
), ),
@ -729,8 +710,6 @@ Now begin!""",
), ),
managed_agents_descriptions=( managed_agents_descriptions=(
show_agents_descriptions(self.managed_agents) show_agents_descriptions(self.managed_agents)
if self.managed_agents is not None
else ""
), ),
facts_update=facts_update, facts_update=facts_update,
remaining_steps=(self.max_iterations - iteration), remaining_steps=(self.max_iterations - iteration),
@ -891,6 +870,7 @@ class CodeAgent(ReactAgent):
grammar: Optional[Dict[str, str]] = None, grammar: Optional[Dict[str, str]] = None,
additional_authorized_imports: Optional[List[str]] = None, additional_authorized_imports: Optional[List[str]] = None,
planning_interval: Optional[int] = None, planning_interval: Optional[int] = None,
use_e2b_executor: bool = False,
**kwargs, **kwargs,
): ):
if llm_engine is None: if llm_engine is None:
@ -909,17 +889,24 @@ class CodeAgent(ReactAgent):
**kwargs, **kwargs,
) )
self.python_evaluator = evaluate_python_code
self.additional_authorized_imports = ( self.additional_authorized_imports = (
additional_authorized_imports if additional_authorized_imports else [] additional_authorized_imports if additional_authorized_imports else []
) )
all_tools = {**self.toolbox.tools, **self.managed_agents}
if use_e2b_executor:
self.python_executor = E2BExecutor(self.additional_authorized_imports, list(all_tools.values()))
else:
self.python_executor = LocalPythonExecutor(self.additional_authorized_imports, all_tools)
self.authorized_imports = list( self.authorized_imports = list(
set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports) set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
) )
if "{{authorized_imports}}" not in self.system_prompt:
raise AgentError(
"Tag '{{authorized_imports}}' should be provided in the prompt."
)
self.system_prompt = self.system_prompt.replace( self.system_prompt = self.system_prompt.replace(
"{{authorized_imports}}", str(self.authorized_imports) "{{authorized_imports}}", str(self.authorized_imports)
) )
self.custom_tools = {}
def step(self, log_entry: ActionStep) -> Union[None, Any]: def step(self, log_entry: ActionStep) -> Union[None, Any]:
""" """
@ -991,22 +978,12 @@ class CodeAgent(ReactAgent):
) )
try: try:
static_tools = { output, execution_logs = self.python_executor(
**BASE_PYTHON_TOOLS.copy(),
**self.toolbox.tools,
}
if self.managed_agents is not None:
static_tools = {**static_tools, **self.managed_agents}
output = self.python_evaluator(
code_action, code_action,
static_tools=static_tools,
custom_tools=self.custom_tools,
state=self.state,
authorized_imports=self.authorized_imports,
) )
if len(self.state["print_outputs"]) > 0: if len(execution_logs) > 0:
console.print(Group(Text("Print outputs:", style="bold"), Text(self.state["print_outputs"]))) console.print(Group(Text("Execution logs:", style="bold"), Text(execution_logs)))
observation = "Print outputs:\n" + self.state["print_outputs"] observation = "Execution logs:\n" + execution_logs
if output is not None: if output is not None:
truncated_output = truncate_content( truncated_output = truncate_content(
str(output) str(output)
@ -1026,7 +1003,7 @@ class CodeAgent(ReactAgent):
console.print(Group(Text("Final answer:", style="bold"), Text(str(output), style="bold green"))) console.print(Group(Text("Final answer:", style="bold"), Text(str(output), style="bold green")))
log_entry.action_output = output log_entry.action_output = output
return output return output
return None
class ManagedAgent: class ManagedAgent:

View File

@ -15,75 +15,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json import json
import math
from dataclasses import dataclass from dataclasses import dataclass
from math import sqrt
from typing import Dict from typing import Dict
from huggingface_hub import hf_hub_download, list_spaces from huggingface_hub import hf_hub_download, list_spaces
from transformers.utils import is_offline_mode from transformers.utils import is_offline_mode
from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code from ..local_python_executor import BASE_BUILTIN_MODULES, BASE_PYTHON_TOOLS, evaluate_python_code
from .tool import TOOL_CONFIG_FILE, Tool from ..tools import TOOL_CONFIG_FILE, Tool
def custom_print(*args):
return None
BASE_PYTHON_TOOLS = {
"print": custom_print,
"isinstance": isinstance,
"range": range,
"float": float,
"int": int,
"bool": bool,
"str": str,
"set": set,
"list": list,
"dict": dict,
"tuple": tuple,
"round": round,
"ceil": math.ceil,
"floor": math.floor,
"log": math.log,
"exp": math.exp,
"sin": math.sin,
"cos": math.cos,
"tan": math.tan,
"asin": math.asin,
"acos": math.acos,
"atan": math.atan,
"atan2": math.atan2,
"degrees": math.degrees,
"radians": math.radians,
"pow": math.pow,
"sqrt": sqrt,
"len": len,
"sum": sum,
"max": max,
"min": min,
"abs": abs,
"enumerate": enumerate,
"zip": zip,
"reversed": reversed,
"sorted": sorted,
"all": all,
"any": any,
"map": map,
"filter": filter,
"ord": ord,
"chr": chr,
"next": next,
"iter": iter,
"divmod": divmod,
"callable": callable,
"getattr": getattr,
"hasattr": hasattr,
"setattr": setattr,
"issubclass": issubclass,
"type": type,
}
@dataclass @dataclass
@ -136,10 +75,10 @@ class PythonInterpreterTool(Tool):
def __init__(self, *args, authorized_imports=None, **kwargs): def __init__(self, *args, authorized_imports=None, **kwargs):
if authorized_imports is None: if authorized_imports is None:
self.authorized_imports = list(set(LIST_SAFE_MODULES)) self.authorized_imports = list(set(BASE_BUILTIN_MODULES))
else: else:
self.authorized_imports = list( self.authorized_imports = list(
set(LIST_SAFE_MODULES) | set(authorized_imports) set(BASE_BUILTIN_MODULES) | set(authorized_imports)
) )
self.inputs = { self.inputs = {
"code": { "code": {

View File

@ -16,15 +16,11 @@
# limitations under the License. # limitations under the License.
import re import re
import requests
from requests.exceptions import RequestException
from ..tools import Tool from ..tools import Tool
class DuckDuckGoSearchTool(Tool): class DuckDuckGoSearchTool(Tool):
name = "web_search" name = "web_search"
description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements. description = """Performs a web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
Each result has keys 'title', 'href' and 'body'.""" Each result has keys 'title', 'href' and 'body'."""
inputs = { inputs = {
"query": {"type": "string", "description": "The search query to perform."} "query": {"type": "string", "description": "The search query to perform."}
@ -56,9 +52,11 @@ class VisitWebpageTool(Tool):
def forward(self, url: str) -> str: def forward(self, url: str) -> str:
try: try:
from markdownify import markdownify from markdownify import markdownify
import requests
from requests.exceptions import RequestException
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"You must install package `markdownify` to run this tool: for instance run `pip install markdownify`." "You must install packages `markdownify` and `requests` to run this tool: for instance run `pip install markdownify requests`."
) )
try: try:
# Send a GET request to the URL # Send a GET request to the URL

View File

@ -3,7 +3,7 @@ from typing import List, Optional
import warnings import warnings
import socket import socket
from agents.tool import Tool from agents.tools import Tool
class DockerPythonInterpreter: class DockerPythonInterpreter:
def __init__(self): def __init__(self):

View File

@ -343,7 +343,7 @@ if __name__ == '__main__':
def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any: def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any:
from .local_python_executor import evaluate_python_code, LIST_SAFE_MODULES from .local_python_executor import evaluate_python_code, BASE_BUILTIN_MODULES
"""Execute code locally with state transfer.""" """Execute code locally with state transfer."""
state_manager = StateManager(work_dir) state_manager = StateManager(work_dir)
@ -363,7 +363,7 @@ def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any:
tools, tools,
{}, {},
namespace, namespace,
LIST_SAFE_MODULES, BASE_BUILTIN_MODULES,
) )
# Save state for Docker # Save state for Docker

103
src/agents/e2b_executor.py Normal file
View File

@ -0,0 +1,103 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dotenv import load_dotenv
import textwrap
import base64
from io import BytesIO
from PIL import Image
from e2b_code_interpreter import Sandbox
from typing import Dict, List, Callable, Tuple, Any
from .tool_validation import validate_tool_attributes
from .utils import instance_to_source, BASE_BUILTIN_MODULES
from .tools import Tool
from .types import AgentImage
load_dotenv()
class E2BExecutor():
def __init__(self, additional_imports: List[str], tools: List[Tool]):
self.custom_tools = {}
self.sbx = Sandbox() # "qywp2ctmu2q7jzprcf4j")
# TODO: validate installing agents package or not
# print("Installing agents package on remote executor...")
# self.sbx.commands.run(
# "pip install git+https://github.com/huggingface/agents.git",
# timeout=300
# )
# print("Installation of agents package finished.")
if len(additional_imports) > 0:
execution = self.sbx.commands.run("pip install " + " ".join(additional_imports))
if execution.error:
raise Exception(f"Error installing dependencies: {execution.error}")
else:
print("Installation succeeded!")
tool_codes = []
for tool in tools:
validate_tool_attributes(tool.__class__, check_imports=False)
tool_code = instance_to_source(tool, base_cls=Tool)
tool_code = tool_code.replace("from agents.tools import Tool", "")
tool_code += f"\n{tool.name} = {tool.__class__.__name__}()\n"
tool_codes.append(tool_code)
tool_definition_code = "\n".join([f"import {module}" for module in BASE_BUILTIN_MODULES])
tool_definition_code += textwrap.dedent("""
class Tool:
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def forward(self, *args, **kwargs):
pass # to be implemented in child class
""")
tool_definition_code += "\n\n".join(tool_codes)
tool_definition_execution = self.run_code_raise_errors(tool_definition_code)
print(tool_definition_execution.logs)
def run_code_raise_errors(self, code: str):
execution = self.sbx.run_code(
code,
)
if execution.error:
logs = 'Executing code yielded an error:'
logs += execution.error.name
logs += execution.error.value
logs += execution.error.traceback
raise ValueError(logs)
return execution
def __call__(self, code_action: str) -> Tuple[Any, Any]:
execution = self.run_code_raise_errors(code_action)
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
if not execution.results:
return None, execution_logs
else:
for result in execution.results:
if result.is_main_result:
for attribute_name in ['jpeg', 'png']:
if getattr(result, attribute_name) is not None:
image_output = getattr(result, attribute_name)
decoded_bytes = base64.b64decode(image_output.encode('utf-8'))
return Image.open(BytesIO(decoded_bytes)), execution_logs
for attribute_name in ['chart', 'data', 'html', 'javascript', 'json', 'latex', 'markdown', 'pdf', 'svg', 'text']:
if getattr(result, attribute_name) is not None:
return getattr(result, attribute_name), execution_logs
raise ValueError("No main result returned by executor!")
__all__ = ["E2BExecutor"]

View File

@ -14,7 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .types import AgentAudio, AgentImage, AgentText from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
from .agents import BaseAgent, AgentStep, ActionStep from .agents import BaseAgent, AgentStep, ActionStep
import gradio as gr import gradio as gr
@ -58,7 +58,7 @@ def stream_to_gradio(
for message in pull_messages_from_step(step_log, test_mode=test_mode): for message in pull_messages_from_step(step_log, test_mode=test_mode):
yield message yield message
final_answer = step_log # Last log is the run's final_answer final_answer = handle_agent_output_types(step_log) # Last log is the run's final_answer
if isinstance(final_answer, AgentText): if isinstance(final_answer, AgentText):
yield gr.ChatMessage( yield gr.ChatMessage(
@ -93,7 +93,7 @@ class GradioUI:
yield messages yield messages
yield messages yield messages
def run(self): def launch(self):
with gr.Blocks() as demo: with gr.Blocks() as demo:
stored_message = gr.State([]) stored_message = gr.State([])
chatbot = gr.Chatbot( chatbot = gr.Chatbot(

View File

@ -19,12 +19,12 @@ import builtins
import difflib import difflib
from collections.abc import Mapping from collections.abc import Mapping
from importlib import import_module from importlib import import_module
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional, Tuple
import math
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from .utils import truncate_content from .utils import truncate_content, BASE_BUILTIN_MODULES
class InterpreterError(ValueError): class InterpreterError(ValueError):
@ -43,24 +43,66 @@ ERRORS = {
and issubclass(getattr(builtins, name), BaseException) and issubclass(getattr(builtins, name), BaseException)
} }
LIST_SAFE_MODULES = [
"random",
"collections",
"math",
"time",
"queue",
"itertools",
"re",
"stat",
"statistics",
"unicodedata",
]
PRINT_OUTPUTS, MAX_LEN_OUTPUT = "", 50000 PRINT_OUTPUTS, MAX_LEN_OUTPUT = "", 50000
OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000 OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000
def custom_print(*args):
return None
BASE_PYTHON_TOOLS = {
"print": custom_print,
"isinstance": isinstance,
"range": range,
"float": float,
"int": int,
"bool": bool,
"str": str,
"set": set,
"list": list,
"dict": dict,
"tuple": tuple,
"round": round,
"ceil": math.ceil,
"floor": math.floor,
"log": math.log,
"exp": math.exp,
"sin": math.sin,
"cos": math.cos,
"tan": math.tan,
"asin": math.asin,
"acos": math.acos,
"atan": math.atan,
"atan2": math.atan2,
"degrees": math.degrees,
"radians": math.radians,
"pow": math.pow,
"sqrt": math.sqrt,
"len": len,
"sum": sum,
"max": max,
"min": min,
"abs": abs,
"enumerate": enumerate,
"zip": zip,
"reversed": reversed,
"sorted": sorted,
"all": all,
"any": any,
"map": map,
"filter": filter,
"ord": ord,
"chr": chr,
"next": next,
"iter": iter,
"divmod": divmod,
"callable": callable,
"getattr": getattr,
"hasattr": hasattr,
"setattr": setattr,
"issubclass": issubclass,
"type": type,
}
class BreakException(Exception): class BreakException(Exception):
pass pass
@ -771,7 +813,7 @@ def evaluate_ast(
state: Dict[str, Any], state: Dict[str, Any],
static_tools: Dict[str, Callable], static_tools: Dict[str, Callable],
custom_tools: Dict[str, Callable], custom_tools: Dict[str, Callable],
authorized_imports: List[str] = LIST_SAFE_MODULES, authorized_imports: List[str] = BASE_BUILTIN_MODULES,
): ):
""" """
Evaluate an abstract syntax tree using the content of the variables stored in a state and only evaluating a given Evaluate an abstract syntax tree using the content of the variables stored in a state and only evaluating a given
@ -949,7 +991,7 @@ def evaluate_python_code(
static_tools: Optional[Dict[str, Callable]] = None, static_tools: Optional[Dict[str, Callable]] = None,
custom_tools: Optional[Dict[str, Callable]] = None, custom_tools: Optional[Dict[str, Callable]] = None,
state: Optional[Dict[str, Any]] = None, state: Optional[Dict[str, Any]] = None,
authorized_imports: List[str] = LIST_SAFE_MODULES, authorized_imports: List[str] = BASE_BUILTIN_MODULES,
): ):
""" """
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
@ -1001,4 +1043,30 @@ def evaluate_python_code(
raise InterpreterError(msg) raise InterpreterError(msg)
__all__ = ["evaluate_python_code"] class LocalPythonExecutor():
def __init__(self, additional_authorized_imports: List[str], tools: Dict):
self.custom_tools = {}
self.state = {}
self.additional_authorized_imports = additional_authorized_imports
self.authorized_imports = list(
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
)
# Add base trusted tools to list
self.static_tools = {
**tools,
**BASE_PYTHON_TOOLS.copy(),
}
# TODO: assert self.authorized imports are all installed locally
def __call__(self, code_action: str) -> Tuple[Any, str]:
output = evaluate_python_code(
code_action,
static_tools=self.static_tools,
custom_tools=self.custom_tools,
state=self.state,
authorized_imports=self.authorized_imports,
)
logs = self.state["print_outputs"]
return output, logs
__all__ = ["evaluate_python_code", "LocalPythonExecutor"]

View File

@ -370,7 +370,7 @@ Here are the rules you should always follow to solve your task:
5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters. 5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters.
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'. 6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables. 7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables.
8. You can use imports in your code, but only from the following list of modules: <<authorized_imports>> 8. You can use imports in your code, but only from the following list of modules: {{authorized_imports}}
9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist. 9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
10. Don't give up! You're in charge of solving the task, not providing directions to solve it. 10. Don't give up! You're in charge of solving the task, not providing directions to solve it.

View File

@ -5,27 +5,30 @@ import builtins
from pathlib import Path from pathlib import Path
from typing import List, Set, Dict from typing import List, Set, Dict
import textwrap import textwrap
from .utils import BASE_BUILTIN_MODULES
_BUILTIN_NAMES = set(vars(builtins)) _BUILTIN_NAMES = set(vars(builtins))
def is_local_import(module_name: str) -> bool: IMPORTED_PACKAGES = BASE_BUILTIN_MODULES
def is_installed_package(module_name: str) -> bool:
""" """
Check if an import is from a local file or a package. Check if an import is from an installed package.
Returns True if it's a local file import. Returns False if it's not found or a local file import.
""" """
try: try:
spec = importlib.util.find_spec(module_name) spec = importlib.util.find_spec(module_name)
if spec is None: if spec is None:
return True # If we can't find the module, assume it's local return False # If we can't find the module, assume it's local
# If the module is found and has a file path, check if it's in site-packages # If the module is found and has a file path, check if it's in site-packages
if spec.origin and 'site-packages' not in spec.origin: if spec.origin and 'site-packages' not in spec.origin:
# Check if it's a .py file in the current directory or subdirectories # Check if it's a .py file in the current directory or subdirectories
return spec.origin.endswith('.py') return not spec.origin.endswith('.py')
return False return False
except ImportError: except ImportError:
return True # If there's an import error, assume it's local return False # If there's an import error, assume it's local
class MethodChecker(ast.NodeVisitor): class MethodChecker(ast.NodeVisitor):
""" """
@ -33,7 +36,7 @@ class MethodChecker(ast.NodeVisitor):
- only uses defined names - only uses defined names
- contains no local imports (e.g. numpy is ok but local_script is not) - contains no local imports (e.g. numpy is ok but local_script is not)
""" """
def __init__(self, class_attributes: Set[str]): def __init__(self, class_attributes: Set[str], check_imports: bool = True):
self.undefined_names = set() self.undefined_names = set()
self.imports = {} self.imports = {}
self.from_imports = {} self.from_imports = {}
@ -41,6 +44,7 @@ class MethodChecker(ast.NodeVisitor):
self.arg_names = set() self.arg_names = set()
self.class_attributes = class_attributes self.class_attributes = class_attributes
self.errors = [] self.errors = []
self.check_imports = check_imports
def visit_arguments(self, node): def visit_arguments(self, node):
"""Collect function arguments""" """Collect function arguments"""
@ -53,16 +57,16 @@ class MethodChecker(ast.NodeVisitor):
def visit_Import(self, node): def visit_Import(self, node):
for name in node.names: for name in node.names:
actual_name = name.asname or name.name actual_name = name.asname or name.name
if is_local_import(actual_name): if not is_installed_package(actual_name) and self.check_imports:
self.errors.append(f"Local import '{actual_name}'") self.errors.append(f"Package not found in importlib, might be a local install: '{actual_name}'")
self.imports[actual_name] = name.name self.imports[actual_name] = name.name
def visit_ImportFrom(self, node): def visit_ImportFrom(self, node):
module = node.module or "" module = node.module or ""
for name in node.names: for name in node.names:
actual_name = name.asname or name.name actual_name = name.asname or name.name
if is_local_import(module): if not is_installed_package(module) and self.check_imports:
self.errors.append(f"Local import '{module}'") self.errors.append(f"Package not found in importlib, might be a local install: '{module}'")
self.from_imports[actual_name] = (module, name.name) self.from_imports[actual_name] = (module, name.name)
def visit_Assign(self, node): def visit_Assign(self, node):
@ -71,6 +75,20 @@ class MethodChecker(ast.NodeVisitor):
self.assigned_names.add(target.id) self.assigned_names.add(target.id)
self.visit(node.value) self.visit(node.value)
def visit_With(self, node):
"""Track aliases in 'with' statements (the 'y' in 'with X as y')"""
for item in node.items:
if item.optional_vars: # This is the 'y' in 'with X as y'
if isinstance(item.optional_vars, ast.Name):
self.assigned_names.add(item.optional_vars.id)
self.generic_visit(node)
def visit_ExceptHandler(self, node):
"""Track exception aliases (the 'e' in 'except Exception as e')"""
if node.name: # This is the 'e' in 'except Exception as e'
self.assigned_names.add(node.name)
self.generic_visit(node)
def visit_AnnAssign(self, node): def visit_AnnAssign(self, node):
"""Track annotated assignments.""" """Track annotated assignments."""
if isinstance(node.target, ast.Name): if isinstance(node.target, ast.Name):
@ -97,6 +115,7 @@ class MethodChecker(ast.NodeVisitor):
if isinstance(node.ctx, ast.Load): if isinstance(node.ctx, ast.Load):
if not ( if not (
node.id in _BUILTIN_NAMES node.id in _BUILTIN_NAMES
or node.id in IMPORTED_PACKAGES
or node.id in self.arg_names or node.id in self.arg_names
or node.id == "self" or node.id == "self"
or node.id in self.class_attributes or node.id in self.class_attributes
@ -110,17 +129,18 @@ class MethodChecker(ast.NodeVisitor):
if isinstance(node.func, ast.Name): if isinstance(node.func, ast.Name):
if not ( if not (
node.func.id in _BUILTIN_NAMES node.func.id in _BUILTIN_NAMES
or node.func.id in IMPORTED_PACKAGES
or node.func.id in self.arg_names or node.func.id in self.arg_names
or node.func.id == "self" or node.func.id == "self"
or node.func.id in self.class_attributes or node.func.id in self.class_attributes
or node.func.id in self.imports or node.func.id in self.imports
or node.func.id in self.from_imports or node.func.id in self.from_imports
or node.func.id in self.assigned_names or node.func.id in self.assigned_names
): ):
self.errors.append(f"Name '{node.func.id}' is undefined.") self.errors.append(f"Name '{node.func.id}' is undefined.")
self.generic_visit(node) self.generic_visit(node)
def validate_tool_attributes(cls) -> None: def validate_tool_attributes(cls, check_imports: bool = True) -> None:
""" """
Validates that a Tool class follows the proper patterns: Validates that a Tool class follows the proper patterns:
0. __init__ takes no argument (args chosen at init are not traceable so we cannot rebuild the source code for them, make them class attributes!). 0. __init__ takes no argument (args chosen at init are not traceable so we cannot rebuild the source code for them, make them class attributes!).
@ -156,8 +176,17 @@ def validate_tool_attributes(cls) -> None:
self.imported_names = set() self.imported_names = set()
self.complex_attributes = set() self.complex_attributes = set()
self.class_attributes = set() self.class_attributes = set()
self.in_method = False
def visit_FunctionDef(self, node):
old_context = self.in_method
self.in_method = True
self.generic_visit(node)
self.in_method = old_context
def visit_Assign(self, node): def visit_Assign(self, node):
if self.in_method:
return
# Track class attributes # Track class attributes
for target in node.targets: for target in node.targets:
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
@ -182,7 +211,7 @@ def validate_tool_attributes(cls) -> None:
# Run checks on all methods # Run checks on all methods
for node in class_node.body: for node in class_node.body:
if isinstance(node, ast.FunctionDef): if isinstance(node, ast.FunctionDef):
method_checker = MethodChecker(class_level_checker.class_attributes) method_checker = MethodChecker(class_level_checker.class_attributes, check_imports=check_imports)
method_checker.visit(node) method_checker.visit(node)
errors += [f"- {node.name}: {error}" for error in method_checker.errors] errors += [f"- {node.name}: {error}" for error in method_checker.errors]

View File

@ -28,6 +28,7 @@ import textwrap
from functools import lru_cache, wraps from functools import lru_cache, wraps
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union, Set from typing import Any, Callable, Dict, List, Optional, Union, Set
import math
from huggingface_hub import ( from huggingface_hub import (
create_repo, create_repo,
@ -48,7 +49,7 @@ from transformers.utils import (
is_vision_available, is_vision_available,
) )
from transformers.dynamic_module_utils import get_imports from transformers.dynamic_module_utils import get_imports
from .types import ImageType, handle_agent_inputs, handle_agent_outputs from .types import ImageType, handle_agent_input_types, handle_agent_output_types
from .utils import instance_to_source from .utils import instance_to_source
from .tool_validation import validate_tool_attributes, MethodChecker from .tool_validation import validate_tool_attributes, MethodChecker
@ -66,7 +67,6 @@ if is_accelerate_available():
TOOL_CONFIG_FILE = "tool_config.json" TOOL_CONFIG_FILE = "tool_config.json"
def get_repo_type(repo_id, repo_type=None, **hub_kwargs): def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
if repo_type is not None: if repo_type is not None:
return repo_type return repo_type
@ -197,12 +197,15 @@ class Tool:
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return NotImplementedError("Write this method in your subclass of `Tool`.") return NotImplementedError("Write this method in your subclass of `Tool`.")
def __call__(self, *args, **kwargs): def __call__(self, *args, sanitize_inputs_outputs: bool = False, **kwargs):
if not self.is_initialized: if not self.is_initialized:
self.setup() self.setup()
args, kwargs = handle_agent_inputs(*args, **kwargs) if sanitize_inputs_outputs:
args, kwargs = handle_agent_input_types(*args, **kwargs)
outputs = self.forward(*args, **kwargs) outputs = self.forward(*args, **kwargs)
return handle_agent_outputs(outputs, self.output_type) if sanitize_inputs_outputs:
outputs = handle_agent_output_types(outputs, self.output_type)
return outputs
def setup(self): def setup(self):
""" """
@ -266,9 +269,7 @@ class Tool:
forward_source_code = add_self_argument(forward_source_code) forward_source_code = add_self_argument(forward_source_code)
forward_source_code = forward_source_code.replace("@tool", "").strip() forward_source_code = forward_source_code.replace("@tool", "").strip()
tool_code += "\n\n" + textwrap.indent(forward_source_code, " ") tool_code += "\n\n" + textwrap.indent(forward_source_code, " ")
with open(tool_file, "w", encoding="utf-8") as f:
f.write(tool_code)
else: # If the tool was not created by the @tool decorator, it was made by subclassing Tool else: # If the tool was not created by the @tool decorator, it was made by subclassing Tool
if type(self).__name__ in ["SpaceToolWrapper", "LangChainToolWrapper", "GradioToolWrapper"]: if type(self).__name__ in ["SpaceToolWrapper", "LangChainToolWrapper", "GradioToolWrapper"]:
raise ValueError( raise ValueError(
@ -278,8 +279,9 @@ class Tool:
validate_tool_attributes(self.__class__) validate_tool_attributes(self.__class__)
tool_code = instance_to_source(self, base_cls=Tool) tool_code = instance_to_source(self, base_cls=Tool)
with open(tool_file, "w", encoding="utf-8") as f:
f.write(tool_code) with open(tool_file, "w", encoding="utf-8") as f:
f.write(tool_code)
# Save app file # Save app file
app_file = os.path.join(output_dir, "app.py") app_file = os.path.join(output_dir, "app.py")
@ -719,6 +721,9 @@ def launch_gradio_demo(tool: Tool):
"number": gr.Textbox, "number": gr.Textbox,
} }
def fn(*args, **kwargs):
return tool(*args, **kwargs, sanitize_inputs_outputs=True)
gradio_inputs = [] gradio_inputs = []
for input_name, input_details in tool.inputs.items(): for input_name, input_details in tool.inputs.items():
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[ input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[
@ -733,7 +738,7 @@ def launch_gradio_demo(tool: Tool):
gradio_output = output_gradio_componentclass(label="Output") gradio_output = output_gradio_componentclass(label="Output")
gr.Interface( gr.Interface(
fn=tool, # This works because `tool` has a __call__ method fn=fn,
inputs=gradio_inputs, inputs=gradio_inputs,
outputs=gradio_output, outputs=gradio_output,
title=tool.name, title=tool.name,
@ -823,61 +828,6 @@ def add_description(description):
return inner return inner
## Will move to the Hub
class EndpointClient:
def __init__(self, endpoint_url: str, token: Optional[str] = None):
self.headers = {
**build_hf_headers(token=token),
"Content-Type": "application/json",
}
self.endpoint_url = endpoint_url
@staticmethod
def encode_image(image):
_bytes = io.BytesIO()
image.save(_bytes, format="PNG")
b64 = base64.b64encode(_bytes.getvalue())
return b64.decode("utf-8")
@staticmethod
def decode_image(raw_image):
if not is_vision_available():
raise ImportError(
"This tool returned an image but Pillow is not installed. Please install it (`pip install Pillow`)."
)
from PIL import Image
b64 = base64.b64decode(raw_image)
_bytes = io.BytesIO(b64)
return Image.open(_bytes)
def __call__(
self,
inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None,
params: Optional[Dict] = None,
data: Optional[bytes] = None,
output_image: bool = False,
) -> Any:
# Build payload
payload = {}
if inputs:
payload["inputs"] = inputs
if params:
payload["parameters"] = params
# Make API call
response = get_session().post(
self.endpoint_url, headers=self.headers, json=payload, data=data
)
# By default, parse the response for the user.
if output_image:
return self.decode_image(response.content)
else:
return response.json()
class ToolCollection: class ToolCollection:
""" """
Tool collections enable loading all Spaces from a collection in order to be added to the agent's toolbox. Tool collections enable loading all Spaces from a collection in order to be added to the agent's toolbox.
@ -1063,4 +1013,5 @@ __all__ = [
"load_tool", "load_tool",
"launch_gradio_demo", "launch_gradio_demo",
"Toolbox", "Toolbox",
"ToolCollection",
] ]

View File

@ -16,6 +16,7 @@ import os
import pathlib import pathlib
import tempfile import tempfile
import uuid import uuid
from io import BytesIO
import numpy as np import numpy as np
@ -105,6 +106,8 @@ class AgentImage(AgentType, ImageType):
if isinstance(value, ImageType): if isinstance(value, ImageType):
self._raw = value self._raw = value
elif isinstance(value, bytes):
self._raw = Image.open(BytesIO(value))
elif isinstance(value, (str, pathlib.Path)): elif isinstance(value, (str, pathlib.Path)):
self._path = value self._path = value
elif isinstance(value, torch.Tensor): elif isinstance(value, torch.Tensor):
@ -241,13 +244,13 @@ class AgentAudio(AgentType, str):
AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio} AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio}
INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage} INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage, np.ndarray: AgentAudio}
if is_torch_available(): if is_torch_available():
INSTANCE_TYPE_MAPPING[Tensor] = AgentAudio INSTANCE_TYPE_MAPPING[Tensor] = AgentAudio
def handle_agent_inputs(*args, **kwargs): def handle_agent_input_types(*args, **kwargs):
args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args] args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args]
kwargs = { kwargs = {
k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items() k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()
@ -255,7 +258,7 @@ def handle_agent_inputs(*args, **kwargs):
return args, kwargs return args, kwargs
def handle_agent_outputs(output, output_type=None): def handle_agent_output_types(output, output_type=None):
if output_type in AGENT_TYPE_MAPPING: if output_type in AGENT_TYPE_MAPPING:
# If the class has defined outputs, we can map directly according to the class definition # If the class has defined outputs, we can map directly according to the class definition
decoded_outputs = AGENT_TYPE_MAPPING[output_type](output) decoded_outputs = AGENT_TYPE_MAPPING[output_type](output)

View File

@ -34,7 +34,18 @@ def is_pygments_available():
console = Console() console = Console()
BASE_BUILTIN_MODULES = [
"random",
"collections",
"math",
"time",
"queue",
"itertools",
"re",
"stat",
"statistics",
"unicodedata",
]
def parse_json_blob(json_blob: str) -> Dict[str, str]: def parse_json_blob(json_blob: str) -> Dict[str, str]:
try: try:
first_accolade_index = json_blob.find("{") first_accolade_index = json_blob.find("{")
@ -190,7 +201,10 @@ def instance_to_source(instance, base_cls=None):
for name, value in class_attrs.items(): for name, value in class_attrs.items():
if isinstance(value, str): if isinstance(value, str):
class_lines.append(f' {name} = "{value}"') if "\n" in value:
class_lines.append(f' {name} = """{value}"""')
else:
class_lines.append(f' {name} = "{value}"')
else: else:
class_lines.append(f' {name} = {repr(value)}') class_lines.append(f' {name} = {repr(value)}')
@ -230,7 +244,8 @@ def instance_to_source(instance, base_cls=None):
final_lines.append(f"from {base_cls.__module__} import {base_cls.__name__}") final_lines.append(f"from {base_cls.__module__} import {base_cls.__name__}")
# Add discovered imports # Add discovered imports
final_lines.extend(required_imports) for package in required_imports:
final_lines.append(f"import {package}")
if final_lines: # Add empty line after imports if final_lines: # Add empty line after imports
final_lines.append("") final_lines.append("")

View File

@ -29,7 +29,7 @@ from agents.agents import (
Toolbox, Toolbox,
ToolCall, ToolCall,
) )
from agents.tool import tool from agents.tools import tool
from agents.default_tools import PythonInterpreterTool from agents.default_tools import PythonInterpreterTool
from transformers.testing_utils import get_tests_dir from transformers.testing_utils import get_tests_dir

View File

@ -26,7 +26,7 @@ from agents.types import (
AgentImage, AgentImage,
AgentText, AgentText,
) )
from agents.tool import Tool, tool, AUTHORIZED_TYPES from agents.tools import Tool, tool, AUTHORIZED_TYPES
from transformers.testing_utils import get_tests_dir from transformers.testing_utils import get_tests_dir