Formatting

This commit is contained in:
Aymeric 2024-12-16 15:46:47 +01:00
parent 1751bf03ac
commit 06066437fd
18 changed files with 275 additions and 204 deletions

View File

@ -18,11 +18,7 @@ __version__ = "0.1.0"
from typing import TYPE_CHECKING
from transformers.utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
from transformers.utils import _LazyModule
from transformers.utils.import_utils import define_import_structure
@ -43,4 +39,6 @@ else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
sys.modules[__name__] = _LazyModule(
__name__, _file, define_import_structure(_file), module_spec=__spec__
)

View File

@ -79,8 +79,9 @@ class AgentGenerationError(AgentError):
pass
@dataclass
class ToolCall():
class ToolCall:
tool_name: str
tool_arguments: Any
@ -146,13 +147,17 @@ Here is a list of the team members that you can call:"""
def format_prompt_with_managed_agents_descriptions(
prompt_template, managed_agents, agent_descriptions_placeholder: Optional[str] = None
prompt_template,
managed_agents,
agent_descriptions_placeholder: Optional[str] = None,
) -> str:
if agent_descriptions_placeholder is None:
agent_descriptions_placeholder = "{{managed_agents_descriptions}}"
if agent_descriptions_placeholder not in prompt_template:
print("PROMPT TEMPLLL", prompt_template)
raise ValueError(f"Provided prompt template does not contain the managed agents descriptions placeholder '{agent_descriptions_placeholder}'")
raise ValueError(
f"Provided prompt template does not contain the managed agents descriptions placeholder '{agent_descriptions_placeholder}'"
)
if len(managed_agents.keys()) > 0:
return prompt_template.replace(
agent_descriptions_placeholder, show_agents_descriptions(managed_agents)
@ -970,7 +975,9 @@ class CodeAgent(ReactAgent):
error_msg = f"Error in code parsing: {e}. Make sure to provide correct code"
raise AgentParsingError(error_msg)
log_entry.tool_call = ToolCall(tool_name="python_interpreter", tool_arguments=code_action)
log_entry.tool_call = ToolCall(
tool_name="python_interpreter", tool_arguments=code_action
)
# Execute
if self.verbose:
@ -1075,4 +1082,13 @@ And even if your task resolution is not successful, please return as much contex
else:
return output
__all__ = ["AgentError", "BaseAgent", "ManagedAgent", "ReactAgent", "CodeAgent", "JsonAgent", "Toolbox"]
__all__ = [
"AgentError",
"BaseAgent",
"ManagedAgent",
"ReactAgent",
"CodeAgent",
"JsonAgent",
"Toolbox",
]

View File

@ -127,7 +127,10 @@ class PythonInterpreterTool(Tool):
name = "python_interpreter"
description = "This is a tool that evaluates python code. It can be used to perform calculations."
inputs = {
"code": {"type": "string", "description": "The python code to run in interpreter"}
"code": {
"type": "string",
"description": "The python code to run in interpreter",
}
}
output_type = "string"
@ -186,4 +189,5 @@ class UserInputTool(Tool):
user_input = input(f"{question} => ")
return user_input
__all__ = ["PythonInterpreterTool", "FinalAnswerTool", "UserInputTool"]

View File

@ -9,12 +9,13 @@ from typing import Optional, Dict, Tuple, Set, Any
import types
from .default_tools import BASE_PYTHON_TOOLS
class StateManager:
def __init__(self, work_dir: Path):
self.work_dir = work_dir
self.state_file = work_dir / "interpreter_state.pickle"
self.imports_file = work_dir / "imports.txt"
self.import_pattern = re.compile(r'^(?:from\s+[\w.]+\s+)?import\s+.+$')
self.import_pattern = re.compile(r"^(?:from\s+[\w.]+\s+)?import\s+.+$")
self.imports: Set[str] = set()
def is_import_statement(self, code: str) -> bool:
@ -23,7 +24,7 @@ class StateManager:
def track_imports(self, code: str):
"""Track import statements for later use."""
for line in code.split('\n'):
for line in code.split("\n"):
if self.is_import_statement(line.strip()):
self.imports.add(line.strip())
@ -37,20 +38,21 @@ class StateManager:
"""
# Filter out modules, functions, and special variables
state_dict = {
'variables': {
k: v for k, v in locals_dict.items()
"variables": {
k: v
for k, v in locals_dict.items()
if not (
k.startswith('_')
k.startswith("_")
or callable(v)
or isinstance(v, type)
or isinstance(v, types.ModuleType)
)
},
'imports': list(self.imports),
'source': executor
"imports": list(self.imports),
"source": executor,
}
with open(self.state_file, 'wb') as f:
with open(self.state_file, "wb") as f:
pickle.dump(state_dict, f)
def load_state(self, executor: str) -> Dict[str, Any]:
@ -66,14 +68,14 @@ class StateManager:
if not self.state_file.exists():
return {}
with open(self.state_file, 'rb') as f:
with open(self.state_file, "rb") as f:
state_dict = pickle.load(f)
# First handle imports
for import_stmt in state_dict['imports']:
for import_stmt in state_dict["imports"]:
exec(import_stmt, globals())
return state_dict['variables']
return state_dict["variables"]
def read_multiplexed_response(socket):
@ -81,10 +83,10 @@ def read_multiplexed_response(socket):
socket.settimeout(10.0)
i = 0
while True and i<1000:
while True and i < 1000:
# Stream output from socket
response_data = socket.recv(4096)
responses = response_data.split(b'\x01\x00\x00\x00\x00\x00')
responses = response_data.split(b"\x01\x00\x00\x00\x00\x00")
# The last non-empty chunk should be our JSON response
if len(responses) > 0:
@ -92,15 +94,15 @@ def read_multiplexed_response(socket):
if chunk and len(chunk.strip()) > 0:
try:
# Find the start of valid JSON by looking for '{'
json_start = chunk.find(b'{')
json_start = chunk.find(b"{")
if json_start != -1:
decoded = chunk[json_start:].decode('utf-8')
decoded = chunk[json_start:].decode("utf-8")
result = json.loads(decoded)
if "output" in result:
return decoded
except json.JSONDecodeError:
continue
i+=1
i += 1
class DockerPythonInterpreter:
@ -113,7 +115,6 @@ class DockerPythonInterpreter:
self.socket = None
self.state_manager = StateManager(work_dir)
def create_interpreter_script(self) -> str:
"""Create the interpreter script that will run inside the container"""
script = """
@ -230,9 +231,7 @@ if __name__ == '__main__':
self.create_interpreter_script()
# Setup volume mapping
volumes = {
str(self.work_dir.resolve()): {"bind": "/workspace", "mode": "rw"}
}
volumes = {str(self.work_dir.resolve()): {"bind": "/workspace", "mode": "rw"}}
for container in self.client.containers.list(all=True):
if container_name == container.name:
@ -241,7 +240,7 @@ if __name__ == '__main__':
container.start()
self.container = container
break
else: # Create new container
else: # Create new container
self.container = self.client.containers.run(
"python:3.9",
name=container_name,
@ -250,22 +249,20 @@ if __name__ == '__main__':
tty=True,
stdin_open=True,
working_dir="/workspace",
volumes=volumes
volumes=volumes,
)
# Install packages in the new container
print("Installing packages...")
packages = ["pandas", "numpy", "pickle5"] # Add your required packages here
result = self.container.exec_run(
f"pip install {' '.join(packages)}",
workdir="/workspace"
f"pip install {' '.join(packages)}", workdir="/workspace"
)
if result.exit_code != 0:
print(f"Warning: Failed to install: {result.output.decode()}")
else:
print(f"Installed {packages}.")
if not self.wait_for_ready(self.container):
raise Exception("Failed to start container")
@ -276,14 +273,12 @@ if __name__ == '__main__':
stdin=True,
stdout=True,
stderr=True,
tty=True
tty=True,
)
# Connect to the exec instance
self.socket = self.client.api.exec_start(
self.exec_id['Id'],
socket=True,
demux=True
self.exec_id["Id"], socket=True, demux=True
)._sock
def _raw_execute(self, code: str) -> Tuple[str, bool]:
@ -296,14 +291,14 @@ if __name__ == '__main__':
if not self.socket:
raise Exception("Socket not started")
command = json.dumps({'code': code}) + '\n'
command = json.dumps({"code": code}) + "\n"
self.socket.send(command.encode())
response = read_multiplexed_response(self.socket)
try:
result = json.loads(response)
return result['output'], result['more']
return result["output"], result["more"]
except json.JSONDecodeError:
return f"Error: Invalid response from interpreter: {response}", False
@ -311,7 +306,7 @@ if __name__ == '__main__':
"""Get the current locals dictionary from the interpreter by pickling directly from Docker."""
pickle_path = self.work_dir / "locals.pickle"
if pickle_path.exists():
with open(pickle_path, 'rb') as f:
with open(pickle_path, "rb") as f:
try:
return pickle.load(f)
except Exception as e:
@ -326,10 +321,7 @@ if __name__ == '__main__':
output, more = self._raw_execute(code)
# Save state after execution
self.state_manager.save_state(
self.get_locals_dict(),
'docker'
)
self.state_manager.save_state(self.get_locals_dict(), "docker")
return output, more
def stop(self, remove: bool = False):
@ -349,6 +341,7 @@ if __name__ == '__main__':
print(f"Error stopping container: {e}")
raise
def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any:
from .local_python_executor import evaluate_python_code, LIST_SAFE_MODULES
@ -359,7 +352,7 @@ def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any:
state_manager.track_imports(code)
# Load state from Docker if available
locals_dict = state_manager.load_state('local')
locals_dict = state_manager.load_state("local")
# Execute in a new namespace with loaded state
namespace = {}
@ -374,7 +367,7 @@ def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any:
)
# Save state for Docker
state_manager.save_state(namespace, 'local')
state_manager.save_state(namespace, "local")
return output
@ -382,14 +375,17 @@ def create_tools_regex(tool_names):
# Escape any special regex characters in tool names
escaped_names = [re.escape(name) for name in tool_names]
# Join with | and add word boundaries
pattern = r'\b(' + '|'.join(escaped_names) + r')\b'
pattern = r"\b(" + "|".join(escaped_names) + r")\b"
return re.compile(pattern)
def execute_code(code: str, tools: Dict[str, Any], work_dir: Path, interpreter):
"""Execute code with automatic switching between Docker and local."""
lines = code.split('\n')
lines = code.split("\n")
current_block = []
tool_regex = create_tools_regex(list(tools.keys()) + ["print"]) # Added print for testing
tool_regex = create_tools_regex(
list(tools.keys()) + ["print"]
) # Added print for testing
tools = {
**BASE_PYTHON_TOOLS.copy(),
@ -400,20 +396,20 @@ def execute_code(code: str, tools: Dict[str, Any], work_dir: Path, interpreter):
if tool_regex.search(line):
# Execute accumulated Docker code if any
if current_block:
output, more = interpreter.execute('\n'.join(current_block))
print(output, end='')
output, more = interpreter.execute("\n".join(current_block))
print(output, end="")
current_block = []
output = execute_locally(line, work_dir, tools)
if output:
print(output, end='')
print(output, end="")
else:
current_block.append(line)
# Execute any remaining Docker code
if current_block:
output, more = interpreter.execute('\n'.join(current_block))
print(output, end='')
output, more = interpreter.execute("\n".join(current_block))
print(output, end="")
__all__ = ["DockerPythonInterpreter", "execute_code"]

View File

@ -111,4 +111,5 @@ class GradioUI:
demo.launch()
__all__ = ["stream_to_gradio", "GradioUI"]

View File

@ -37,6 +37,7 @@ DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_action>",
}
class MessageRole(str, Enum):
USER = "user"
ASSISTANT = "assistant"
@ -48,6 +49,7 @@ class MessageRole(str, Enum):
def roles(cls):
return [r.value for r in cls]
openai_role_conversions = {
MessageRole.TOOL_RESPONSE: MessageRole.USER,
}
@ -56,6 +58,7 @@ llama_role_conversions = {
MessageRole.TOOL_RESPONSE: MessageRole.USER,
}
def get_clean_message_list(
message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}
):
@ -118,7 +121,7 @@ class HfEngine:
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500
max_tokens: int = 1500,
):
raise NotImplementedError
@ -276,7 +279,12 @@ class TransformersEngine(HfEngine):
class OpenAIEngine:
def __init__(self, model_name: Optional[str] = None, api_key: Optional[str] = None, base_url: Optional[str] = None):
def __init__(
self,
model_name: Optional[str] = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
):
"""Creates a LLM Engine that follows OpenAI format.
Args:
@ -301,7 +309,9 @@ class OpenAIEngine:
grammar: Optional[str] = None,
max_tokens: int = 1500,
) -> str:
messages = get_clean_message_list(messages, role_conversions=openai_role_conversions)
messages = get_clean_message_list(
messages, role_conversions=openai_role_conversions
)
response = self.client.chat.completions.create(
model=self.model_name,
@ -337,7 +347,9 @@ class AnthropicEngine:
grammar: Optional[str] = None,
max_tokens: int = 1500,
) -> str:
messages = get_clean_message_list(messages, role_conversions=openai_role_conversions)
messages = get_clean_message_list(
messages, role_conversions=openai_role_conversions
)
index_system_message, system_prompt = None, None
for index, message in enumerate(messages):
if message["role"] == MessageRole.SYSTEM:
@ -346,7 +358,9 @@ class AnthropicEngine:
if system_prompt is None:
raise Exception("No system prompt found!")
filtered_messages = [message for i, message in enumerate(messages) if i != index_system_message]
filtered_messages = [
message for i, message in enumerate(messages) if i != index_system_message
]
if len(filtered_messages) == 0:
print("Error, no user message:", messages)
assert False
@ -366,4 +380,13 @@ class AnthropicEngine:
return full_response_text
__all__ = ["MessageRole", "llama_role_conversions", "get_clean_message_list", "HfEngine", "TransformersEngine", "HfApiEngine", "OpenAIEngine", "AnthropicEngine"]
__all__ = [
"MessageRole",
"llama_role_conversions",
"get_clean_message_list",
"HfEngine",
"TransformersEngine",
"HfApiEngine",
"OpenAIEngine",
"AnthropicEngine",
]

View File

@ -1000,4 +1000,5 @@ def evaluate_python_code(
msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
raise InterpreterError(msg)
__all__ = ["evaluate_python_code"]

View File

@ -44,4 +44,5 @@ class Monitor:
console.print(f"- Input tokens: {self.total_input_token_count:,}")
console.print(f"- Output tokens: {self.total_output_token_count:,}")
__all__ = ["Monitor"]

View File

@ -491,4 +491,10 @@ Here is my new/updated plan of action to solve the task:
{plan_update}
```"""
__all__ = ["USER_PROMPT_PLAN_UPDATE", "PLAN_UPDATE_FINAL_PLAN_REDACTION", "ONESHOT_CODE_SYSTEM_PROMPT", "CODE_SYSTEM_PROMPT", "JSON_SYSTEM_PROMPT"]
__all__ = [
"USER_PROMPT_PLAN_UPDATE",
"PLAN_UPDATE_FINAL_PLAN_REDACTION",
"ONESHOT_CODE_SYSTEM_PROMPT",
"CODE_SYSTEM_PROMPT",
"JSON_SYSTEM_PROMPT",
]

View File

@ -78,4 +78,5 @@ class VisitWebpageTool(Tool):
except Exception as e:
return f"An unexpected error occurred: {str(e)}"
__all__ = ["DuckDuckGoSearchTool", "VisitWebpageTool"]

View File

@ -39,12 +39,6 @@ from huggingface_hub import (
from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session
from packaging import version
from transformers.dynamic_module_utils import (
custom_object_save,
get_class_from_dynamic_module,
get_imports,
)
from transformers import AutoProcessor
from transformers.utils import (
TypeHintParsingException,
cached_file,
@ -62,11 +56,10 @@ logger = logging.getLogger(__name__)
if is_torch_available():
import torch
pass
if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import send_to_device
pass
TOOL_CONFIG_FILE = "tool_config.json"
@ -123,6 +116,7 @@ def validate_after_init(cls, do_validate_forward: bool = True):
cls.__init__ = new_init
return cls
def validate_args_are_self_contained(source_code):
"""Validates that all names in forward method are properly defined.
In particular it will check that all imports are done within the function."""
@ -150,7 +144,7 @@ def validate_args_are_self_contained(source_code):
def visit_ImportFrom(self, node):
"""Handle from imports like 'from datetime import datetime'."""
module = node.module or ''
module = node.module or ""
for name in node.names:
actual_name = name.asname or name.name
self.from_imports[actual_name] = (module, name.name, actual_name)
@ -187,9 +181,11 @@ def validate_args_are_self_contained(source_code):
self.assigned_names.update(target_names)
# Special handling for enumerate
if (isinstance(node.iter, ast.Call) and
isinstance(node.iter.func, ast.Name) and
node.iter.func.id == 'enumerate'):
if (
isinstance(node.iter, ast.Call)
and isinstance(node.iter.func, ast.Name)
and node.iter.func.id == "enumerate"
):
# For enumerate, if we have "for i, x in enumerate(...)",
# both i and x should be marked as assigned
if isinstance(node.target, ast.Tuple):
@ -201,19 +197,19 @@ def validate_args_are_self_contained(source_code):
self.generic_visit(node)
def visit_Name(self, node):
if (isinstance(node.ctx, ast.Load) and not (
node.id == "tool" or
node.id in builtin_names or
node.id in arg_names or
node.id == 'self' or
node.id in self.assigned_names
)):
if isinstance(node.ctx, ast.Load) and not (
node.id == "tool"
or node.id in builtin_names
or node.id in arg_names
or node.id == "self"
or node.id in self.assigned_names
):
if node.id not in self.from_imports and node.id not in self.imports:
self.undefined_names.add(node.id)
def visit_Attribute(self, node):
# Skip self.something
if not (isinstance(node.value, ast.Name) and node.value.id == 'self'):
if not (isinstance(node.value, ast.Name) and node.value.id == "self"):
self.generic_visit(node)
checker = NameChecker()
@ -226,6 +222,7 @@ def validate_args_are_self_contained(source_code):
"""
)
AUTHORIZED_TYPES = [
"string",
"boolean",
@ -273,7 +270,6 @@ class Tool:
super().__init_subclass__(**kwargs)
validate_after_init(cls, do_validate_forward=False)
def validate_arguments(self, do_validate_forward: bool = True):
required_attributes = {
"description": str,
@ -359,13 +355,13 @@ class {class_name}(Tool):
def add_self_argument(source_code: str) -> str:
"""Add 'self' as first argument to a function definition if not present."""
pattern = r'def forward\(((?!self)[^)]*)\)'
pattern = r"def forward\(((?!self)[^)]*)\)"
def replacement(match):
args = match.group(1).strip()
if args: # If there are other arguments
return f'def forward(self, {args})'
return 'def forward(self)'
return f"def forward(self, {args})"
return "def forward(self)"
return re.sub(pattern, replacement, source_code)
@ -391,11 +387,7 @@ class {class_name}(Tool):
# Save app file
app_file = os.path.join(output_dir, "app.py")
with open(app_file, "w", encoding="utf-8") as f:
f.write(
APP_FILE_TEMPLATE.format(
class_name=class_name
)
)
f.write(APP_FILE_TEMPLATE.format(class_name=class_name))
# Save requirements file
requirements_file = os.path.join(output_dir, "requirements.txt")
@ -457,7 +449,7 @@ class {class_name}(Tool):
self.save(work_dir)
print(work_dir)
with open(work_dir + "/tool.py", "r") as f:
print('\n'.join(f.readlines()))
print("\n".join(f.readlines()))
logger.info(
f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}"
)
@ -575,7 +567,6 @@ class {class_name}(Tool):
return tool_class(**kwargs)
@staticmethod
def from_space(
space_id: str,
@ -702,7 +693,11 @@ class {class_name}(Tool):
return output
return SpaceToolWrapper(
space_id=space_id, name=name, description=description, api_name=api_name, token=token
space_id=space_id,
name=name,
description=description,
api_name=api_name,
token=token,
)
@staticmethod
@ -855,12 +850,12 @@ TOOL_MAPPING = {
def load_tool(
task_or_repo_id,
model_repo_id: Optional[str] = None,
token: Optional[str] = None,
trust_remote_code: bool=False,
**kwargs
):
task_or_repo_id,
model_repo_id: Optional[str] = None,
token: Optional[str] = None,
trust_remote_code: bool = False,
**kwargs,
):
"""
Main function to quickly load a tool, be it on the Hub or in the Transformers library.
@ -909,7 +904,11 @@ def load_tool(
f"code that you have checked."
)
return Tool.from_hub(
task_or_repo_id, model_repo_id=model_repo_id, token=token, trust_remote_code=trust_remote_code, **kwargs
task_or_repo_id,
model_repo_id=model_repo_id,
token=token,
trust_remote_code=trust_remote_code,
**kwargs,
)
@ -1028,7 +1027,7 @@ def tool(tool_function: Callable) -> Tool:
raise TypeHintParsingException(
"Tool return type not found: make sure your function has a return type hint!"
)
class_name = ''.join([el.title() for el in parameters['name'].split('_')])
class_name = "".join([el.title() for el in parameters["name"].split("_")])
if parameters["return"]["type"] == "object":
parameters["return"]["type"] = "any"
@ -1086,7 +1085,9 @@ class Toolbox:
"""Get all tools currently in the toolbox"""
return self._tools
def show_tool_descriptions(self, tool_description_template: Optional[str] = None) -> str:
def show_tool_descriptions(
self, tool_description_template: Optional[str] = None
) -> str:
"""
Returns the description of all tools in the toolbox
@ -1151,4 +1152,12 @@ class Toolbox:
toolbox_description += f"\t{tool.name}: {tool.description}\n"
return toolbox_description
__all__ = ["AUTHORIZED_TYPES", "Tool", "tool", "load_tool", "launch_gradio_demo", "Toolbox"]
__all__ = [
"AUTHORIZED_TYPES",
"Tool",
"tool",
"load_tool",
"launch_gradio_demo",
"Toolbox",
]

View File

@ -267,4 +267,5 @@ def handle_agent_outputs(output, output_type=None):
return _v(output)
return output
__all__ = ["AgentType", "AgentImage", "AgentText", "AgentAudio"]

View File

@ -18,6 +18,7 @@ import json
import re
from typing import Tuple, Dict, Union
import ast
from rich.console import Console
from transformers.utils.import_utils import _is_package_available
@ -28,8 +29,6 @@ def is_pygments_available():
return _pygments_available
from rich.console import Console
console = Console()
@ -111,6 +110,7 @@ def truncate_content(
+ content[-MAX_LENGTH_TRUNCATE_CONTENT // 2 :]
)
class ImportFinder(ast.NodeVisitor):
def __init__(self):
self.packages = set()
@ -118,13 +118,14 @@ class ImportFinder(ast.NodeVisitor):
def visit_Import(self, node):
for alias in node.names:
# Get the base package name (before any dots)
base_package = alias.name.split('.')[0]
base_package = alias.name.split(".")[0]
self.packages.add(base_package)
def visit_ImportFrom(self, node):
if node.module: # for "from x import y" statements
# Get the base package name (before any dots)
base_package = node.module.split('.')[0]
base_package = node.module.split(".")[0]
self.packages.add(base_package)
__all__ = []

View File

@ -27,12 +27,13 @@ from agents.agents import (
CodeAgent,
JsonAgent,
Toolbox,
ToolCall
ToolCall,
)
from agents.tools import tool
from agents.default_tools import PythonInterpreterTool
from transformers.testing_utils import get_tests_dir
def get_new_path(suffix="") -> str:
directory = tempfile.mkdtemp()
return os.path.join(directory, str(uuid.uuid4()) + suffix)
@ -60,6 +61,7 @@ Action:
}
"""
def fake_json_llm_image(messages, stop_sequences=None, grammar=None) -> str:
prompt = str(messages)
@ -82,6 +84,7 @@ Action:
}
"""
def fake_code_llm(messages, stop_sequences=None, grammar=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
@ -179,9 +182,7 @@ class AgentTests(unittest.TestCase):
assert output == "7.2904"
def test_fake_json_agent(self):
agent = JsonAgent(
tools=[PythonInterpreterTool()], llm_engine=fake_json_llm
)
agent = JsonAgent(tools=[PythonInterpreterTool()], llm_engine=fake_json_llm)
output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, str)
assert output == "7.2904"
@ -209,9 +210,7 @@ Action:
Args:
prompt: The prompt
"""
return Image.open(
Path(get_tests_dir("fixtures")) / "000000039769.png"
)
return Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png")
agent = JsonAgent(
tools=[fake_image_generation_tool], llm_engine=fake_json_llm_image
@ -221,9 +220,7 @@ Action:
assert isinstance(agent.state["image.png"], Image.Image)
def test_fake_code_agent(self):
agent = CodeAgent(
tools=[PythonInterpreterTool()], llm_engine=fake_code_llm
)
agent = CodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_code_llm)
output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, float)
assert output == 7.2904
@ -234,9 +231,7 @@ Action:
)
def test_reset_conversations(self):
agent = CodeAgent(
tools=[PythonInterpreterTool()], llm_engine=fake_code_llm
)
agent = CodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_code_llm)
output = agent.run("What is 2 multiplied by 3.6452?", reset=True)
assert output == 7.2904
assert len(agent.logs) == 4
@ -299,9 +294,7 @@ Action:
# check that python_interpreter base tool does not get added to code agents
agent = CodeAgent(tools=[], llm_engine=fake_code_llm, add_base_tools=True)
assert (
len(agent.toolbox.tools) == 2
) # added final_answer tool + search
assert len(agent.toolbox.tools) == 2 # added final_answer tool + search
def test_function_persistence_across_steps(self):
agent = CodeAgent(

View File

@ -25,6 +25,7 @@ from pathlib import Path
from typing import List
from dotenv import load_dotenv
class SubprocessCallException(Exception):
pass
@ -59,7 +60,7 @@ class DocCodeExtractor:
@staticmethod
def extract_python_code(content: str) -> List[str]:
"""Extract Python code blocks from markdown content."""
pattern = r'```(?:python|py)\n(.*?)\n```'
pattern = r"```(?:python|py)\n(.*?)\n```"
matches = re.finditer(pattern, content, re.DOTALL)
return [match.group(1).strip() for match in matches]
@ -118,18 +119,27 @@ class TestDocs:
# Create and execute test script
try:
excluded_snippets = ["ToolCollection", "image_generation_tool", "from_langchain"]
excluded_snippets = [
"ToolCollection",
"image_generation_tool",
"from_langchain",
]
code_blocks = [
block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", self.hf_token) for block in code_blocks
if not any([snippet in block for snippet in excluded_snippets]) # Exclude these tools that take longer to run and add dependencies
block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", self.hf_token)
for block in code_blocks
if not any(
[snippet in block for snippet in excluded_snippets]
) # Exclude these tools that take longer to run and add dependencies
]
test_script = self.extractor.create_test_script(code_blocks, self._tmpdir)
run_command(self.launch_args + [str(test_script)])
except SubprocessCallException as e:
pytest.fail(f"\nError while testing {doc_path.name}:\n{str(e)}")
except Exception as e:
pytest.fail(f"\nUnexpected error while testing {doc_path.name}:\n{traceback.format_exc()}")
except Exception:
pytest.fail(
f"\nUnexpected error while testing {doc_path.name}:\n{traceback.format_exc()}"
)
@pytest.fixture(autouse=True)
def _setup(self):
@ -152,7 +162,5 @@ def pytest_generate_tests(metafunc):
# Parameterize with the markdown files
metafunc.parametrize(
"doc_path",
test_class.md_files,
ids=[f.stem for f in test_class.md_files]
"doc_path", test_class.md_files, ids=[f.stem for f in test_class.md_files]
)

View File

@ -20,9 +20,10 @@ import numpy as np
from PIL import Image
from transformers import is_torch_available
from agents.types import AGENT_TYPE_MAPPING
from agents.default_tools import FinalAnswerTool
from transformers.testing_utils import get_tests_dir, require_torch
from agents.types import AGENT_TYPE_MAPPING
from agents.default_tools import FinalAnswerTool
from .test_tools_common import ToolTesterMixin

View File

@ -98,6 +98,7 @@ class ToolTesterMixin:
agent_type = AGENT_TYPE_MAPPING[self.tool.output_type]
self.assertTrue(isinstance(output, agent_type))
class ToolTests(unittest.TestCase):
def test_tool_init_with_decorator(self):
@tool
@ -163,40 +164,46 @@ class ToolTests(unittest.TestCase):
assert coolfunc.output_type == "number"
assert "docstring has no description for the argument" in str(e)
def test_tool_definition_needs_imports_in_function(self):
def test_tool_definition_raises_error_imports_outside_function(self):
with pytest.raises(Exception) as e:
from datetime import datetime
@tool
def get_current_time() -> str:
"""
Gets the current time.
"""
return str(datetime.now())
assert "datetime" in str(e)
# Also test with classic definition
with pytest.raises(Exception) as e:
class GetCurrentTimeTool(Tool):
name="get_current_time_tool"
description="Gets the current time"
name = "get_current_time_tool"
description = "Gets the current time"
inputs = {}
output_type = "string"
def forward(self):
return str(datetime.now())
assert "datetime" in str(e)
def test_tool_definition_raises_no_error_imports_in_function(self):
@tool
def get_current_time() -> str:
"""
Gets the current time.
"""
from datetime import datetime
return str(datetime.now())
class GetCurrentTimeTool(Tool):
name="get_current_time_tool"
description="Gets the current time"
name = "get_current_time_tool"
description = "Gets the current time"
inputs = {}
output_type = "string"

View File

@ -5,6 +5,7 @@ import tempfile
from pathlib import Path
def str_to_bool(value) -> int:
"""
Converts a string representation of truth to `True` (1) or `False` (0).
@ -28,10 +29,13 @@ def get_int_from_env(env_keys, default):
return val
return default
def parse_flag_from_env(key, default=False):
"""Returns truthy value for `key` from the env if available else the default."""
value = os.environ.get(key, str(default))
return str_to_bool(value) == 1 # As its name indicates `str_to_bool` actually returns an int...
return (
str_to_bool(value) == 1
) # As its name indicates `str_to_bool` actually returns an int...
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)