Formatting

This commit is contained in:
Aymeric 2024-12-16 15:46:47 +01:00
parent 1751bf03ac
commit 06066437fd
18 changed files with 275 additions and 204 deletions

View File

@ -18,11 +18,7 @@ __version__ = "0.1.0"
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from transformers.utils import ( from transformers.utils import _LazyModule
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
from transformers.utils.import_utils import define_import_structure from transformers.utils.import_utils import define_import_structure
@ -43,4 +39,6 @@ else:
import sys import sys
_file = globals()["__file__"] _file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) sys.modules[__name__] = _LazyModule(
__name__, _file, define_import_structure(_file), module_spec=__spec__
)

View File

@ -79,8 +79,9 @@ class AgentGenerationError(AgentError):
pass pass
@dataclass @dataclass
class ToolCall(): class ToolCall:
tool_name: str tool_name: str
tool_arguments: Any tool_arguments: Any
@ -146,13 +147,17 @@ Here is a list of the team members that you can call:"""
def format_prompt_with_managed_agents_descriptions( def format_prompt_with_managed_agents_descriptions(
prompt_template, managed_agents, agent_descriptions_placeholder: Optional[str] = None prompt_template,
managed_agents,
agent_descriptions_placeholder: Optional[str] = None,
) -> str: ) -> str:
if agent_descriptions_placeholder is None: if agent_descriptions_placeholder is None:
agent_descriptions_placeholder = "{{managed_agents_descriptions}}" agent_descriptions_placeholder = "{{managed_agents_descriptions}}"
if agent_descriptions_placeholder not in prompt_template: if agent_descriptions_placeholder not in prompt_template:
print("PROMPT TEMPLLL", prompt_template) print("PROMPT TEMPLLL", prompt_template)
raise ValueError(f"Provided prompt template does not contain the managed agents descriptions placeholder '{agent_descriptions_placeholder}'") raise ValueError(
f"Provided prompt template does not contain the managed agents descriptions placeholder '{agent_descriptions_placeholder}'"
)
if len(managed_agents.keys()) > 0: if len(managed_agents.keys()) > 0:
return prompt_template.replace( return prompt_template.replace(
agent_descriptions_placeholder, show_agents_descriptions(managed_agents) agent_descriptions_placeholder, show_agents_descriptions(managed_agents)
@ -970,7 +975,9 @@ class CodeAgent(ReactAgent):
error_msg = f"Error in code parsing: {e}. Make sure to provide correct code" error_msg = f"Error in code parsing: {e}. Make sure to provide correct code"
raise AgentParsingError(error_msg) raise AgentParsingError(error_msg)
log_entry.tool_call = ToolCall(tool_name="python_interpreter", tool_arguments=code_action) log_entry.tool_call = ToolCall(
tool_name="python_interpreter", tool_arguments=code_action
)
# Execute # Execute
if self.verbose: if self.verbose:
@ -1075,4 +1082,13 @@ And even if your task resolution is not successful, please return as much contex
else: else:
return output return output
__all__ = ["AgentError", "BaseAgent", "ManagedAgent", "ReactAgent", "CodeAgent", "JsonAgent", "Toolbox"]
__all__ = [
"AgentError",
"BaseAgent",
"ManagedAgent",
"ReactAgent",
"CodeAgent",
"JsonAgent",
"Toolbox",
]

View File

@ -127,7 +127,10 @@ class PythonInterpreterTool(Tool):
name = "python_interpreter" name = "python_interpreter"
description = "This is a tool that evaluates python code. It can be used to perform calculations." description = "This is a tool that evaluates python code. It can be used to perform calculations."
inputs = { inputs = {
"code": {"type": "string", "description": "The python code to run in interpreter"} "code": {
"type": "string",
"description": "The python code to run in interpreter",
}
} }
output_type = "string" output_type = "string"
@ -186,4 +189,5 @@ class UserInputTool(Tool):
user_input = input(f"{question} => ") user_input = input(f"{question} => ")
return user_input return user_input
__all__ = ["PythonInterpreterTool", "FinalAnswerTool", "UserInputTool"] __all__ = ["PythonInterpreterTool", "FinalAnswerTool", "UserInputTool"]

View File

@ -9,12 +9,13 @@ from typing import Optional, Dict, Tuple, Set, Any
import types import types
from .default_tools import BASE_PYTHON_TOOLS from .default_tools import BASE_PYTHON_TOOLS
class StateManager: class StateManager:
def __init__(self, work_dir: Path): def __init__(self, work_dir: Path):
self.work_dir = work_dir self.work_dir = work_dir
self.state_file = work_dir / "interpreter_state.pickle" self.state_file = work_dir / "interpreter_state.pickle"
self.imports_file = work_dir / "imports.txt" self.imports_file = work_dir / "imports.txt"
self.import_pattern = re.compile(r'^(?:from\s+[\w.]+\s+)?import\s+.+$') self.import_pattern = re.compile(r"^(?:from\s+[\w.]+\s+)?import\s+.+$")
self.imports: Set[str] = set() self.imports: Set[str] = set()
def is_import_statement(self, code: str) -> bool: def is_import_statement(self, code: str) -> bool:
@ -23,7 +24,7 @@ class StateManager:
def track_imports(self, code: str): def track_imports(self, code: str):
"""Track import statements for later use.""" """Track import statements for later use."""
for line in code.split('\n'): for line in code.split("\n"):
if self.is_import_statement(line.strip()): if self.is_import_statement(line.strip()):
self.imports.add(line.strip()) self.imports.add(line.strip())
@ -37,20 +38,21 @@ class StateManager:
""" """
# Filter out modules, functions, and special variables # Filter out modules, functions, and special variables
state_dict = { state_dict = {
'variables': { "variables": {
k: v for k, v in locals_dict.items() k: v
for k, v in locals_dict.items()
if not ( if not (
k.startswith('_') k.startswith("_")
or callable(v) or callable(v)
or isinstance(v, type) or isinstance(v, type)
or isinstance(v, types.ModuleType) or isinstance(v, types.ModuleType)
) )
}, },
'imports': list(self.imports), "imports": list(self.imports),
'source': executor "source": executor,
} }
with open(self.state_file, 'wb') as f: with open(self.state_file, "wb") as f:
pickle.dump(state_dict, f) pickle.dump(state_dict, f)
def load_state(self, executor: str) -> Dict[str, Any]: def load_state(self, executor: str) -> Dict[str, Any]:
@ -66,14 +68,14 @@ class StateManager:
if not self.state_file.exists(): if not self.state_file.exists():
return {} return {}
with open(self.state_file, 'rb') as f: with open(self.state_file, "rb") as f:
state_dict = pickle.load(f) state_dict = pickle.load(f)
# First handle imports # First handle imports
for import_stmt in state_dict['imports']: for import_stmt in state_dict["imports"]:
exec(import_stmt, globals()) exec(import_stmt, globals())
return state_dict['variables'] return state_dict["variables"]
def read_multiplexed_response(socket): def read_multiplexed_response(socket):
@ -81,10 +83,10 @@ def read_multiplexed_response(socket):
socket.settimeout(10.0) socket.settimeout(10.0)
i = 0 i = 0
while True and i<1000: while True and i < 1000:
# Stream output from socket # Stream output from socket
response_data = socket.recv(4096) response_data = socket.recv(4096)
responses = response_data.split(b'\x01\x00\x00\x00\x00\x00') responses = response_data.split(b"\x01\x00\x00\x00\x00\x00")
# The last non-empty chunk should be our JSON response # The last non-empty chunk should be our JSON response
if len(responses) > 0: if len(responses) > 0:
@ -92,15 +94,15 @@ def read_multiplexed_response(socket):
if chunk and len(chunk.strip()) > 0: if chunk and len(chunk.strip()) > 0:
try: try:
# Find the start of valid JSON by looking for '{' # Find the start of valid JSON by looking for '{'
json_start = chunk.find(b'{') json_start = chunk.find(b"{")
if json_start != -1: if json_start != -1:
decoded = chunk[json_start:].decode('utf-8') decoded = chunk[json_start:].decode("utf-8")
result = json.loads(decoded) result = json.loads(decoded)
if "output" in result: if "output" in result:
return decoded return decoded
except json.JSONDecodeError: except json.JSONDecodeError:
continue continue
i+=1 i += 1
class DockerPythonInterpreter: class DockerPythonInterpreter:
@ -113,7 +115,6 @@ class DockerPythonInterpreter:
self.socket = None self.socket = None
self.state_manager = StateManager(work_dir) self.state_manager = StateManager(work_dir)
def create_interpreter_script(self) -> str: def create_interpreter_script(self) -> str:
"""Create the interpreter script that will run inside the container""" """Create the interpreter script that will run inside the container"""
script = """ script = """
@ -230,9 +231,7 @@ if __name__ == '__main__':
self.create_interpreter_script() self.create_interpreter_script()
# Setup volume mapping # Setup volume mapping
volumes = { volumes = {str(self.work_dir.resolve()): {"bind": "/workspace", "mode": "rw"}}
str(self.work_dir.resolve()): {"bind": "/workspace", "mode": "rw"}
}
for container in self.client.containers.list(all=True): for container in self.client.containers.list(all=True):
if container_name == container.name: if container_name == container.name:
@ -241,7 +240,7 @@ if __name__ == '__main__':
container.start() container.start()
self.container = container self.container = container
break break
else: # Create new container else: # Create new container
self.container = self.client.containers.run( self.container = self.client.containers.run(
"python:3.9", "python:3.9",
name=container_name, name=container_name,
@ -250,22 +249,20 @@ if __name__ == '__main__':
tty=True, tty=True,
stdin_open=True, stdin_open=True,
working_dir="/workspace", working_dir="/workspace",
volumes=volumes volumes=volumes,
) )
# Install packages in the new container # Install packages in the new container
print("Installing packages...") print("Installing packages...")
packages = ["pandas", "numpy", "pickle5"] # Add your required packages here packages = ["pandas", "numpy", "pickle5"] # Add your required packages here
result = self.container.exec_run( result = self.container.exec_run(
f"pip install {' '.join(packages)}", f"pip install {' '.join(packages)}", workdir="/workspace"
workdir="/workspace"
) )
if result.exit_code != 0: if result.exit_code != 0:
print(f"Warning: Failed to install: {result.output.decode()}") print(f"Warning: Failed to install: {result.output.decode()}")
else: else:
print(f"Installed {packages}.") print(f"Installed {packages}.")
if not self.wait_for_ready(self.container): if not self.wait_for_ready(self.container):
raise Exception("Failed to start container") raise Exception("Failed to start container")
@ -276,14 +273,12 @@ if __name__ == '__main__':
stdin=True, stdin=True,
stdout=True, stdout=True,
stderr=True, stderr=True,
tty=True tty=True,
) )
# Connect to the exec instance # Connect to the exec instance
self.socket = self.client.api.exec_start( self.socket = self.client.api.exec_start(
self.exec_id['Id'], self.exec_id["Id"], socket=True, demux=True
socket=True,
demux=True
)._sock )._sock
def _raw_execute(self, code: str) -> Tuple[str, bool]: def _raw_execute(self, code: str) -> Tuple[str, bool]:
@ -296,14 +291,14 @@ if __name__ == '__main__':
if not self.socket: if not self.socket:
raise Exception("Socket not started") raise Exception("Socket not started")
command = json.dumps({'code': code}) + '\n' command = json.dumps({"code": code}) + "\n"
self.socket.send(command.encode()) self.socket.send(command.encode())
response = read_multiplexed_response(self.socket) response = read_multiplexed_response(self.socket)
try: try:
result = json.loads(response) result = json.loads(response)
return result['output'], result['more'] return result["output"], result["more"]
except json.JSONDecodeError: except json.JSONDecodeError:
return f"Error: Invalid response from interpreter: {response}", False return f"Error: Invalid response from interpreter: {response}", False
@ -311,7 +306,7 @@ if __name__ == '__main__':
"""Get the current locals dictionary from the interpreter by pickling directly from Docker.""" """Get the current locals dictionary from the interpreter by pickling directly from Docker."""
pickle_path = self.work_dir / "locals.pickle" pickle_path = self.work_dir / "locals.pickle"
if pickle_path.exists(): if pickle_path.exists():
with open(pickle_path, 'rb') as f: with open(pickle_path, "rb") as f:
try: try:
return pickle.load(f) return pickle.load(f)
except Exception as e: except Exception as e:
@ -326,10 +321,7 @@ if __name__ == '__main__':
output, more = self._raw_execute(code) output, more = self._raw_execute(code)
# Save state after execution # Save state after execution
self.state_manager.save_state( self.state_manager.save_state(self.get_locals_dict(), "docker")
self.get_locals_dict(),
'docker'
)
return output, more return output, more
def stop(self, remove: bool = False): def stop(self, remove: bool = False):
@ -349,6 +341,7 @@ if __name__ == '__main__':
print(f"Error stopping container: {e}") print(f"Error stopping container: {e}")
raise raise
def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any: def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any:
from .local_python_executor import evaluate_python_code, LIST_SAFE_MODULES from .local_python_executor import evaluate_python_code, LIST_SAFE_MODULES
@ -359,7 +352,7 @@ def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any:
state_manager.track_imports(code) state_manager.track_imports(code)
# Load state from Docker if available # Load state from Docker if available
locals_dict = state_manager.load_state('local') locals_dict = state_manager.load_state("local")
# Execute in a new namespace with loaded state # Execute in a new namespace with loaded state
namespace = {} namespace = {}
@ -374,7 +367,7 @@ def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any:
) )
# Save state for Docker # Save state for Docker
state_manager.save_state(namespace, 'local') state_manager.save_state(namespace, "local")
return output return output
@ -382,14 +375,17 @@ def create_tools_regex(tool_names):
# Escape any special regex characters in tool names # Escape any special regex characters in tool names
escaped_names = [re.escape(name) for name in tool_names] escaped_names = [re.escape(name) for name in tool_names]
# Join with | and add word boundaries # Join with | and add word boundaries
pattern = r'\b(' + '|'.join(escaped_names) + r')\b' pattern = r"\b(" + "|".join(escaped_names) + r")\b"
return re.compile(pattern) return re.compile(pattern)
def execute_code(code: str, tools: Dict[str, Any], work_dir: Path, interpreter): def execute_code(code: str, tools: Dict[str, Any], work_dir: Path, interpreter):
"""Execute code with automatic switching between Docker and local.""" """Execute code with automatic switching between Docker and local."""
lines = code.split('\n') lines = code.split("\n")
current_block = [] current_block = []
tool_regex = create_tools_regex(list(tools.keys()) + ["print"]) # Added print for testing tool_regex = create_tools_regex(
list(tools.keys()) + ["print"]
) # Added print for testing
tools = { tools = {
**BASE_PYTHON_TOOLS.copy(), **BASE_PYTHON_TOOLS.copy(),
@ -400,20 +396,20 @@ def execute_code(code: str, tools: Dict[str, Any], work_dir: Path, interpreter):
if tool_regex.search(line): if tool_regex.search(line):
# Execute accumulated Docker code if any # Execute accumulated Docker code if any
if current_block: if current_block:
output, more = interpreter.execute('\n'.join(current_block)) output, more = interpreter.execute("\n".join(current_block))
print(output, end='') print(output, end="")
current_block = [] current_block = []
output = execute_locally(line, work_dir, tools) output = execute_locally(line, work_dir, tools)
if output: if output:
print(output, end='') print(output, end="")
else: else:
current_block.append(line) current_block.append(line)
# Execute any remaining Docker code # Execute any remaining Docker code
if current_block: if current_block:
output, more = interpreter.execute('\n'.join(current_block)) output, more = interpreter.execute("\n".join(current_block))
print(output, end='') print(output, end="")
__all__ = ["DockerPythonInterpreter", "execute_code"] __all__ = ["DockerPythonInterpreter", "execute_code"]

View File

@ -111,4 +111,5 @@ class GradioUI:
demo.launch() demo.launch()
__all__ = ["stream_to_gradio", "GradioUI"] __all__ = ["stream_to_gradio", "GradioUI"]

View File

@ -37,6 +37,7 @@ DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_action>", "value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_action>",
} }
class MessageRole(str, Enum): class MessageRole(str, Enum):
USER = "user" USER = "user"
ASSISTANT = "assistant" ASSISTANT = "assistant"
@ -48,6 +49,7 @@ class MessageRole(str, Enum):
def roles(cls): def roles(cls):
return [r.value for r in cls] return [r.value for r in cls]
openai_role_conversions = { openai_role_conversions = {
MessageRole.TOOL_RESPONSE: MessageRole.USER, MessageRole.TOOL_RESPONSE: MessageRole.USER,
} }
@ -56,6 +58,7 @@ llama_role_conversions = {
MessageRole.TOOL_RESPONSE: MessageRole.USER, MessageRole.TOOL_RESPONSE: MessageRole.USER,
} }
def get_clean_message_list( def get_clean_message_list(
message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {} message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}
): ):
@ -118,7 +121,7 @@ class HfEngine:
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None, grammar: Optional[str] = None,
max_tokens: int = 1500 max_tokens: int = 1500,
): ):
raise NotImplementedError raise NotImplementedError
@ -276,7 +279,12 @@ class TransformersEngine(HfEngine):
class OpenAIEngine: class OpenAIEngine:
def __init__(self, model_name: Optional[str] = None, api_key: Optional[str] = None, base_url: Optional[str] = None): def __init__(
self,
model_name: Optional[str] = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
):
"""Creates a LLM Engine that follows OpenAI format. """Creates a LLM Engine that follows OpenAI format.
Args: Args:
@ -301,7 +309,9 @@ class OpenAIEngine:
grammar: Optional[str] = None, grammar: Optional[str] = None,
max_tokens: int = 1500, max_tokens: int = 1500,
) -> str: ) -> str:
messages = get_clean_message_list(messages, role_conversions=openai_role_conversions) messages = get_clean_message_list(
messages, role_conversions=openai_role_conversions
)
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=self.model_name, model=self.model_name,
@ -337,7 +347,9 @@ class AnthropicEngine:
grammar: Optional[str] = None, grammar: Optional[str] = None,
max_tokens: int = 1500, max_tokens: int = 1500,
) -> str: ) -> str:
messages = get_clean_message_list(messages, role_conversions=openai_role_conversions) messages = get_clean_message_list(
messages, role_conversions=openai_role_conversions
)
index_system_message, system_prompt = None, None index_system_message, system_prompt = None, None
for index, message in enumerate(messages): for index, message in enumerate(messages):
if message["role"] == MessageRole.SYSTEM: if message["role"] == MessageRole.SYSTEM:
@ -346,7 +358,9 @@ class AnthropicEngine:
if system_prompt is None: if system_prompt is None:
raise Exception("No system prompt found!") raise Exception("No system prompt found!")
filtered_messages = [message for i, message in enumerate(messages) if i != index_system_message] filtered_messages = [
message for i, message in enumerate(messages) if i != index_system_message
]
if len(filtered_messages) == 0: if len(filtered_messages) == 0:
print("Error, no user message:", messages) print("Error, no user message:", messages)
assert False assert False
@ -366,4 +380,13 @@ class AnthropicEngine:
return full_response_text return full_response_text
__all__ = ["MessageRole", "llama_role_conversions", "get_clean_message_list", "HfEngine", "TransformersEngine", "HfApiEngine", "OpenAIEngine", "AnthropicEngine"] __all__ = [
"MessageRole",
"llama_role_conversions",
"get_clean_message_list",
"HfEngine",
"TransformersEngine",
"HfApiEngine",
"OpenAIEngine",
"AnthropicEngine",
]

View File

@ -1000,4 +1000,5 @@ def evaluate_python_code(
msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}" msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
raise InterpreterError(msg) raise InterpreterError(msg)
__all__ = ["evaluate_python_code"] __all__ = ["evaluate_python_code"]

View File

@ -44,4 +44,5 @@ class Monitor:
console.print(f"- Input tokens: {self.total_input_token_count:,}") console.print(f"- Input tokens: {self.total_input_token_count:,}")
console.print(f"- Output tokens: {self.total_output_token_count:,}") console.print(f"- Output tokens: {self.total_output_token_count:,}")
__all__ = ["Monitor"] __all__ = ["Monitor"]

View File

@ -491,4 +491,10 @@ Here is my new/updated plan of action to solve the task:
{plan_update} {plan_update}
```""" ```"""
__all__ = ["USER_PROMPT_PLAN_UPDATE", "PLAN_UPDATE_FINAL_PLAN_REDACTION", "ONESHOT_CODE_SYSTEM_PROMPT", "CODE_SYSTEM_PROMPT", "JSON_SYSTEM_PROMPT"] __all__ = [
"USER_PROMPT_PLAN_UPDATE",
"PLAN_UPDATE_FINAL_PLAN_REDACTION",
"ONESHOT_CODE_SYSTEM_PROMPT",
"CODE_SYSTEM_PROMPT",
"JSON_SYSTEM_PROMPT",
]

View File

@ -78,4 +78,5 @@ class VisitWebpageTool(Tool):
except Exception as e: except Exception as e:
return f"An unexpected error occurred: {str(e)}" return f"An unexpected error occurred: {str(e)}"
__all__ = ["DuckDuckGoSearchTool", "VisitWebpageTool"] __all__ = ["DuckDuckGoSearchTool", "VisitWebpageTool"]

View File

@ -39,12 +39,6 @@ from huggingface_hub import (
from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session
from packaging import version from packaging import version
from transformers.dynamic_module_utils import (
custom_object_save,
get_class_from_dynamic_module,
get_imports,
)
from transformers import AutoProcessor
from transformers.utils import ( from transformers.utils import (
TypeHintParsingException, TypeHintParsingException,
cached_file, cached_file,
@ -62,11 +56,10 @@ logger = logging.getLogger(__name__)
if is_torch_available(): if is_torch_available():
import torch pass
if is_accelerate_available(): if is_accelerate_available():
from accelerate import PartialState pass
from accelerate.utils import send_to_device
TOOL_CONFIG_FILE = "tool_config.json" TOOL_CONFIG_FILE = "tool_config.json"
@ -123,6 +116,7 @@ def validate_after_init(cls, do_validate_forward: bool = True):
cls.__init__ = new_init cls.__init__ = new_init
return cls return cls
def validate_args_are_self_contained(source_code): def validate_args_are_self_contained(source_code):
"""Validates that all names in forward method are properly defined. """Validates that all names in forward method are properly defined.
In particular it will check that all imports are done within the function.""" In particular it will check that all imports are done within the function."""
@ -150,7 +144,7 @@ def validate_args_are_self_contained(source_code):
def visit_ImportFrom(self, node): def visit_ImportFrom(self, node):
"""Handle from imports like 'from datetime import datetime'.""" """Handle from imports like 'from datetime import datetime'."""
module = node.module or '' module = node.module or ""
for name in node.names: for name in node.names:
actual_name = name.asname or name.name actual_name = name.asname or name.name
self.from_imports[actual_name] = (module, name.name, actual_name) self.from_imports[actual_name] = (module, name.name, actual_name)
@ -187,9 +181,11 @@ def validate_args_are_self_contained(source_code):
self.assigned_names.update(target_names) self.assigned_names.update(target_names)
# Special handling for enumerate # Special handling for enumerate
if (isinstance(node.iter, ast.Call) and if (
isinstance(node.iter.func, ast.Name) and isinstance(node.iter, ast.Call)
node.iter.func.id == 'enumerate'): and isinstance(node.iter.func, ast.Name)
and node.iter.func.id == "enumerate"
):
# For enumerate, if we have "for i, x in enumerate(...)", # For enumerate, if we have "for i, x in enumerate(...)",
# both i and x should be marked as assigned # both i and x should be marked as assigned
if isinstance(node.target, ast.Tuple): if isinstance(node.target, ast.Tuple):
@ -201,19 +197,19 @@ def validate_args_are_self_contained(source_code):
self.generic_visit(node) self.generic_visit(node)
def visit_Name(self, node): def visit_Name(self, node):
if (isinstance(node.ctx, ast.Load) and not ( if isinstance(node.ctx, ast.Load) and not (
node.id == "tool" or node.id == "tool"
node.id in builtin_names or or node.id in builtin_names
node.id in arg_names or or node.id in arg_names
node.id == 'self' or or node.id == "self"
node.id in self.assigned_names or node.id in self.assigned_names
)): ):
if node.id not in self.from_imports and node.id not in self.imports: if node.id not in self.from_imports and node.id not in self.imports:
self.undefined_names.add(node.id) self.undefined_names.add(node.id)
def visit_Attribute(self, node): def visit_Attribute(self, node):
# Skip self.something # Skip self.something
if not (isinstance(node.value, ast.Name) and node.value.id == 'self'): if not (isinstance(node.value, ast.Name) and node.value.id == "self"):
self.generic_visit(node) self.generic_visit(node)
checker = NameChecker() checker = NameChecker()
@ -226,6 +222,7 @@ def validate_args_are_self_contained(source_code):
""" """
) )
AUTHORIZED_TYPES = [ AUTHORIZED_TYPES = [
"string", "string",
"boolean", "boolean",
@ -273,7 +270,6 @@ class Tool:
super().__init_subclass__(**kwargs) super().__init_subclass__(**kwargs)
validate_after_init(cls, do_validate_forward=False) validate_after_init(cls, do_validate_forward=False)
def validate_arguments(self, do_validate_forward: bool = True): def validate_arguments(self, do_validate_forward: bool = True):
required_attributes = { required_attributes = {
"description": str, "description": str,
@ -359,13 +355,13 @@ class {class_name}(Tool):
def add_self_argument(source_code: str) -> str: def add_self_argument(source_code: str) -> str:
"""Add 'self' as first argument to a function definition if not present.""" """Add 'self' as first argument to a function definition if not present."""
pattern = r'def forward\(((?!self)[^)]*)\)' pattern = r"def forward\(((?!self)[^)]*)\)"
def replacement(match): def replacement(match):
args = match.group(1).strip() args = match.group(1).strip()
if args: # If there are other arguments if args: # If there are other arguments
return f'def forward(self, {args})' return f"def forward(self, {args})"
return 'def forward(self)' return "def forward(self)"
return re.sub(pattern, replacement, source_code) return re.sub(pattern, replacement, source_code)
@ -391,11 +387,7 @@ class {class_name}(Tool):
# Save app file # Save app file
app_file = os.path.join(output_dir, "app.py") app_file = os.path.join(output_dir, "app.py")
with open(app_file, "w", encoding="utf-8") as f: with open(app_file, "w", encoding="utf-8") as f:
f.write( f.write(APP_FILE_TEMPLATE.format(class_name=class_name))
APP_FILE_TEMPLATE.format(
class_name=class_name
)
)
# Save requirements file # Save requirements file
requirements_file = os.path.join(output_dir, "requirements.txt") requirements_file = os.path.join(output_dir, "requirements.txt")
@ -457,7 +449,7 @@ class {class_name}(Tool):
self.save(work_dir) self.save(work_dir)
print(work_dir) print(work_dir)
with open(work_dir + "/tool.py", "r") as f: with open(work_dir + "/tool.py", "r") as f:
print('\n'.join(f.readlines())) print("\n".join(f.readlines()))
logger.info( logger.info(
f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}" f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}"
) )
@ -575,7 +567,6 @@ class {class_name}(Tool):
return tool_class(**kwargs) return tool_class(**kwargs)
@staticmethod @staticmethod
def from_space( def from_space(
space_id: str, space_id: str,
@ -702,7 +693,11 @@ class {class_name}(Tool):
return output return output
return SpaceToolWrapper( return SpaceToolWrapper(
space_id=space_id, name=name, description=description, api_name=api_name, token=token space_id=space_id,
name=name,
description=description,
api_name=api_name,
token=token,
) )
@staticmethod @staticmethod
@ -855,12 +850,12 @@ TOOL_MAPPING = {
def load_tool( def load_tool(
task_or_repo_id, task_or_repo_id,
model_repo_id: Optional[str] = None, model_repo_id: Optional[str] = None,
token: Optional[str] = None, token: Optional[str] = None,
trust_remote_code: bool=False, trust_remote_code: bool = False,
**kwargs **kwargs,
): ):
""" """
Main function to quickly load a tool, be it on the Hub or in the Transformers library. Main function to quickly load a tool, be it on the Hub or in the Transformers library.
@ -909,7 +904,11 @@ def load_tool(
f"code that you have checked." f"code that you have checked."
) )
return Tool.from_hub( return Tool.from_hub(
task_or_repo_id, model_repo_id=model_repo_id, token=token, trust_remote_code=trust_remote_code, **kwargs task_or_repo_id,
model_repo_id=model_repo_id,
token=token,
trust_remote_code=trust_remote_code,
**kwargs,
) )
@ -1028,7 +1027,7 @@ def tool(tool_function: Callable) -> Tool:
raise TypeHintParsingException( raise TypeHintParsingException(
"Tool return type not found: make sure your function has a return type hint!" "Tool return type not found: make sure your function has a return type hint!"
) )
class_name = ''.join([el.title() for el in parameters['name'].split('_')]) class_name = "".join([el.title() for el in parameters["name"].split("_")])
if parameters["return"]["type"] == "object": if parameters["return"]["type"] == "object":
parameters["return"]["type"] = "any" parameters["return"]["type"] = "any"
@ -1086,7 +1085,9 @@ class Toolbox:
"""Get all tools currently in the toolbox""" """Get all tools currently in the toolbox"""
return self._tools return self._tools
def show_tool_descriptions(self, tool_description_template: Optional[str] = None) -> str: def show_tool_descriptions(
self, tool_description_template: Optional[str] = None
) -> str:
""" """
Returns the description of all tools in the toolbox Returns the description of all tools in the toolbox
@ -1151,4 +1152,12 @@ class Toolbox:
toolbox_description += f"\t{tool.name}: {tool.description}\n" toolbox_description += f"\t{tool.name}: {tool.description}\n"
return toolbox_description return toolbox_description
__all__ = ["AUTHORIZED_TYPES", "Tool", "tool", "load_tool", "launch_gradio_demo", "Toolbox"]
__all__ = [
"AUTHORIZED_TYPES",
"Tool",
"tool",
"load_tool",
"launch_gradio_demo",
"Toolbox",
]

View File

@ -267,4 +267,5 @@ def handle_agent_outputs(output, output_type=None):
return _v(output) return _v(output)
return output return output
__all__ = ["AgentType", "AgentImage", "AgentText", "AgentAudio"] __all__ = ["AgentType", "AgentImage", "AgentText", "AgentAudio"]

View File

@ -18,6 +18,7 @@ import json
import re import re
from typing import Tuple, Dict, Union from typing import Tuple, Dict, Union
import ast import ast
from rich.console import Console
from transformers.utils.import_utils import _is_package_available from transformers.utils.import_utils import _is_package_available
@ -28,8 +29,6 @@ def is_pygments_available():
return _pygments_available return _pygments_available
from rich.console import Console
console = Console() console = Console()
@ -111,6 +110,7 @@ def truncate_content(
+ content[-MAX_LENGTH_TRUNCATE_CONTENT // 2 :] + content[-MAX_LENGTH_TRUNCATE_CONTENT // 2 :]
) )
class ImportFinder(ast.NodeVisitor): class ImportFinder(ast.NodeVisitor):
def __init__(self): def __init__(self):
self.packages = set() self.packages = set()
@ -118,13 +118,14 @@ class ImportFinder(ast.NodeVisitor):
def visit_Import(self, node): def visit_Import(self, node):
for alias in node.names: for alias in node.names:
# Get the base package name (before any dots) # Get the base package name (before any dots)
base_package = alias.name.split('.')[0] base_package = alias.name.split(".")[0]
self.packages.add(base_package) self.packages.add(base_package)
def visit_ImportFrom(self, node): def visit_ImportFrom(self, node):
if node.module: # for "from x import y" statements if node.module: # for "from x import y" statements
# Get the base package name (before any dots) # Get the base package name (before any dots)
base_package = node.module.split('.')[0] base_package = node.module.split(".")[0]
self.packages.add(base_package) self.packages.add(base_package)
__all__ = [] __all__ = []

View File

@ -27,12 +27,13 @@ from agents.agents import (
CodeAgent, CodeAgent,
JsonAgent, JsonAgent,
Toolbox, Toolbox,
ToolCall ToolCall,
) )
from agents.tools import tool from agents.tools import tool
from agents.default_tools import PythonInterpreterTool from agents.default_tools import PythonInterpreterTool
from transformers.testing_utils import get_tests_dir from transformers.testing_utils import get_tests_dir
def get_new_path(suffix="") -> str: def get_new_path(suffix="") -> str:
directory = tempfile.mkdtemp() directory = tempfile.mkdtemp()
return os.path.join(directory, str(uuid.uuid4()) + suffix) return os.path.join(directory, str(uuid.uuid4()) + suffix)
@ -60,6 +61,7 @@ Action:
} }
""" """
def fake_json_llm_image(messages, stop_sequences=None, grammar=None) -> str: def fake_json_llm_image(messages, stop_sequences=None, grammar=None) -> str:
prompt = str(messages) prompt = str(messages)
@ -82,6 +84,7 @@ Action:
} }
""" """
def fake_code_llm(messages, stop_sequences=None, grammar=None) -> str: def fake_code_llm(messages, stop_sequences=None, grammar=None) -> str:
prompt = str(messages) prompt = str(messages)
if "special_marker" not in prompt: if "special_marker" not in prompt:
@ -179,9 +182,7 @@ class AgentTests(unittest.TestCase):
assert output == "7.2904" assert output == "7.2904"
def test_fake_json_agent(self): def test_fake_json_agent(self):
agent = JsonAgent( agent = JsonAgent(tools=[PythonInterpreterTool()], llm_engine=fake_json_llm)
tools=[PythonInterpreterTool()], llm_engine=fake_json_llm
)
output = agent.run("What is 2 multiplied by 3.6452?") output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, str) assert isinstance(output, str)
assert output == "7.2904" assert output == "7.2904"
@ -209,9 +210,7 @@ Action:
Args: Args:
prompt: The prompt prompt: The prompt
""" """
return Image.open( return Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png")
Path(get_tests_dir("fixtures")) / "000000039769.png"
)
agent = JsonAgent( agent = JsonAgent(
tools=[fake_image_generation_tool], llm_engine=fake_json_llm_image tools=[fake_image_generation_tool], llm_engine=fake_json_llm_image
@ -221,9 +220,7 @@ Action:
assert isinstance(agent.state["image.png"], Image.Image) assert isinstance(agent.state["image.png"], Image.Image)
def test_fake_code_agent(self): def test_fake_code_agent(self):
agent = CodeAgent( agent = CodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_code_llm)
tools=[PythonInterpreterTool()], llm_engine=fake_code_llm
)
output = agent.run("What is 2 multiplied by 3.6452?") output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, float) assert isinstance(output, float)
assert output == 7.2904 assert output == 7.2904
@ -234,9 +231,7 @@ Action:
) )
def test_reset_conversations(self): def test_reset_conversations(self):
agent = CodeAgent( agent = CodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_code_llm)
tools=[PythonInterpreterTool()], llm_engine=fake_code_llm
)
output = agent.run("What is 2 multiplied by 3.6452?", reset=True) output = agent.run("What is 2 multiplied by 3.6452?", reset=True)
assert output == 7.2904 assert output == 7.2904
assert len(agent.logs) == 4 assert len(agent.logs) == 4
@ -299,9 +294,7 @@ Action:
# check that python_interpreter base tool does not get added to code agents # check that python_interpreter base tool does not get added to code agents
agent = CodeAgent(tools=[], llm_engine=fake_code_llm, add_base_tools=True) agent = CodeAgent(tools=[], llm_engine=fake_code_llm, add_base_tools=True)
assert ( assert len(agent.toolbox.tools) == 2 # added final_answer tool + search
len(agent.toolbox.tools) == 2
) # added final_answer tool + search
def test_function_persistence_across_steps(self): def test_function_persistence_across_steps(self):
agent = CodeAgent( agent = CodeAgent(

View File

@ -25,6 +25,7 @@ from pathlib import Path
from typing import List from typing import List
from dotenv import load_dotenv from dotenv import load_dotenv
class SubprocessCallException(Exception): class SubprocessCallException(Exception):
pass pass
@ -59,7 +60,7 @@ class DocCodeExtractor:
@staticmethod @staticmethod
def extract_python_code(content: str) -> List[str]: def extract_python_code(content: str) -> List[str]:
"""Extract Python code blocks from markdown content.""" """Extract Python code blocks from markdown content."""
pattern = r'```(?:python|py)\n(.*?)\n```' pattern = r"```(?:python|py)\n(.*?)\n```"
matches = re.finditer(pattern, content, re.DOTALL) matches = re.finditer(pattern, content, re.DOTALL)
return [match.group(1).strip() for match in matches] return [match.group(1).strip() for match in matches]
@ -118,18 +119,27 @@ class TestDocs:
# Create and execute test script # Create and execute test script
try: try:
excluded_snippets = ["ToolCollection", "image_generation_tool", "from_langchain"] excluded_snippets = [
"ToolCollection",
"image_generation_tool",
"from_langchain",
]
code_blocks = [ code_blocks = [
block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", self.hf_token) for block in code_blocks block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", self.hf_token)
if not any([snippet in block for snippet in excluded_snippets]) # Exclude these tools that take longer to run and add dependencies for block in code_blocks
if not any(
[snippet in block for snippet in excluded_snippets]
) # Exclude these tools that take longer to run and add dependencies
] ]
test_script = self.extractor.create_test_script(code_blocks, self._tmpdir) test_script = self.extractor.create_test_script(code_blocks, self._tmpdir)
run_command(self.launch_args + [str(test_script)]) run_command(self.launch_args + [str(test_script)])
except SubprocessCallException as e: except SubprocessCallException as e:
pytest.fail(f"\nError while testing {doc_path.name}:\n{str(e)}") pytest.fail(f"\nError while testing {doc_path.name}:\n{str(e)}")
except Exception as e: except Exception:
pytest.fail(f"\nUnexpected error while testing {doc_path.name}:\n{traceback.format_exc()}") pytest.fail(
f"\nUnexpected error while testing {doc_path.name}:\n{traceback.format_exc()}"
)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def _setup(self): def _setup(self):
@ -152,7 +162,5 @@ def pytest_generate_tests(metafunc):
# Parameterize with the markdown files # Parameterize with the markdown files
metafunc.parametrize( metafunc.parametrize(
"doc_path", "doc_path", test_class.md_files, ids=[f.stem for f in test_class.md_files]
test_class.md_files,
ids=[f.stem for f in test_class.md_files]
) )

View File

@ -20,9 +20,10 @@ import numpy as np
from PIL import Image from PIL import Image
from transformers import is_torch_available from transformers import is_torch_available
from agents.types import AGENT_TYPE_MAPPING
from agents.default_tools import FinalAnswerTool
from transformers.testing_utils import get_tests_dir, require_torch from transformers.testing_utils import get_tests_dir, require_torch
from agents.types import AGENT_TYPE_MAPPING
from agents.default_tools import FinalAnswerTool
from .test_tools_common import ToolTesterMixin from .test_tools_common import ToolTesterMixin

View File

@ -98,6 +98,7 @@ class ToolTesterMixin:
agent_type = AGENT_TYPE_MAPPING[self.tool.output_type] agent_type = AGENT_TYPE_MAPPING[self.tool.output_type]
self.assertTrue(isinstance(output, agent_type)) self.assertTrue(isinstance(output, agent_type))
class ToolTests(unittest.TestCase): class ToolTests(unittest.TestCase):
def test_tool_init_with_decorator(self): def test_tool_init_with_decorator(self):
@tool @tool
@ -163,40 +164,46 @@ class ToolTests(unittest.TestCase):
assert coolfunc.output_type == "number" assert coolfunc.output_type == "number"
assert "docstring has no description for the argument" in str(e) assert "docstring has no description for the argument" in str(e)
def test_tool_definition_needs_imports_in_function(self): def test_tool_definition_raises_error_imports_outside_function(self):
with pytest.raises(Exception) as e: with pytest.raises(Exception) as e:
from datetime import datetime from datetime import datetime
@tool @tool
def get_current_time() -> str: def get_current_time() -> str:
""" """
Gets the current time. Gets the current time.
""" """
return str(datetime.now()) return str(datetime.now())
assert "datetime" in str(e) assert "datetime" in str(e)
# Also test with classic definition # Also test with classic definition
with pytest.raises(Exception) as e: with pytest.raises(Exception) as e:
class GetCurrentTimeTool(Tool): class GetCurrentTimeTool(Tool):
name="get_current_time_tool" name = "get_current_time_tool"
description="Gets the current time" description = "Gets the current time"
inputs = {} inputs = {}
output_type = "string" output_type = "string"
def forward(self): def forward(self):
return str(datetime.now()) return str(datetime.now())
assert "datetime" in str(e) assert "datetime" in str(e)
def test_tool_definition_raises_no_error_imports_in_function(self):
@tool @tool
def get_current_time() -> str: def get_current_time() -> str:
""" """
Gets the current time. Gets the current time.
""" """
from datetime import datetime from datetime import datetime
return str(datetime.now()) return str(datetime.now())
class GetCurrentTimeTool(Tool): class GetCurrentTimeTool(Tool):
name="get_current_time_tool" name = "get_current_time_tool"
description="Gets the current time" description = "Gets the current time"
inputs = {} inputs = {}
output_type = "string" output_type = "string"

View File

@ -5,6 +5,7 @@ import tempfile
from pathlib import Path from pathlib import Path
def str_to_bool(value) -> int: def str_to_bool(value) -> int:
""" """
Converts a string representation of truth to `True` (1) or `False` (0). Converts a string representation of truth to `True` (1) or `False` (0).
@ -28,10 +29,13 @@ def get_int_from_env(env_keys, default):
return val return val
return default return default
def parse_flag_from_env(key, default=False): def parse_flag_from_env(key, default=False):
"""Returns truthy value for `key` from the env if available else the default.""" """Returns truthy value for `key` from the env if available else the default."""
value = os.environ.get(key, str(default)) value = os.environ.get(key, str(default))
return str_to_bool(value) == 1 # As its name indicates `str_to_bool` actually returns an int... return (
str_to_bool(value) == 1
) # As its name indicates `str_to_bool` actually returns an int...
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)