Add E2B code interpreter 🥳

This commit is contained in:
Aymeric 2024-12-20 16:18:40 +01:00
parent 7b0b01d8f3
commit c18bc9037d
24 changed files with 400 additions and 243 deletions

7
.gitignore vendored
View File

@ -36,6 +36,7 @@ sdist/
var/
wheels/
share/python-wheels/
node_modules/
*.egg-info/
.installed.cfg
*.egg
@ -156,4 +157,8 @@ dmypy.json
cython_debug/
# PyCharm
#.idea/
#.idea/
# Archive
archive/
savedir/

View File

@ -1,5 +1,5 @@
# Base Python image
FROM python:3.9-slim
FROM python:3.12-slim
# Set working directory
WORKDIR /app
@ -7,8 +7,6 @@ WORKDIR /app
# Install build dependencies
RUN apt-get update && apt-get install -y \
build-essential \
gcc \
g++ \
zlib1g-dev \
libjpeg-dev \
libpng-dev \

5
e2b.Dockerfile Normal file
View File

@ -0,0 +1,5 @@
# You can use most Debian-based base images
FROM e2bdev/code-interpreter:latest
# Install dependencies and customize sandbox
RUN pip install git+https://github.com/huggingface/agents.git

16
e2b.toml Normal file
View File

@ -0,0 +1,16 @@
# This is a config for E2B sandbox template.
# You can use template ID (qywp2ctmu2q7jzprcf4j) to create a sandbox:
# Python SDK
# from e2b import Sandbox, AsyncSandbox
# sandbox = Sandbox("qywp2ctmu2q7jzprcf4j") # Sync sandbox
# sandbox = await AsyncSandbox.create("qywp2ctmu2q7jzprcf4j") # Async sandbox
# JS SDK
# import { Sandbox } from 'e2b'
# const sandbox = await Sandbox.create('qywp2ctmu2q7jzprcf4j')
team_id = "f8776d3a-df2f-4a1d-af48-68c2e13b3b87"
start_cmd = "/root/.jupyter/start-up.sh"
dockerfile = "e2b.Dockerfile"
template_id = "qywp2ctmu2q7jzprcf4j"

View File

@ -1,8 +1,8 @@
from agents.tools.search import DuckDuckGoSearchTool
from agents.default_tools.search import DuckDuckGoSearchTool
from agents.docker_alternative import DockerPythonInterpreter
from agents.tool import Tool
from agents.tools import Tool
class DummyTool(Tool):
name = "echo"

View File

@ -1,4 +1,4 @@
from agents.tool import Tool
from agents.tools import Tool
class DummyTool(Tool):

44
examples/e2b_example.py Normal file
View File

@ -0,0 +1,44 @@
from agents import Tool, CodeAgent
from agents.default_tools.search import VisitWebpageTool
from dotenv import load_dotenv
load_dotenv()
LAUNCH_GRADIO = False
class GetCatImageTool(Tool):
name="get_cat_image"
description = "Get a cat image"
inputs = {}
output_type = "image"
def __init__(self):
super().__init__()
self.url = "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png"
def forward(self):
from PIL import Image
import requests
from io import BytesIO
response = requests.get(self.url)
return Image.open(BytesIO(response.content))
get_cat_image = GetCatImageTool()
agent = CodeAgent(
tools = [get_cat_image, VisitWebpageTool()],
additional_authorized_imports=["Pillow", "requests", "markdownify"], # "duckduckgo-search",
use_e2b_executor=False
)
if LAUNCH_GRADIO:
from agents.gradio_ui import GradioUI
GradioUI(agent).launch()
else:
agent.run(
"Return me an image of Lincoln's preferred pet",
additional_context="Here is a webpage about US presidents and pets: https://www.9lives.com/blog/a-history-of-cats-in-the-white-house/"
)

View File

@ -24,22 +24,28 @@ from transformers.utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .agents import *
from .default_tools import *
from .default_tools.base import *
from .default_tools.search import *
from .gradio_ui import *
from .llm_engines import *
from .local_python_executor import *
from .monitoring import *
from .prompts import *
from .tools.search import *
from .tool import *
from .tools import *
from .types import *
from .utils import *
from .default_tools.search import *
else:
import sys
_file = globals()["__file__"]
import_structure = define_import_structure(_file)
import_structure[""]= {"__version__": __version__}
sys.modules[__name__] = _LazyModule(
__name__, _file, define_import_structure(_file), module_spec=__spec__
__name__,
_file,
import_structure,
module_spec=__spec__,
extra_objects={"__version__": __version__}
)

View File

@ -27,7 +27,7 @@ from transformers.utils import is_torch_available
from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content
from .types import AgentAudio, AgentImage
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool
from .default_tools.base import FinalAnswerTool
from .llm_engines import HfApiEngine, MessageRole
from .monitoring import Monitor
from .prompts import (
@ -42,8 +42,9 @@ from .prompts import (
SYSTEM_PROMPT_PLAN_UPDATE,
SYSTEM_PROMPT_PLAN,
)
from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code
from .tool import (
from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonExecutor
from .e2b_executor import E2BExecutor
from .tools import (
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
Tool,
get_tool_description_with_args,
@ -169,17 +170,6 @@ def format_prompt_with_managed_agents_descriptions(
else:
return prompt_template.replace(agent_descriptions_placeholder, "")
def format_prompt_with_imports(
prompt_template: str, authorized_imports: List[str]
) -> str:
if "<<authorized_imports>>" not in prompt_template:
raise AgentError(
"Tag '<<authorized_imports>>' should be provided in the prompt."
)
return prompt_template.replace("<<authorized_imports>>", str(authorized_imports))
class BaseAgent:
def __init__(
self,
@ -264,11 +254,6 @@ class BaseAgent:
self.system_prompt = format_prompt_with_managed_agents_descriptions(
self.system_prompt, self.managed_agents
)
if hasattr(self, "authorized_imports"):
self.system_prompt = format_prompt_with_imports(
self.system_prompt,
list(set(LIST_SAFE_MODULES) | set(getattr(self, "authorized_imports"))),
)
return self.system_prompt
@ -439,9 +424,7 @@ class ReactAgent(BaseAgent):
tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox).
arguments (Dict[str, str]): Arguments passed to the Tool.
"""
available_tools = self.toolbox.tools
if self.managed_agents is not None:
available_tools = {**available_tools, **self.managed_agents}
available_tools = {**self.toolbox.tools, **self.managed_agents}
if tool_name not in available_tools:
error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
console.print(f"[bold red]{error_msg}")
@ -674,8 +657,6 @@ Now begin!""",
),
managed_agents_descriptions=(
show_agents_descriptions(self.managed_agents)
if self.managed_agents is not None
else ""
),
answer_facts=answer_facts,
),
@ -729,8 +710,6 @@ Now begin!""",
),
managed_agents_descriptions=(
show_agents_descriptions(self.managed_agents)
if self.managed_agents is not None
else ""
),
facts_update=facts_update,
remaining_steps=(self.max_iterations - iteration),
@ -891,6 +870,7 @@ class CodeAgent(ReactAgent):
grammar: Optional[Dict[str, str]] = None,
additional_authorized_imports: Optional[List[str]] = None,
planning_interval: Optional[int] = None,
use_e2b_executor: bool = False,
**kwargs,
):
if llm_engine is None:
@ -909,17 +889,24 @@ class CodeAgent(ReactAgent):
**kwargs,
)
self.python_evaluator = evaluate_python_code
self.additional_authorized_imports = (
additional_authorized_imports if additional_authorized_imports else []
)
all_tools = {**self.toolbox.tools, **self.managed_agents}
if use_e2b_executor:
self.python_executor = E2BExecutor(self.additional_authorized_imports, list(all_tools.values()))
else:
self.python_executor = LocalPythonExecutor(self.additional_authorized_imports, all_tools)
self.authorized_imports = list(
set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
)
if "{{authorized_imports}}" not in self.system_prompt:
raise AgentError(
"Tag '{{authorized_imports}}' should be provided in the prompt."
)
self.system_prompt = self.system_prompt.replace(
"{{authorized_imports}}", str(self.authorized_imports)
)
self.custom_tools = {}
def step(self, log_entry: ActionStep) -> Union[None, Any]:
"""
@ -991,22 +978,12 @@ class CodeAgent(ReactAgent):
)
try:
static_tools = {
**BASE_PYTHON_TOOLS.copy(),
**self.toolbox.tools,
}
if self.managed_agents is not None:
static_tools = {**static_tools, **self.managed_agents}
output = self.python_evaluator(
output, execution_logs = self.python_executor(
code_action,
static_tools=static_tools,
custom_tools=self.custom_tools,
state=self.state,
authorized_imports=self.authorized_imports,
)
if len(self.state["print_outputs"]) > 0:
console.print(Group(Text("Print outputs:", style="bold"), Text(self.state["print_outputs"])))
observation = "Print outputs:\n" + self.state["print_outputs"]
if len(execution_logs) > 0:
console.print(Group(Text("Execution logs:", style="bold"), Text(execution_logs)))
observation = "Execution logs:\n" + execution_logs
if output is not None:
truncated_output = truncate_content(
str(output)
@ -1026,7 +1003,7 @@ class CodeAgent(ReactAgent):
console.print(Group(Text("Final answer:", style="bold"), Text(str(output), style="bold green")))
log_entry.action_output = output
return output
return None
class ManagedAgent:

View File

@ -15,75 +15,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import math
from dataclasses import dataclass
from math import sqrt
from typing import Dict
from huggingface_hub import hf_hub_download, list_spaces
from transformers.utils import is_offline_mode
from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code
from .tool import TOOL_CONFIG_FILE, Tool
def custom_print(*args):
return None
BASE_PYTHON_TOOLS = {
"print": custom_print,
"isinstance": isinstance,
"range": range,
"float": float,
"int": int,
"bool": bool,
"str": str,
"set": set,
"list": list,
"dict": dict,
"tuple": tuple,
"round": round,
"ceil": math.ceil,
"floor": math.floor,
"log": math.log,
"exp": math.exp,
"sin": math.sin,
"cos": math.cos,
"tan": math.tan,
"asin": math.asin,
"acos": math.acos,
"atan": math.atan,
"atan2": math.atan2,
"degrees": math.degrees,
"radians": math.radians,
"pow": math.pow,
"sqrt": sqrt,
"len": len,
"sum": sum,
"max": max,
"min": min,
"abs": abs,
"enumerate": enumerate,
"zip": zip,
"reversed": reversed,
"sorted": sorted,
"all": all,
"any": any,
"map": map,
"filter": filter,
"ord": ord,
"chr": chr,
"next": next,
"iter": iter,
"divmod": divmod,
"callable": callable,
"getattr": getattr,
"hasattr": hasattr,
"setattr": setattr,
"issubclass": issubclass,
"type": type,
}
from ..local_python_executor import BASE_BUILTIN_MODULES, BASE_PYTHON_TOOLS, evaluate_python_code
from ..tools import TOOL_CONFIG_FILE, Tool
@dataclass
@ -136,10 +75,10 @@ class PythonInterpreterTool(Tool):
def __init__(self, *args, authorized_imports=None, **kwargs):
if authorized_imports is None:
self.authorized_imports = list(set(LIST_SAFE_MODULES))
self.authorized_imports = list(set(BASE_BUILTIN_MODULES))
else:
self.authorized_imports = list(
set(LIST_SAFE_MODULES) | set(authorized_imports)
set(BASE_BUILTIN_MODULES) | set(authorized_imports)
)
self.inputs = {
"code": {

View File

@ -16,15 +16,11 @@
# limitations under the License.
import re
import requests
from requests.exceptions import RequestException
from ..tools import Tool
class DuckDuckGoSearchTool(Tool):
name = "web_search"
description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
description = """Performs a web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
Each result has keys 'title', 'href' and 'body'."""
inputs = {
"query": {"type": "string", "description": "The search query to perform."}
@ -56,9 +52,11 @@ class VisitWebpageTool(Tool):
def forward(self, url: str) -> str:
try:
from markdownify import markdownify
import requests
from requests.exceptions import RequestException
except ImportError:
raise ImportError(
"You must install package `markdownify` to run this tool: for instance run `pip install markdownify`."
"You must install packages `markdownify` and `requests` to run this tool: for instance run `pip install markdownify requests`."
)
try:
# Send a GET request to the URL

View File

@ -3,7 +3,7 @@ from typing import List, Optional
import warnings
import socket
from agents.tool import Tool
from agents.tools import Tool
class DockerPythonInterpreter:
def __init__(self):

View File

@ -343,7 +343,7 @@ if __name__ == '__main__':
def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any:
from .local_python_executor import evaluate_python_code, LIST_SAFE_MODULES
from .local_python_executor import evaluate_python_code, BASE_BUILTIN_MODULES
"""Execute code locally with state transfer."""
state_manager = StateManager(work_dir)
@ -363,7 +363,7 @@ def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any:
tools,
{},
namespace,
LIST_SAFE_MODULES,
BASE_BUILTIN_MODULES,
)
# Save state for Docker

103
src/agents/e2b_executor.py Normal file
View File

@ -0,0 +1,103 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dotenv import load_dotenv
import textwrap
import base64
from io import BytesIO
from PIL import Image
from e2b_code_interpreter import Sandbox
from typing import Dict, List, Callable, Tuple, Any
from .tool_validation import validate_tool_attributes
from .utils import instance_to_source, BASE_BUILTIN_MODULES
from .tools import Tool
from .types import AgentImage
load_dotenv()
class E2BExecutor():
def __init__(self, additional_imports: List[str], tools: List[Tool]):
self.custom_tools = {}
self.sbx = Sandbox() # "qywp2ctmu2q7jzprcf4j")
# TODO: validate installing agents package or not
# print("Installing agents package on remote executor...")
# self.sbx.commands.run(
# "pip install git+https://github.com/huggingface/agents.git",
# timeout=300
# )
# print("Installation of agents package finished.")
if len(additional_imports) > 0:
execution = self.sbx.commands.run("pip install " + " ".join(additional_imports))
if execution.error:
raise Exception(f"Error installing dependencies: {execution.error}")
else:
print("Installation succeeded!")
tool_codes = []
for tool in tools:
validate_tool_attributes(tool.__class__, check_imports=False)
tool_code = instance_to_source(tool, base_cls=Tool)
tool_code = tool_code.replace("from agents.tools import Tool", "")
tool_code += f"\n{tool.name} = {tool.__class__.__name__}()\n"
tool_codes.append(tool_code)
tool_definition_code = "\n".join([f"import {module}" for module in BASE_BUILTIN_MODULES])
tool_definition_code += textwrap.dedent("""
class Tool:
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def forward(self, *args, **kwargs):
pass # to be implemented in child class
""")
tool_definition_code += "\n\n".join(tool_codes)
tool_definition_execution = self.run_code_raise_errors(tool_definition_code)
print(tool_definition_execution.logs)
def run_code_raise_errors(self, code: str):
execution = self.sbx.run_code(
code,
)
if execution.error:
logs = 'Executing code yielded an error:'
logs += execution.error.name
logs += execution.error.value
logs += execution.error.traceback
raise ValueError(logs)
return execution
def __call__(self, code_action: str) -> Tuple[Any, Any]:
execution = self.run_code_raise_errors(code_action)
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
if not execution.results:
return None, execution_logs
else:
for result in execution.results:
if result.is_main_result:
for attribute_name in ['jpeg', 'png']:
if getattr(result, attribute_name) is not None:
image_output = getattr(result, attribute_name)
decoded_bytes = base64.b64decode(image_output.encode('utf-8'))
return Image.open(BytesIO(decoded_bytes)), execution_logs
for attribute_name in ['chart', 'data', 'html', 'javascript', 'json', 'latex', 'markdown', 'pdf', 'svg', 'text']:
if getattr(result, attribute_name) is not None:
return getattr(result, attribute_name), execution_logs
raise ValueError("No main result returned by executor!")
__all__ = ["E2BExecutor"]

View File

@ -14,7 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .types import AgentAudio, AgentImage, AgentText
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
from .agents import BaseAgent, AgentStep, ActionStep
import gradio as gr
@ -58,7 +58,7 @@ def stream_to_gradio(
for message in pull_messages_from_step(step_log, test_mode=test_mode):
yield message
final_answer = step_log # Last log is the run's final_answer
final_answer = handle_agent_output_types(step_log) # Last log is the run's final_answer
if isinstance(final_answer, AgentText):
yield gr.ChatMessage(
@ -93,7 +93,7 @@ class GradioUI:
yield messages
yield messages
def run(self):
def launch(self):
with gr.Blocks() as demo:
stored_message = gr.State([])
chatbot = gr.Chatbot(

View File

@ -19,12 +19,12 @@ import builtins
import difflib
from collections.abc import Mapping
from importlib import import_module
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple
import math
import numpy as np
import pandas as pd
from .utils import truncate_content
from .utils import truncate_content, BASE_BUILTIN_MODULES
class InterpreterError(ValueError):
@ -43,24 +43,66 @@ ERRORS = {
and issubclass(getattr(builtins, name), BaseException)
}
LIST_SAFE_MODULES = [
"random",
"collections",
"math",
"time",
"queue",
"itertools",
"re",
"stat",
"statistics",
"unicodedata",
]
PRINT_OUTPUTS, MAX_LEN_OUTPUT = "", 50000
OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000
def custom_print(*args):
return None
BASE_PYTHON_TOOLS = {
"print": custom_print,
"isinstance": isinstance,
"range": range,
"float": float,
"int": int,
"bool": bool,
"str": str,
"set": set,
"list": list,
"dict": dict,
"tuple": tuple,
"round": round,
"ceil": math.ceil,
"floor": math.floor,
"log": math.log,
"exp": math.exp,
"sin": math.sin,
"cos": math.cos,
"tan": math.tan,
"asin": math.asin,
"acos": math.acos,
"atan": math.atan,
"atan2": math.atan2,
"degrees": math.degrees,
"radians": math.radians,
"pow": math.pow,
"sqrt": math.sqrt,
"len": len,
"sum": sum,
"max": max,
"min": min,
"abs": abs,
"enumerate": enumerate,
"zip": zip,
"reversed": reversed,
"sorted": sorted,
"all": all,
"any": any,
"map": map,
"filter": filter,
"ord": ord,
"chr": chr,
"next": next,
"iter": iter,
"divmod": divmod,
"callable": callable,
"getattr": getattr,
"hasattr": hasattr,
"setattr": setattr,
"issubclass": issubclass,
"type": type,
}
class BreakException(Exception):
pass
@ -771,7 +813,7 @@ def evaluate_ast(
state: Dict[str, Any],
static_tools: Dict[str, Callable],
custom_tools: Dict[str, Callable],
authorized_imports: List[str] = LIST_SAFE_MODULES,
authorized_imports: List[str] = BASE_BUILTIN_MODULES,
):
"""
Evaluate an abstract syntax tree using the content of the variables stored in a state and only evaluating a given
@ -949,7 +991,7 @@ def evaluate_python_code(
static_tools: Optional[Dict[str, Callable]] = None,
custom_tools: Optional[Dict[str, Callable]] = None,
state: Optional[Dict[str, Any]] = None,
authorized_imports: List[str] = LIST_SAFE_MODULES,
authorized_imports: List[str] = BASE_BUILTIN_MODULES,
):
"""
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
@ -1001,4 +1043,30 @@ def evaluate_python_code(
raise InterpreterError(msg)
__all__ = ["evaluate_python_code"]
class LocalPythonExecutor():
def __init__(self, additional_authorized_imports: List[str], tools: Dict):
self.custom_tools = {}
self.state = {}
self.additional_authorized_imports = additional_authorized_imports
self.authorized_imports = list(
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
)
# Add base trusted tools to list
self.static_tools = {
**tools,
**BASE_PYTHON_TOOLS.copy(),
}
# TODO: assert self.authorized imports are all installed locally
def __call__(self, code_action: str) -> Tuple[Any, str]:
output = evaluate_python_code(
code_action,
static_tools=self.static_tools,
custom_tools=self.custom_tools,
state=self.state,
authorized_imports=self.authorized_imports,
)
logs = self.state["print_outputs"]
return output, logs
__all__ = ["evaluate_python_code", "LocalPythonExecutor"]

View File

@ -370,7 +370,7 @@ Here are the rules you should always follow to solve your task:
5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters.
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables.
8. You can use imports in your code, but only from the following list of modules: <<authorized_imports>>
8. You can use imports in your code, but only from the following list of modules: {{authorized_imports}}
9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
10. Don't give up! You're in charge of solving the task, not providing directions to solve it.

View File

@ -5,27 +5,30 @@ import builtins
from pathlib import Path
from typing import List, Set, Dict
import textwrap
from .utils import BASE_BUILTIN_MODULES
_BUILTIN_NAMES = set(vars(builtins))
def is_local_import(module_name: str) -> bool:
IMPORTED_PACKAGES = BASE_BUILTIN_MODULES
def is_installed_package(module_name: str) -> bool:
"""
Check if an import is from a local file or a package.
Returns True if it's a local file import.
Check if an import is from an installed package.
Returns False if it's not found or a local file import.
"""
try:
spec = importlib.util.find_spec(module_name)
if spec is None:
return True # If we can't find the module, assume it's local
return False # If we can't find the module, assume it's local
# If the module is found and has a file path, check if it's in site-packages
if spec.origin and 'site-packages' not in spec.origin:
# Check if it's a .py file in the current directory or subdirectories
return spec.origin.endswith('.py')
return not spec.origin.endswith('.py')
return False
except ImportError:
return True # If there's an import error, assume it's local
return False # If there's an import error, assume it's local
class MethodChecker(ast.NodeVisitor):
"""
@ -33,7 +36,7 @@ class MethodChecker(ast.NodeVisitor):
- only uses defined names
- contains no local imports (e.g. numpy is ok but local_script is not)
"""
def __init__(self, class_attributes: Set[str]):
def __init__(self, class_attributes: Set[str], check_imports: bool = True):
self.undefined_names = set()
self.imports = {}
self.from_imports = {}
@ -41,6 +44,7 @@ class MethodChecker(ast.NodeVisitor):
self.arg_names = set()
self.class_attributes = class_attributes
self.errors = []
self.check_imports = check_imports
def visit_arguments(self, node):
"""Collect function arguments"""
@ -53,16 +57,16 @@ class MethodChecker(ast.NodeVisitor):
def visit_Import(self, node):
for name in node.names:
actual_name = name.asname or name.name
if is_local_import(actual_name):
self.errors.append(f"Local import '{actual_name}'")
if not is_installed_package(actual_name) and self.check_imports:
self.errors.append(f"Package not found in importlib, might be a local install: '{actual_name}'")
self.imports[actual_name] = name.name
def visit_ImportFrom(self, node):
module = node.module or ""
for name in node.names:
actual_name = name.asname or name.name
if is_local_import(module):
self.errors.append(f"Local import '{module}'")
if not is_installed_package(module) and self.check_imports:
self.errors.append(f"Package not found in importlib, might be a local install: '{module}'")
self.from_imports[actual_name] = (module, name.name)
def visit_Assign(self, node):
@ -71,6 +75,20 @@ class MethodChecker(ast.NodeVisitor):
self.assigned_names.add(target.id)
self.visit(node.value)
def visit_With(self, node):
"""Track aliases in 'with' statements (the 'y' in 'with X as y')"""
for item in node.items:
if item.optional_vars: # This is the 'y' in 'with X as y'
if isinstance(item.optional_vars, ast.Name):
self.assigned_names.add(item.optional_vars.id)
self.generic_visit(node)
def visit_ExceptHandler(self, node):
"""Track exception aliases (the 'e' in 'except Exception as e')"""
if node.name: # This is the 'e' in 'except Exception as e'
self.assigned_names.add(node.name)
self.generic_visit(node)
def visit_AnnAssign(self, node):
"""Track annotated assignments."""
if isinstance(node.target, ast.Name):
@ -97,6 +115,7 @@ class MethodChecker(ast.NodeVisitor):
if isinstance(node.ctx, ast.Load):
if not (
node.id in _BUILTIN_NAMES
or node.id in IMPORTED_PACKAGES
or node.id in self.arg_names
or node.id == "self"
or node.id in self.class_attributes
@ -110,17 +129,18 @@ class MethodChecker(ast.NodeVisitor):
if isinstance(node.func, ast.Name):
if not (
node.func.id in _BUILTIN_NAMES
or node.func.id in IMPORTED_PACKAGES
or node.func.id in self.arg_names
or node.func.id == "self"
or node.func.id in self.class_attributes
or node.func.id in self.imports
or node.func.id in self.from_imports
or node.func.id in self.assigned_names
):
):
self.errors.append(f"Name '{node.func.id}' is undefined.")
self.generic_visit(node)
def validate_tool_attributes(cls) -> None:
def validate_tool_attributes(cls, check_imports: bool = True) -> None:
"""
Validates that a Tool class follows the proper patterns:
0. __init__ takes no argument (args chosen at init are not traceable so we cannot rebuild the source code for them, make them class attributes!).
@ -156,8 +176,17 @@ def validate_tool_attributes(cls) -> None:
self.imported_names = set()
self.complex_attributes = set()
self.class_attributes = set()
self.in_method = False
def visit_FunctionDef(self, node):
old_context = self.in_method
self.in_method = True
self.generic_visit(node)
self.in_method = old_context
def visit_Assign(self, node):
if self.in_method:
return
# Track class attributes
for target in node.targets:
if isinstance(target, ast.Name):
@ -182,7 +211,7 @@ def validate_tool_attributes(cls) -> None:
# Run checks on all methods
for node in class_node.body:
if isinstance(node, ast.FunctionDef):
method_checker = MethodChecker(class_level_checker.class_attributes)
method_checker = MethodChecker(class_level_checker.class_attributes, check_imports=check_imports)
method_checker.visit(node)
errors += [f"- {node.name}: {error}" for error in method_checker.errors]

View File

@ -28,6 +28,7 @@ import textwrap
from functools import lru_cache, wraps
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union, Set
import math
from huggingface_hub import (
create_repo,
@ -48,7 +49,7 @@ from transformers.utils import (
is_vision_available,
)
from transformers.dynamic_module_utils import get_imports
from .types import ImageType, handle_agent_inputs, handle_agent_outputs
from .types import ImageType, handle_agent_input_types, handle_agent_output_types
from .utils import instance_to_source
from .tool_validation import validate_tool_attributes, MethodChecker
@ -66,7 +67,6 @@ if is_accelerate_available():
TOOL_CONFIG_FILE = "tool_config.json"
def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
if repo_type is not None:
return repo_type
@ -197,12 +197,15 @@ class Tool:
def forward(self, *args, **kwargs):
return NotImplementedError("Write this method in your subclass of `Tool`.")
def __call__(self, *args, **kwargs):
def __call__(self, *args, sanitize_inputs_outputs: bool = False, **kwargs):
if not self.is_initialized:
self.setup()
args, kwargs = handle_agent_inputs(*args, **kwargs)
if sanitize_inputs_outputs:
args, kwargs = handle_agent_input_types(*args, **kwargs)
outputs = self.forward(*args, **kwargs)
return handle_agent_outputs(outputs, self.output_type)
if sanitize_inputs_outputs:
outputs = handle_agent_output_types(outputs, self.output_type)
return outputs
def setup(self):
"""
@ -266,9 +269,7 @@ class Tool:
forward_source_code = add_self_argument(forward_source_code)
forward_source_code = forward_source_code.replace("@tool", "").strip()
tool_code += "\n\n" + textwrap.indent(forward_source_code, " ")
with open(tool_file, "w", encoding="utf-8") as f:
f.write(tool_code)
else: # If the tool was not created by the @tool decorator, it was made by subclassing Tool
if type(self).__name__ in ["SpaceToolWrapper", "LangChainToolWrapper", "GradioToolWrapper"]:
raise ValueError(
@ -278,8 +279,9 @@ class Tool:
validate_tool_attributes(self.__class__)
tool_code = instance_to_source(self, base_cls=Tool)
with open(tool_file, "w", encoding="utf-8") as f:
f.write(tool_code)
with open(tool_file, "w", encoding="utf-8") as f:
f.write(tool_code)
# Save app file
app_file = os.path.join(output_dir, "app.py")
@ -719,6 +721,9 @@ def launch_gradio_demo(tool: Tool):
"number": gr.Textbox,
}
def fn(*args, **kwargs):
return tool(*args, **kwargs, sanitize_inputs_outputs=True)
gradio_inputs = []
for input_name, input_details in tool.inputs.items():
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[
@ -733,7 +738,7 @@ def launch_gradio_demo(tool: Tool):
gradio_output = output_gradio_componentclass(label="Output")
gr.Interface(
fn=tool, # This works because `tool` has a __call__ method
fn=fn,
inputs=gradio_inputs,
outputs=gradio_output,
title=tool.name,
@ -823,61 +828,6 @@ def add_description(description):
return inner
## Will move to the Hub
class EndpointClient:
def __init__(self, endpoint_url: str, token: Optional[str] = None):
self.headers = {
**build_hf_headers(token=token),
"Content-Type": "application/json",
}
self.endpoint_url = endpoint_url
@staticmethod
def encode_image(image):
_bytes = io.BytesIO()
image.save(_bytes, format="PNG")
b64 = base64.b64encode(_bytes.getvalue())
return b64.decode("utf-8")
@staticmethod
def decode_image(raw_image):
if not is_vision_available():
raise ImportError(
"This tool returned an image but Pillow is not installed. Please install it (`pip install Pillow`)."
)
from PIL import Image
b64 = base64.b64decode(raw_image)
_bytes = io.BytesIO(b64)
return Image.open(_bytes)
def __call__(
self,
inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None,
params: Optional[Dict] = None,
data: Optional[bytes] = None,
output_image: bool = False,
) -> Any:
# Build payload
payload = {}
if inputs:
payload["inputs"] = inputs
if params:
payload["parameters"] = params
# Make API call
response = get_session().post(
self.endpoint_url, headers=self.headers, json=payload, data=data
)
# By default, parse the response for the user.
if output_image:
return self.decode_image(response.content)
else:
return response.json()
class ToolCollection:
"""
Tool collections enable loading all Spaces from a collection in order to be added to the agent's toolbox.
@ -1063,4 +1013,5 @@ __all__ = [
"load_tool",
"launch_gradio_demo",
"Toolbox",
"ToolCollection",
]

View File

@ -16,6 +16,7 @@ import os
import pathlib
import tempfile
import uuid
from io import BytesIO
import numpy as np
@ -105,6 +106,8 @@ class AgentImage(AgentType, ImageType):
if isinstance(value, ImageType):
self._raw = value
elif isinstance(value, bytes):
self._raw = Image.open(BytesIO(value))
elif isinstance(value, (str, pathlib.Path)):
self._path = value
elif isinstance(value, torch.Tensor):
@ -241,13 +244,13 @@ class AgentAudio(AgentType, str):
AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio}
INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage}
INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage, np.ndarray: AgentAudio}
if is_torch_available():
INSTANCE_TYPE_MAPPING[Tensor] = AgentAudio
def handle_agent_inputs(*args, **kwargs):
def handle_agent_input_types(*args, **kwargs):
args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args]
kwargs = {
k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()
@ -255,7 +258,7 @@ def handle_agent_inputs(*args, **kwargs):
return args, kwargs
def handle_agent_outputs(output, output_type=None):
def handle_agent_output_types(output, output_type=None):
if output_type in AGENT_TYPE_MAPPING:
# If the class has defined outputs, we can map directly according to the class definition
decoded_outputs = AGENT_TYPE_MAPPING[output_type](output)

View File

@ -34,7 +34,18 @@ def is_pygments_available():
console = Console()
BASE_BUILTIN_MODULES = [
"random",
"collections",
"math",
"time",
"queue",
"itertools",
"re",
"stat",
"statistics",
"unicodedata",
]
def parse_json_blob(json_blob: str) -> Dict[str, str]:
try:
first_accolade_index = json_blob.find("{")
@ -190,7 +201,10 @@ def instance_to_source(instance, base_cls=None):
for name, value in class_attrs.items():
if isinstance(value, str):
class_lines.append(f' {name} = "{value}"')
if "\n" in value:
class_lines.append(f' {name} = """{value}"""')
else:
class_lines.append(f' {name} = "{value}"')
else:
class_lines.append(f' {name} = {repr(value)}')
@ -230,7 +244,8 @@ def instance_to_source(instance, base_cls=None):
final_lines.append(f"from {base_cls.__module__} import {base_cls.__name__}")
# Add discovered imports
final_lines.extend(required_imports)
for package in required_imports:
final_lines.append(f"import {package}")
if final_lines: # Add empty line after imports
final_lines.append("")

View File

@ -29,7 +29,7 @@ from agents.agents import (
Toolbox,
ToolCall,
)
from agents.tool import tool
from agents.tools import tool
from agents.default_tools import PythonInterpreterTool
from transformers.testing_utils import get_tests_dir

View File

@ -26,7 +26,7 @@ from agents.types import (
AgentImage,
AgentText,
)
from agents.tool import Tool, tool, AUTHORIZED_TYPES
from agents.tools import Tool, tool, AUTHORIZED_TYPES
from transformers.testing_utils import get_tests_dir