Add E2B code interpreter 🥳
This commit is contained in:
parent
7b0b01d8f3
commit
c18bc9037d
|
@ -36,6 +36,7 @@ sdist/
|
|||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
node_modules/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
@ -156,4 +157,8 @@ dmypy.json
|
|||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
#.idea/
|
||||
#.idea/
|
||||
|
||||
# Archive
|
||||
archive/
|
||||
savedir/
|
|
@ -1,5 +1,5 @@
|
|||
# Base Python image
|
||||
FROM python:3.9-slim
|
||||
FROM python:3.12-slim
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
@ -7,8 +7,6 @@ WORKDIR /app
|
|||
# Install build dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
build-essential \
|
||||
gcc \
|
||||
g++ \
|
||||
zlib1g-dev \
|
||||
libjpeg-dev \
|
||||
libpng-dev \
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
# You can use most Debian-based base images
|
||||
FROM e2bdev/code-interpreter:latest
|
||||
|
||||
# Install dependencies and customize sandbox
|
||||
RUN pip install git+https://github.com/huggingface/agents.git
|
|
@ -0,0 +1,16 @@
|
|||
# This is a config for E2B sandbox template.
|
||||
# You can use template ID (qywp2ctmu2q7jzprcf4j) to create a sandbox:
|
||||
|
||||
# Python SDK
|
||||
# from e2b import Sandbox, AsyncSandbox
|
||||
# sandbox = Sandbox("qywp2ctmu2q7jzprcf4j") # Sync sandbox
|
||||
# sandbox = await AsyncSandbox.create("qywp2ctmu2q7jzprcf4j") # Async sandbox
|
||||
|
||||
# JS SDK
|
||||
# import { Sandbox } from 'e2b'
|
||||
# const sandbox = await Sandbox.create('qywp2ctmu2q7jzprcf4j')
|
||||
|
||||
team_id = "f8776d3a-df2f-4a1d-af48-68c2e13b3b87"
|
||||
start_cmd = "/root/.jupyter/start-up.sh"
|
||||
dockerfile = "e2b.Dockerfile"
|
||||
template_id = "qywp2ctmu2q7jzprcf4j"
|
|
@ -1,8 +1,8 @@
|
|||
from agents.tools.search import DuckDuckGoSearchTool
|
||||
from agents.default_tools.search import DuckDuckGoSearchTool
|
||||
from agents.docker_alternative import DockerPythonInterpreter
|
||||
|
||||
|
||||
from agents.tool import Tool
|
||||
from agents.tools import Tool
|
||||
|
||||
class DummyTool(Tool):
|
||||
name = "echo"
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from agents.tool import Tool
|
||||
from agents.tools import Tool
|
||||
|
||||
|
||||
class DummyTool(Tool):
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
from agents import Tool, CodeAgent
|
||||
from agents.default_tools.search import VisitWebpageTool
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
LAUNCH_GRADIO = False
|
||||
|
||||
class GetCatImageTool(Tool):
|
||||
name="get_cat_image"
|
||||
description = "Get a cat image"
|
||||
inputs = {}
|
||||
output_type = "image"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.url = "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png"
|
||||
|
||||
def forward(self):
|
||||
from PIL import Image
|
||||
import requests
|
||||
from io import BytesIO
|
||||
|
||||
response = requests.get(self.url)
|
||||
|
||||
return Image.open(BytesIO(response.content))
|
||||
|
||||
get_cat_image = GetCatImageTool()
|
||||
|
||||
|
||||
agent = CodeAgent(
|
||||
tools = [get_cat_image, VisitWebpageTool()],
|
||||
additional_authorized_imports=["Pillow", "requests", "markdownify"], # "duckduckgo-search",
|
||||
use_e2b_executor=False
|
||||
)
|
||||
|
||||
if LAUNCH_GRADIO:
|
||||
from agents.gradio_ui import GradioUI
|
||||
|
||||
GradioUI(agent).launch()
|
||||
else:
|
||||
agent.run(
|
||||
"Return me an image of Lincoln's preferred pet",
|
||||
additional_context="Here is a webpage about US presidents and pets: https://www.9lives.com/blog/a-history-of-cats-in-the-white-house/"
|
||||
)
|
|
@ -24,22 +24,28 @@ from transformers.utils.import_utils import define_import_structure
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from .agents import *
|
||||
from .default_tools import *
|
||||
from .default_tools.base import *
|
||||
from .default_tools.search import *
|
||||
from .gradio_ui import *
|
||||
from .llm_engines import *
|
||||
from .local_python_executor import *
|
||||
from .monitoring import *
|
||||
from .prompts import *
|
||||
from .tools.search import *
|
||||
from .tool import *
|
||||
from .tools import *
|
||||
from .types import *
|
||||
from .utils import *
|
||||
from .default_tools.search import *
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
import_structure = define_import_structure(_file)
|
||||
import_structure[""]= {"__version__": __version__}
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__, _file, define_import_structure(_file), module_spec=__spec__
|
||||
__name__,
|
||||
_file,
|
||||
import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={"__version__": __version__}
|
||||
)
|
||||
|
|
|
@ -27,7 +27,7 @@ from transformers.utils import is_torch_available
|
|||
|
||||
from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content
|
||||
from .types import AgentAudio, AgentImage
|
||||
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool
|
||||
from .default_tools.base import FinalAnswerTool
|
||||
from .llm_engines import HfApiEngine, MessageRole
|
||||
from .monitoring import Monitor
|
||||
from .prompts import (
|
||||
|
@ -42,8 +42,9 @@ from .prompts import (
|
|||
SYSTEM_PROMPT_PLAN_UPDATE,
|
||||
SYSTEM_PROMPT_PLAN,
|
||||
)
|
||||
from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code
|
||||
from .tool import (
|
||||
from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonExecutor
|
||||
from .e2b_executor import E2BExecutor
|
||||
from .tools import (
|
||||
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||
Tool,
|
||||
get_tool_description_with_args,
|
||||
|
@ -169,17 +170,6 @@ def format_prompt_with_managed_agents_descriptions(
|
|||
else:
|
||||
return prompt_template.replace(agent_descriptions_placeholder, "")
|
||||
|
||||
|
||||
def format_prompt_with_imports(
|
||||
prompt_template: str, authorized_imports: List[str]
|
||||
) -> str:
|
||||
if "<<authorized_imports>>" not in prompt_template:
|
||||
raise AgentError(
|
||||
"Tag '<<authorized_imports>>' should be provided in the prompt."
|
||||
)
|
||||
return prompt_template.replace("<<authorized_imports>>", str(authorized_imports))
|
||||
|
||||
|
||||
class BaseAgent:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -264,11 +254,6 @@ class BaseAgent:
|
|||
self.system_prompt = format_prompt_with_managed_agents_descriptions(
|
||||
self.system_prompt, self.managed_agents
|
||||
)
|
||||
if hasattr(self, "authorized_imports"):
|
||||
self.system_prompt = format_prompt_with_imports(
|
||||
self.system_prompt,
|
||||
list(set(LIST_SAFE_MODULES) | set(getattr(self, "authorized_imports"))),
|
||||
)
|
||||
|
||||
return self.system_prompt
|
||||
|
||||
|
@ -439,9 +424,7 @@ class ReactAgent(BaseAgent):
|
|||
tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox).
|
||||
arguments (Dict[str, str]): Arguments passed to the Tool.
|
||||
"""
|
||||
available_tools = self.toolbox.tools
|
||||
if self.managed_agents is not None:
|
||||
available_tools = {**available_tools, **self.managed_agents}
|
||||
available_tools = {**self.toolbox.tools, **self.managed_agents}
|
||||
if tool_name not in available_tools:
|
||||
error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
|
||||
console.print(f"[bold red]{error_msg}")
|
||||
|
@ -674,8 +657,6 @@ Now begin!""",
|
|||
),
|
||||
managed_agents_descriptions=(
|
||||
show_agents_descriptions(self.managed_agents)
|
||||
if self.managed_agents is not None
|
||||
else ""
|
||||
),
|
||||
answer_facts=answer_facts,
|
||||
),
|
||||
|
@ -729,8 +710,6 @@ Now begin!""",
|
|||
),
|
||||
managed_agents_descriptions=(
|
||||
show_agents_descriptions(self.managed_agents)
|
||||
if self.managed_agents is not None
|
||||
else ""
|
||||
),
|
||||
facts_update=facts_update,
|
||||
remaining_steps=(self.max_iterations - iteration),
|
||||
|
@ -891,6 +870,7 @@ class CodeAgent(ReactAgent):
|
|||
grammar: Optional[Dict[str, str]] = None,
|
||||
additional_authorized_imports: Optional[List[str]] = None,
|
||||
planning_interval: Optional[int] = None,
|
||||
use_e2b_executor: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
if llm_engine is None:
|
||||
|
@ -909,17 +889,24 @@ class CodeAgent(ReactAgent):
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
self.python_evaluator = evaluate_python_code
|
||||
self.additional_authorized_imports = (
|
||||
additional_authorized_imports if additional_authorized_imports else []
|
||||
)
|
||||
all_tools = {**self.toolbox.tools, **self.managed_agents}
|
||||
if use_e2b_executor:
|
||||
self.python_executor = E2BExecutor(self.additional_authorized_imports, list(all_tools.values()))
|
||||
else:
|
||||
self.python_executor = LocalPythonExecutor(self.additional_authorized_imports, all_tools)
|
||||
self.authorized_imports = list(
|
||||
set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)
|
||||
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
|
||||
)
|
||||
if "{{authorized_imports}}" not in self.system_prompt:
|
||||
raise AgentError(
|
||||
"Tag '{{authorized_imports}}' should be provided in the prompt."
|
||||
)
|
||||
self.system_prompt = self.system_prompt.replace(
|
||||
"{{authorized_imports}}", str(self.authorized_imports)
|
||||
)
|
||||
self.custom_tools = {}
|
||||
|
||||
def step(self, log_entry: ActionStep) -> Union[None, Any]:
|
||||
"""
|
||||
|
@ -991,22 +978,12 @@ class CodeAgent(ReactAgent):
|
|||
)
|
||||
|
||||
try:
|
||||
static_tools = {
|
||||
**BASE_PYTHON_TOOLS.copy(),
|
||||
**self.toolbox.tools,
|
||||
}
|
||||
if self.managed_agents is not None:
|
||||
static_tools = {**static_tools, **self.managed_agents}
|
||||
output = self.python_evaluator(
|
||||
output, execution_logs = self.python_executor(
|
||||
code_action,
|
||||
static_tools=static_tools,
|
||||
custom_tools=self.custom_tools,
|
||||
state=self.state,
|
||||
authorized_imports=self.authorized_imports,
|
||||
)
|
||||
if len(self.state["print_outputs"]) > 0:
|
||||
console.print(Group(Text("Print outputs:", style="bold"), Text(self.state["print_outputs"])))
|
||||
observation = "Print outputs:\n" + self.state["print_outputs"]
|
||||
if len(execution_logs) > 0:
|
||||
console.print(Group(Text("Execution logs:", style="bold"), Text(execution_logs)))
|
||||
observation = "Execution logs:\n" + execution_logs
|
||||
if output is not None:
|
||||
truncated_output = truncate_content(
|
||||
str(output)
|
||||
|
@ -1026,7 +1003,7 @@ class CodeAgent(ReactAgent):
|
|||
console.print(Group(Text("Final answer:", style="bold"), Text(str(output), style="bold green")))
|
||||
log_entry.action_output = output
|
||||
return output
|
||||
return None
|
||||
|
||||
|
||||
|
||||
class ManagedAgent:
|
||||
|
|
|
@ -15,75 +15,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from math import sqrt
|
||||
from typing import Dict
|
||||
|
||||
from huggingface_hub import hf_hub_download, list_spaces
|
||||
|
||||
from transformers.utils import is_offline_mode
|
||||
from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code
|
||||
from .tool import TOOL_CONFIG_FILE, Tool
|
||||
|
||||
|
||||
def custom_print(*args):
|
||||
return None
|
||||
|
||||
|
||||
BASE_PYTHON_TOOLS = {
|
||||
"print": custom_print,
|
||||
"isinstance": isinstance,
|
||||
"range": range,
|
||||
"float": float,
|
||||
"int": int,
|
||||
"bool": bool,
|
||||
"str": str,
|
||||
"set": set,
|
||||
"list": list,
|
||||
"dict": dict,
|
||||
"tuple": tuple,
|
||||
"round": round,
|
||||
"ceil": math.ceil,
|
||||
"floor": math.floor,
|
||||
"log": math.log,
|
||||
"exp": math.exp,
|
||||
"sin": math.sin,
|
||||
"cos": math.cos,
|
||||
"tan": math.tan,
|
||||
"asin": math.asin,
|
||||
"acos": math.acos,
|
||||
"atan": math.atan,
|
||||
"atan2": math.atan2,
|
||||
"degrees": math.degrees,
|
||||
"radians": math.radians,
|
||||
"pow": math.pow,
|
||||
"sqrt": sqrt,
|
||||
"len": len,
|
||||
"sum": sum,
|
||||
"max": max,
|
||||
"min": min,
|
||||
"abs": abs,
|
||||
"enumerate": enumerate,
|
||||
"zip": zip,
|
||||
"reversed": reversed,
|
||||
"sorted": sorted,
|
||||
"all": all,
|
||||
"any": any,
|
||||
"map": map,
|
||||
"filter": filter,
|
||||
"ord": ord,
|
||||
"chr": chr,
|
||||
"next": next,
|
||||
"iter": iter,
|
||||
"divmod": divmod,
|
||||
"callable": callable,
|
||||
"getattr": getattr,
|
||||
"hasattr": hasattr,
|
||||
"setattr": setattr,
|
||||
"issubclass": issubclass,
|
||||
"type": type,
|
||||
}
|
||||
from ..local_python_executor import BASE_BUILTIN_MODULES, BASE_PYTHON_TOOLS, evaluate_python_code
|
||||
from ..tools import TOOL_CONFIG_FILE, Tool
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -136,10 +75,10 @@ class PythonInterpreterTool(Tool):
|
|||
|
||||
def __init__(self, *args, authorized_imports=None, **kwargs):
|
||||
if authorized_imports is None:
|
||||
self.authorized_imports = list(set(LIST_SAFE_MODULES))
|
||||
self.authorized_imports = list(set(BASE_BUILTIN_MODULES))
|
||||
else:
|
||||
self.authorized_imports = list(
|
||||
set(LIST_SAFE_MODULES) | set(authorized_imports)
|
||||
set(BASE_BUILTIN_MODULES) | set(authorized_imports)
|
||||
)
|
||||
self.inputs = {
|
||||
"code": {
|
|
@ -16,15 +16,11 @@
|
|||
# limitations under the License.
|
||||
import re
|
||||
|
||||
import requests
|
||||
from requests.exceptions import RequestException
|
||||
|
||||
from ..tools import Tool
|
||||
|
||||
|
||||
class DuckDuckGoSearchTool(Tool):
|
||||
name = "web_search"
|
||||
description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
|
||||
description = """Performs a web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
|
||||
Each result has keys 'title', 'href' and 'body'."""
|
||||
inputs = {
|
||||
"query": {"type": "string", "description": "The search query to perform."}
|
||||
|
@ -56,9 +52,11 @@ class VisitWebpageTool(Tool):
|
|||
def forward(self, url: str) -> str:
|
||||
try:
|
||||
from markdownify import markdownify
|
||||
import requests
|
||||
from requests.exceptions import RequestException
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"You must install package `markdownify` to run this tool: for instance run `pip install markdownify`."
|
||||
"You must install packages `markdownify` and `requests` to run this tool: for instance run `pip install markdownify requests`."
|
||||
)
|
||||
try:
|
||||
# Send a GET request to the URL
|
|
@ -3,7 +3,7 @@ from typing import List, Optional
|
|||
import warnings
|
||||
import socket
|
||||
|
||||
from agents.tool import Tool
|
||||
from agents.tools import Tool
|
||||
|
||||
class DockerPythonInterpreter:
|
||||
def __init__(self):
|
||||
|
|
|
@ -343,7 +343,7 @@ if __name__ == '__main__':
|
|||
|
||||
|
||||
def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any:
|
||||
from .local_python_executor import evaluate_python_code, LIST_SAFE_MODULES
|
||||
from .local_python_executor import evaluate_python_code, BASE_BUILTIN_MODULES
|
||||
|
||||
"""Execute code locally with state transfer."""
|
||||
state_manager = StateManager(work_dir)
|
||||
|
@ -363,7 +363,7 @@ def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any:
|
|||
tools,
|
||||
{},
|
||||
namespace,
|
||||
LIST_SAFE_MODULES,
|
||||
BASE_BUILTIN_MODULES,
|
||||
)
|
||||
|
||||
# Save state for Docker
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dotenv import load_dotenv
|
||||
import textwrap
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
from e2b_code_interpreter import Sandbox
|
||||
from typing import Dict, List, Callable, Tuple, Any
|
||||
from .tool_validation import validate_tool_attributes
|
||||
from .utils import instance_to_source, BASE_BUILTIN_MODULES
|
||||
from .tools import Tool
|
||||
from .types import AgentImage
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class E2BExecutor():
|
||||
def __init__(self, additional_imports: List[str], tools: List[Tool]):
|
||||
self.custom_tools = {}
|
||||
self.sbx = Sandbox() # "qywp2ctmu2q7jzprcf4j")
|
||||
# TODO: validate installing agents package or not
|
||||
# print("Installing agents package on remote executor...")
|
||||
# self.sbx.commands.run(
|
||||
# "pip install git+https://github.com/huggingface/agents.git",
|
||||
# timeout=300
|
||||
# )
|
||||
# print("Installation of agents package finished.")
|
||||
if len(additional_imports) > 0:
|
||||
execution = self.sbx.commands.run("pip install " + " ".join(additional_imports))
|
||||
if execution.error:
|
||||
raise Exception(f"Error installing dependencies: {execution.error}")
|
||||
else:
|
||||
print("Installation succeeded!")
|
||||
|
||||
tool_codes = []
|
||||
for tool in tools:
|
||||
validate_tool_attributes(tool.__class__, check_imports=False)
|
||||
tool_code = instance_to_source(tool, base_cls=Tool)
|
||||
tool_code = tool_code.replace("from agents.tools import Tool", "")
|
||||
tool_code += f"\n{tool.name} = {tool.__class__.__name__}()\n"
|
||||
tool_codes.append(tool_code)
|
||||
|
||||
tool_definition_code = "\n".join([f"import {module}" for module in BASE_BUILTIN_MODULES])
|
||||
tool_definition_code += textwrap.dedent("""
|
||||
class Tool:
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
pass # to be implemented in child class
|
||||
""")
|
||||
tool_definition_code += "\n\n".join(tool_codes)
|
||||
|
||||
tool_definition_execution = self.run_code_raise_errors(tool_definition_code)
|
||||
print(tool_definition_execution.logs)
|
||||
|
||||
def run_code_raise_errors(self, code: str):
|
||||
execution = self.sbx.run_code(
|
||||
code,
|
||||
)
|
||||
if execution.error:
|
||||
logs = 'Executing code yielded an error:'
|
||||
logs += execution.error.name
|
||||
logs += execution.error.value
|
||||
logs += execution.error.traceback
|
||||
raise ValueError(logs)
|
||||
return execution
|
||||
|
||||
def __call__(self, code_action: str) -> Tuple[Any, Any]:
|
||||
execution = self.run_code_raise_errors(code_action)
|
||||
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
|
||||
if not execution.results:
|
||||
return None, execution_logs
|
||||
else:
|
||||
for result in execution.results:
|
||||
if result.is_main_result:
|
||||
for attribute_name in ['jpeg', 'png']:
|
||||
if getattr(result, attribute_name) is not None:
|
||||
image_output = getattr(result, attribute_name)
|
||||
decoded_bytes = base64.b64decode(image_output.encode('utf-8'))
|
||||
return Image.open(BytesIO(decoded_bytes)), execution_logs
|
||||
for attribute_name in ['chart', 'data', 'html', 'javascript', 'json', 'latex', 'markdown', 'pdf', 'svg', 'text']:
|
||||
if getattr(result, attribute_name) is not None:
|
||||
return getattr(result, attribute_name), execution_logs
|
||||
raise ValueError("No main result returned by executor!")
|
||||
|
||||
__all__ = ["E2BExecutor"]
|
|
@ -14,7 +14,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .types import AgentAudio, AgentImage, AgentText
|
||||
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
|
||||
from .agents import BaseAgent, AgentStep, ActionStep
|
||||
import gradio as gr
|
||||
|
||||
|
@ -58,7 +58,7 @@ def stream_to_gradio(
|
|||
for message in pull_messages_from_step(step_log, test_mode=test_mode):
|
||||
yield message
|
||||
|
||||
final_answer = step_log # Last log is the run's final_answer
|
||||
final_answer = handle_agent_output_types(step_log) # Last log is the run's final_answer
|
||||
|
||||
if isinstance(final_answer, AgentText):
|
||||
yield gr.ChatMessage(
|
||||
|
@ -93,7 +93,7 @@ class GradioUI:
|
|||
yield messages
|
||||
yield messages
|
||||
|
||||
def run(self):
|
||||
def launch(self):
|
||||
with gr.Blocks() as demo:
|
||||
stored_message = gr.State([])
|
||||
chatbot = gr.Chatbot(
|
||||
|
|
|
@ -19,12 +19,12 @@ import builtins
|
|||
import difflib
|
||||
from collections.abc import Mapping
|
||||
from importlib import import_module
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
import math
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from .utils import truncate_content
|
||||
from .utils import truncate_content, BASE_BUILTIN_MODULES
|
||||
|
||||
|
||||
class InterpreterError(ValueError):
|
||||
|
@ -43,24 +43,66 @@ ERRORS = {
|
|||
and issubclass(getattr(builtins, name), BaseException)
|
||||
}
|
||||
|
||||
|
||||
LIST_SAFE_MODULES = [
|
||||
"random",
|
||||
"collections",
|
||||
"math",
|
||||
"time",
|
||||
"queue",
|
||||
"itertools",
|
||||
"re",
|
||||
"stat",
|
||||
"statistics",
|
||||
"unicodedata",
|
||||
]
|
||||
|
||||
PRINT_OUTPUTS, MAX_LEN_OUTPUT = "", 50000
|
||||
OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000
|
||||
|
||||
def custom_print(*args):
|
||||
return None
|
||||
|
||||
|
||||
BASE_PYTHON_TOOLS = {
|
||||
"print": custom_print,
|
||||
"isinstance": isinstance,
|
||||
"range": range,
|
||||
"float": float,
|
||||
"int": int,
|
||||
"bool": bool,
|
||||
"str": str,
|
||||
"set": set,
|
||||
"list": list,
|
||||
"dict": dict,
|
||||
"tuple": tuple,
|
||||
"round": round,
|
||||
"ceil": math.ceil,
|
||||
"floor": math.floor,
|
||||
"log": math.log,
|
||||
"exp": math.exp,
|
||||
"sin": math.sin,
|
||||
"cos": math.cos,
|
||||
"tan": math.tan,
|
||||
"asin": math.asin,
|
||||
"acos": math.acos,
|
||||
"atan": math.atan,
|
||||
"atan2": math.atan2,
|
||||
"degrees": math.degrees,
|
||||
"radians": math.radians,
|
||||
"pow": math.pow,
|
||||
"sqrt": math.sqrt,
|
||||
"len": len,
|
||||
"sum": sum,
|
||||
"max": max,
|
||||
"min": min,
|
||||
"abs": abs,
|
||||
"enumerate": enumerate,
|
||||
"zip": zip,
|
||||
"reversed": reversed,
|
||||
"sorted": sorted,
|
||||
"all": all,
|
||||
"any": any,
|
||||
"map": map,
|
||||
"filter": filter,
|
||||
"ord": ord,
|
||||
"chr": chr,
|
||||
"next": next,
|
||||
"iter": iter,
|
||||
"divmod": divmod,
|
||||
"callable": callable,
|
||||
"getattr": getattr,
|
||||
"hasattr": hasattr,
|
||||
"setattr": setattr,
|
||||
"issubclass": issubclass,
|
||||
"type": type,
|
||||
}
|
||||
class BreakException(Exception):
|
||||
pass
|
||||
|
||||
|
@ -771,7 +813,7 @@ def evaluate_ast(
|
|||
state: Dict[str, Any],
|
||||
static_tools: Dict[str, Callable],
|
||||
custom_tools: Dict[str, Callable],
|
||||
authorized_imports: List[str] = LIST_SAFE_MODULES,
|
||||
authorized_imports: List[str] = BASE_BUILTIN_MODULES,
|
||||
):
|
||||
"""
|
||||
Evaluate an abstract syntax tree using the content of the variables stored in a state and only evaluating a given
|
||||
|
@ -949,7 +991,7 @@ def evaluate_python_code(
|
|||
static_tools: Optional[Dict[str, Callable]] = None,
|
||||
custom_tools: Optional[Dict[str, Callable]] = None,
|
||||
state: Optional[Dict[str, Any]] = None,
|
||||
authorized_imports: List[str] = LIST_SAFE_MODULES,
|
||||
authorized_imports: List[str] = BASE_BUILTIN_MODULES,
|
||||
):
|
||||
"""
|
||||
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
|
||||
|
@ -1001,4 +1043,30 @@ def evaluate_python_code(
|
|||
raise InterpreterError(msg)
|
||||
|
||||
|
||||
__all__ = ["evaluate_python_code"]
|
||||
class LocalPythonExecutor():
|
||||
def __init__(self, additional_authorized_imports: List[str], tools: Dict):
|
||||
self.custom_tools = {}
|
||||
self.state = {}
|
||||
self.additional_authorized_imports = additional_authorized_imports
|
||||
self.authorized_imports = list(
|
||||
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
|
||||
)
|
||||
# Add base trusted tools to list
|
||||
self.static_tools = {
|
||||
**tools,
|
||||
**BASE_PYTHON_TOOLS.copy(),
|
||||
}
|
||||
# TODO: assert self.authorized imports are all installed locally
|
||||
|
||||
def __call__(self, code_action: str) -> Tuple[Any, str]:
|
||||
output = evaluate_python_code(
|
||||
code_action,
|
||||
static_tools=self.static_tools,
|
||||
custom_tools=self.custom_tools,
|
||||
state=self.state,
|
||||
authorized_imports=self.authorized_imports,
|
||||
)
|
||||
logs = self.state["print_outputs"]
|
||||
return output, logs
|
||||
|
||||
__all__ = ["evaluate_python_code", "LocalPythonExecutor"]
|
||||
|
|
|
@ -370,7 +370,7 @@ Here are the rules you should always follow to solve your task:
|
|||
5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters.
|
||||
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
|
||||
7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables.
|
||||
8. You can use imports in your code, but only from the following list of modules: <<authorized_imports>>
|
||||
8. You can use imports in your code, but only from the following list of modules: {{authorized_imports}}
|
||||
9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
|
||||
10. Don't give up! You're in charge of solving the task, not providing directions to solve it.
|
||||
|
||||
|
|
|
@ -5,27 +5,30 @@ import builtins
|
|||
from pathlib import Path
|
||||
from typing import List, Set, Dict
|
||||
import textwrap
|
||||
from .utils import BASE_BUILTIN_MODULES
|
||||
|
||||
_BUILTIN_NAMES = set(vars(builtins))
|
||||
|
||||
def is_local_import(module_name: str) -> bool:
|
||||
IMPORTED_PACKAGES = BASE_BUILTIN_MODULES
|
||||
|
||||
def is_installed_package(module_name: str) -> bool:
|
||||
"""
|
||||
Check if an import is from a local file or a package.
|
||||
Returns True if it's a local file import.
|
||||
Check if an import is from an installed package.
|
||||
Returns False if it's not found or a local file import.
|
||||
"""
|
||||
try:
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
if spec is None:
|
||||
return True # If we can't find the module, assume it's local
|
||||
return False # If we can't find the module, assume it's local
|
||||
|
||||
# If the module is found and has a file path, check if it's in site-packages
|
||||
if spec.origin and 'site-packages' not in spec.origin:
|
||||
# Check if it's a .py file in the current directory or subdirectories
|
||||
return spec.origin.endswith('.py')
|
||||
|
||||
return not spec.origin.endswith('.py')
|
||||
|
||||
return False
|
||||
except ImportError:
|
||||
return True # If there's an import error, assume it's local
|
||||
return False # If there's an import error, assume it's local
|
||||
|
||||
class MethodChecker(ast.NodeVisitor):
|
||||
"""
|
||||
|
@ -33,7 +36,7 @@ class MethodChecker(ast.NodeVisitor):
|
|||
- only uses defined names
|
||||
- contains no local imports (e.g. numpy is ok but local_script is not)
|
||||
"""
|
||||
def __init__(self, class_attributes: Set[str]):
|
||||
def __init__(self, class_attributes: Set[str], check_imports: bool = True):
|
||||
self.undefined_names = set()
|
||||
self.imports = {}
|
||||
self.from_imports = {}
|
||||
|
@ -41,6 +44,7 @@ class MethodChecker(ast.NodeVisitor):
|
|||
self.arg_names = set()
|
||||
self.class_attributes = class_attributes
|
||||
self.errors = []
|
||||
self.check_imports = check_imports
|
||||
|
||||
def visit_arguments(self, node):
|
||||
"""Collect function arguments"""
|
||||
|
@ -53,16 +57,16 @@ class MethodChecker(ast.NodeVisitor):
|
|||
def visit_Import(self, node):
|
||||
for name in node.names:
|
||||
actual_name = name.asname or name.name
|
||||
if is_local_import(actual_name):
|
||||
self.errors.append(f"Local import '{actual_name}'")
|
||||
if not is_installed_package(actual_name) and self.check_imports:
|
||||
self.errors.append(f"Package not found in importlib, might be a local install: '{actual_name}'")
|
||||
self.imports[actual_name] = name.name
|
||||
|
||||
def visit_ImportFrom(self, node):
|
||||
module = node.module or ""
|
||||
for name in node.names:
|
||||
actual_name = name.asname or name.name
|
||||
if is_local_import(module):
|
||||
self.errors.append(f"Local import '{module}'")
|
||||
if not is_installed_package(module) and self.check_imports:
|
||||
self.errors.append(f"Package not found in importlib, might be a local install: '{module}'")
|
||||
self.from_imports[actual_name] = (module, name.name)
|
||||
|
||||
def visit_Assign(self, node):
|
||||
|
@ -71,6 +75,20 @@ class MethodChecker(ast.NodeVisitor):
|
|||
self.assigned_names.add(target.id)
|
||||
self.visit(node.value)
|
||||
|
||||
def visit_With(self, node):
|
||||
"""Track aliases in 'with' statements (the 'y' in 'with X as y')"""
|
||||
for item in node.items:
|
||||
if item.optional_vars: # This is the 'y' in 'with X as y'
|
||||
if isinstance(item.optional_vars, ast.Name):
|
||||
self.assigned_names.add(item.optional_vars.id)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_ExceptHandler(self, node):
|
||||
"""Track exception aliases (the 'e' in 'except Exception as e')"""
|
||||
if node.name: # This is the 'e' in 'except Exception as e'
|
||||
self.assigned_names.add(node.name)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_AnnAssign(self, node):
|
||||
"""Track annotated assignments."""
|
||||
if isinstance(node.target, ast.Name):
|
||||
|
@ -97,6 +115,7 @@ class MethodChecker(ast.NodeVisitor):
|
|||
if isinstance(node.ctx, ast.Load):
|
||||
if not (
|
||||
node.id in _BUILTIN_NAMES
|
||||
or node.id in IMPORTED_PACKAGES
|
||||
or node.id in self.arg_names
|
||||
or node.id == "self"
|
||||
or node.id in self.class_attributes
|
||||
|
@ -110,17 +129,18 @@ class MethodChecker(ast.NodeVisitor):
|
|||
if isinstance(node.func, ast.Name):
|
||||
if not (
|
||||
node.func.id in _BUILTIN_NAMES
|
||||
or node.func.id in IMPORTED_PACKAGES
|
||||
or node.func.id in self.arg_names
|
||||
or node.func.id == "self"
|
||||
or node.func.id in self.class_attributes
|
||||
or node.func.id in self.imports
|
||||
or node.func.id in self.from_imports
|
||||
or node.func.id in self.assigned_names
|
||||
):
|
||||
):
|
||||
self.errors.append(f"Name '{node.func.id}' is undefined.")
|
||||
self.generic_visit(node)
|
||||
|
||||
def validate_tool_attributes(cls) -> None:
|
||||
def validate_tool_attributes(cls, check_imports: bool = True) -> None:
|
||||
"""
|
||||
Validates that a Tool class follows the proper patterns:
|
||||
0. __init__ takes no argument (args chosen at init are not traceable so we cannot rebuild the source code for them, make them class attributes!).
|
||||
|
@ -156,8 +176,17 @@ def validate_tool_attributes(cls) -> None:
|
|||
self.imported_names = set()
|
||||
self.complex_attributes = set()
|
||||
self.class_attributes = set()
|
||||
self.in_method = False
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
old_context = self.in_method
|
||||
self.in_method = True
|
||||
self.generic_visit(node)
|
||||
self.in_method = old_context
|
||||
|
||||
def visit_Assign(self, node):
|
||||
if self.in_method:
|
||||
return
|
||||
# Track class attributes
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
|
@ -182,7 +211,7 @@ def validate_tool_attributes(cls) -> None:
|
|||
# Run checks on all methods
|
||||
for node in class_node.body:
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
method_checker = MethodChecker(class_level_checker.class_attributes)
|
||||
method_checker = MethodChecker(class_level_checker.class_attributes, check_imports=check_imports)
|
||||
method_checker.visit(node)
|
||||
errors += [f"- {node.name}: {error}" for error in method_checker.errors]
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ import textwrap
|
|||
from functools import lru_cache, wraps
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Union, Set
|
||||
import math
|
||||
|
||||
from huggingface_hub import (
|
||||
create_repo,
|
||||
|
@ -48,7 +49,7 @@ from transformers.utils import (
|
|||
is_vision_available,
|
||||
)
|
||||
from transformers.dynamic_module_utils import get_imports
|
||||
from .types import ImageType, handle_agent_inputs, handle_agent_outputs
|
||||
from .types import ImageType, handle_agent_input_types, handle_agent_output_types
|
||||
from .utils import instance_to_source
|
||||
from .tool_validation import validate_tool_attributes, MethodChecker
|
||||
|
||||
|
@ -66,7 +67,6 @@ if is_accelerate_available():
|
|||
|
||||
TOOL_CONFIG_FILE = "tool_config.json"
|
||||
|
||||
|
||||
def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
|
||||
if repo_type is not None:
|
||||
return repo_type
|
||||
|
@ -197,12 +197,15 @@ class Tool:
|
|||
def forward(self, *args, **kwargs):
|
||||
return NotImplementedError("Write this method in your subclass of `Tool`.")
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
def __call__(self, *args, sanitize_inputs_outputs: bool = False, **kwargs):
|
||||
if not self.is_initialized:
|
||||
self.setup()
|
||||
args, kwargs = handle_agent_inputs(*args, **kwargs)
|
||||
if sanitize_inputs_outputs:
|
||||
args, kwargs = handle_agent_input_types(*args, **kwargs)
|
||||
outputs = self.forward(*args, **kwargs)
|
||||
return handle_agent_outputs(outputs, self.output_type)
|
||||
if sanitize_inputs_outputs:
|
||||
outputs = handle_agent_output_types(outputs, self.output_type)
|
||||
return outputs
|
||||
|
||||
def setup(self):
|
||||
"""
|
||||
|
@ -266,9 +269,7 @@ class Tool:
|
|||
forward_source_code = add_self_argument(forward_source_code)
|
||||
forward_source_code = forward_source_code.replace("@tool", "").strip()
|
||||
tool_code += "\n\n" + textwrap.indent(forward_source_code, " ")
|
||||
|
||||
with open(tool_file, "w", encoding="utf-8") as f:
|
||||
f.write(tool_code)
|
||||
|
||||
else: # If the tool was not created by the @tool decorator, it was made by subclassing Tool
|
||||
if type(self).__name__ in ["SpaceToolWrapper", "LangChainToolWrapper", "GradioToolWrapper"]:
|
||||
raise ValueError(
|
||||
|
@ -278,8 +279,9 @@ class Tool:
|
|||
validate_tool_attributes(self.__class__)
|
||||
|
||||
tool_code = instance_to_source(self, base_cls=Tool)
|
||||
with open(tool_file, "w", encoding="utf-8") as f:
|
||||
f.write(tool_code)
|
||||
|
||||
with open(tool_file, "w", encoding="utf-8") as f:
|
||||
f.write(tool_code)
|
||||
|
||||
# Save app file
|
||||
app_file = os.path.join(output_dir, "app.py")
|
||||
|
@ -719,6 +721,9 @@ def launch_gradio_demo(tool: Tool):
|
|||
"number": gr.Textbox,
|
||||
}
|
||||
|
||||
def fn(*args, **kwargs):
|
||||
return tool(*args, **kwargs, sanitize_inputs_outputs=True)
|
||||
|
||||
gradio_inputs = []
|
||||
for input_name, input_details in tool.inputs.items():
|
||||
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[
|
||||
|
@ -733,7 +738,7 @@ def launch_gradio_demo(tool: Tool):
|
|||
gradio_output = output_gradio_componentclass(label="Output")
|
||||
|
||||
gr.Interface(
|
||||
fn=tool, # This works because `tool` has a __call__ method
|
||||
fn=fn,
|
||||
inputs=gradio_inputs,
|
||||
outputs=gradio_output,
|
||||
title=tool.name,
|
||||
|
@ -823,61 +828,6 @@ def add_description(description):
|
|||
return inner
|
||||
|
||||
|
||||
## Will move to the Hub
|
||||
class EndpointClient:
|
||||
def __init__(self, endpoint_url: str, token: Optional[str] = None):
|
||||
self.headers = {
|
||||
**build_hf_headers(token=token),
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
self.endpoint_url = endpoint_url
|
||||
|
||||
@staticmethod
|
||||
def encode_image(image):
|
||||
_bytes = io.BytesIO()
|
||||
image.save(_bytes, format="PNG")
|
||||
b64 = base64.b64encode(_bytes.getvalue())
|
||||
return b64.decode("utf-8")
|
||||
|
||||
@staticmethod
|
||||
def decode_image(raw_image):
|
||||
if not is_vision_available():
|
||||
raise ImportError(
|
||||
"This tool returned an image but Pillow is not installed. Please install it (`pip install Pillow`)."
|
||||
)
|
||||
|
||||
from PIL import Image
|
||||
|
||||
b64 = base64.b64decode(raw_image)
|
||||
_bytes = io.BytesIO(b64)
|
||||
return Image.open(_bytes)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None,
|
||||
params: Optional[Dict] = None,
|
||||
data: Optional[bytes] = None,
|
||||
output_image: bool = False,
|
||||
) -> Any:
|
||||
# Build payload
|
||||
payload = {}
|
||||
if inputs:
|
||||
payload["inputs"] = inputs
|
||||
if params:
|
||||
payload["parameters"] = params
|
||||
|
||||
# Make API call
|
||||
response = get_session().post(
|
||||
self.endpoint_url, headers=self.headers, json=payload, data=data
|
||||
)
|
||||
|
||||
# By default, parse the response for the user.
|
||||
if output_image:
|
||||
return self.decode_image(response.content)
|
||||
else:
|
||||
return response.json()
|
||||
|
||||
|
||||
class ToolCollection:
|
||||
"""
|
||||
Tool collections enable loading all Spaces from a collection in order to be added to the agent's toolbox.
|
||||
|
@ -1063,4 +1013,5 @@ __all__ = [
|
|||
"load_tool",
|
||||
"launch_gradio_demo",
|
||||
"Toolbox",
|
||||
"ToolCollection",
|
||||
]
|
|
@ -16,6 +16,7 @@ import os
|
|||
import pathlib
|
||||
import tempfile
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -105,6 +106,8 @@ class AgentImage(AgentType, ImageType):
|
|||
|
||||
if isinstance(value, ImageType):
|
||||
self._raw = value
|
||||
elif isinstance(value, bytes):
|
||||
self._raw = Image.open(BytesIO(value))
|
||||
elif isinstance(value, (str, pathlib.Path)):
|
||||
self._path = value
|
||||
elif isinstance(value, torch.Tensor):
|
||||
|
@ -241,13 +244,13 @@ class AgentAudio(AgentType, str):
|
|||
|
||||
|
||||
AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio}
|
||||
INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage}
|
||||
INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage, np.ndarray: AgentAudio}
|
||||
|
||||
if is_torch_available():
|
||||
INSTANCE_TYPE_MAPPING[Tensor] = AgentAudio
|
||||
|
||||
|
||||
def handle_agent_inputs(*args, **kwargs):
|
||||
def handle_agent_input_types(*args, **kwargs):
|
||||
args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args]
|
||||
kwargs = {
|
||||
k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()
|
||||
|
@ -255,7 +258,7 @@ def handle_agent_inputs(*args, **kwargs):
|
|||
return args, kwargs
|
||||
|
||||
|
||||
def handle_agent_outputs(output, output_type=None):
|
||||
def handle_agent_output_types(output, output_type=None):
|
||||
if output_type in AGENT_TYPE_MAPPING:
|
||||
# If the class has defined outputs, we can map directly according to the class definition
|
||||
decoded_outputs = AGENT_TYPE_MAPPING[output_type](output)
|
||||
|
|
|
@ -34,7 +34,18 @@ def is_pygments_available():
|
|||
|
||||
console = Console()
|
||||
|
||||
|
||||
BASE_BUILTIN_MODULES = [
|
||||
"random",
|
||||
"collections",
|
||||
"math",
|
||||
"time",
|
||||
"queue",
|
||||
"itertools",
|
||||
"re",
|
||||
"stat",
|
||||
"statistics",
|
||||
"unicodedata",
|
||||
]
|
||||
def parse_json_blob(json_blob: str) -> Dict[str, str]:
|
||||
try:
|
||||
first_accolade_index = json_blob.find("{")
|
||||
|
@ -190,7 +201,10 @@ def instance_to_source(instance, base_cls=None):
|
|||
|
||||
for name, value in class_attrs.items():
|
||||
if isinstance(value, str):
|
||||
class_lines.append(f' {name} = "{value}"')
|
||||
if "\n" in value:
|
||||
class_lines.append(f' {name} = """{value}"""')
|
||||
else:
|
||||
class_lines.append(f' {name} = "{value}"')
|
||||
else:
|
||||
class_lines.append(f' {name} = {repr(value)}')
|
||||
|
||||
|
@ -230,7 +244,8 @@ def instance_to_source(instance, base_cls=None):
|
|||
final_lines.append(f"from {base_cls.__module__} import {base_cls.__name__}")
|
||||
|
||||
# Add discovered imports
|
||||
final_lines.extend(required_imports)
|
||||
for package in required_imports:
|
||||
final_lines.append(f"import {package}")
|
||||
|
||||
if final_lines: # Add empty line after imports
|
||||
final_lines.append("")
|
||||
|
|
|
@ -29,7 +29,7 @@ from agents.agents import (
|
|||
Toolbox,
|
||||
ToolCall,
|
||||
)
|
||||
from agents.tool import tool
|
||||
from agents.tools import tool
|
||||
from agents.default_tools import PythonInterpreterTool
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ from agents.types import (
|
|||
AgentImage,
|
||||
AgentText,
|
||||
)
|
||||
from agents.tool import Tool, tool, AUTHORIZED_TYPES
|
||||
from agents.tools import Tool, tool, AUTHORIZED_TYPES
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue