Formatting
This commit is contained in:
parent
1751bf03ac
commit
06066437fd
|
@ -18,11 +18,7 @@ __version__ = "0.1.0"
|
|||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers.utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.utils import _LazyModule
|
||||
from transformers.utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
|
@ -43,4 +39,6 @@ else:
|
|||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__, _file, define_import_structure(_file), module_spec=__spec__
|
||||
)
|
||||
|
|
|
@ -79,8 +79,9 @@ class AgentGenerationError(AgentError):
|
|||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall():
|
||||
class ToolCall:
|
||||
tool_name: str
|
||||
tool_arguments: Any
|
||||
|
||||
|
@ -146,13 +147,17 @@ Here is a list of the team members that you can call:"""
|
|||
|
||||
|
||||
def format_prompt_with_managed_agents_descriptions(
|
||||
prompt_template, managed_agents, agent_descriptions_placeholder: Optional[str] = None
|
||||
prompt_template,
|
||||
managed_agents,
|
||||
agent_descriptions_placeholder: Optional[str] = None,
|
||||
) -> str:
|
||||
if agent_descriptions_placeholder is None:
|
||||
agent_descriptions_placeholder = "{{managed_agents_descriptions}}"
|
||||
if agent_descriptions_placeholder not in prompt_template:
|
||||
print("PROMPT TEMPLLL", prompt_template)
|
||||
raise ValueError(f"Provided prompt template does not contain the managed agents descriptions placeholder '{agent_descriptions_placeholder}'")
|
||||
raise ValueError(
|
||||
f"Provided prompt template does not contain the managed agents descriptions placeholder '{agent_descriptions_placeholder}'"
|
||||
)
|
||||
if len(managed_agents.keys()) > 0:
|
||||
return prompt_template.replace(
|
||||
agent_descriptions_placeholder, show_agents_descriptions(managed_agents)
|
||||
|
@ -970,7 +975,9 @@ class CodeAgent(ReactAgent):
|
|||
error_msg = f"Error in code parsing: {e}. Make sure to provide correct code"
|
||||
raise AgentParsingError(error_msg)
|
||||
|
||||
log_entry.tool_call = ToolCall(tool_name="python_interpreter", tool_arguments=code_action)
|
||||
log_entry.tool_call = ToolCall(
|
||||
tool_name="python_interpreter", tool_arguments=code_action
|
||||
)
|
||||
|
||||
# Execute
|
||||
if self.verbose:
|
||||
|
@ -1075,4 +1082,13 @@ And even if your task resolution is not successful, please return as much contex
|
|||
else:
|
||||
return output
|
||||
|
||||
__all__ = ["AgentError", "BaseAgent", "ManagedAgent", "ReactAgent", "CodeAgent", "JsonAgent", "Toolbox"]
|
||||
|
||||
__all__ = [
|
||||
"AgentError",
|
||||
"BaseAgent",
|
||||
"ManagedAgent",
|
||||
"ReactAgent",
|
||||
"CodeAgent",
|
||||
"JsonAgent",
|
||||
"Toolbox",
|
||||
]
|
||||
|
|
|
@ -127,7 +127,10 @@ class PythonInterpreterTool(Tool):
|
|||
name = "python_interpreter"
|
||||
description = "This is a tool that evaluates python code. It can be used to perform calculations."
|
||||
inputs = {
|
||||
"code": {"type": "string", "description": "The python code to run in interpreter"}
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "The python code to run in interpreter",
|
||||
}
|
||||
}
|
||||
output_type = "string"
|
||||
|
||||
|
@ -186,4 +189,5 @@ class UserInputTool(Tool):
|
|||
user_input = input(f"{question} => ")
|
||||
return user_input
|
||||
|
||||
|
||||
__all__ = ["PythonInterpreterTool", "FinalAnswerTool", "UserInputTool"]
|
|
@ -9,12 +9,13 @@ from typing import Optional, Dict, Tuple, Set, Any
|
|||
import types
|
||||
from .default_tools import BASE_PYTHON_TOOLS
|
||||
|
||||
|
||||
class StateManager:
|
||||
def __init__(self, work_dir: Path):
|
||||
self.work_dir = work_dir
|
||||
self.state_file = work_dir / "interpreter_state.pickle"
|
||||
self.imports_file = work_dir / "imports.txt"
|
||||
self.import_pattern = re.compile(r'^(?:from\s+[\w.]+\s+)?import\s+.+$')
|
||||
self.import_pattern = re.compile(r"^(?:from\s+[\w.]+\s+)?import\s+.+$")
|
||||
self.imports: Set[str] = set()
|
||||
|
||||
def is_import_statement(self, code: str) -> bool:
|
||||
|
@ -23,7 +24,7 @@ class StateManager:
|
|||
|
||||
def track_imports(self, code: str):
|
||||
"""Track import statements for later use."""
|
||||
for line in code.split('\n'):
|
||||
for line in code.split("\n"):
|
||||
if self.is_import_statement(line.strip()):
|
||||
self.imports.add(line.strip())
|
||||
|
||||
|
@ -37,20 +38,21 @@ class StateManager:
|
|||
"""
|
||||
# Filter out modules, functions, and special variables
|
||||
state_dict = {
|
||||
'variables': {
|
||||
k: v for k, v in locals_dict.items()
|
||||
"variables": {
|
||||
k: v
|
||||
for k, v in locals_dict.items()
|
||||
if not (
|
||||
k.startswith('_')
|
||||
k.startswith("_")
|
||||
or callable(v)
|
||||
or isinstance(v, type)
|
||||
or isinstance(v, types.ModuleType)
|
||||
)
|
||||
},
|
||||
'imports': list(self.imports),
|
||||
'source': executor
|
||||
"imports": list(self.imports),
|
||||
"source": executor,
|
||||
}
|
||||
|
||||
with open(self.state_file, 'wb') as f:
|
||||
with open(self.state_file, "wb") as f:
|
||||
pickle.dump(state_dict, f)
|
||||
|
||||
def load_state(self, executor: str) -> Dict[str, Any]:
|
||||
|
@ -66,14 +68,14 @@ class StateManager:
|
|||
if not self.state_file.exists():
|
||||
return {}
|
||||
|
||||
with open(self.state_file, 'rb') as f:
|
||||
with open(self.state_file, "rb") as f:
|
||||
state_dict = pickle.load(f)
|
||||
|
||||
# First handle imports
|
||||
for import_stmt in state_dict['imports']:
|
||||
for import_stmt in state_dict["imports"]:
|
||||
exec(import_stmt, globals())
|
||||
|
||||
return state_dict['variables']
|
||||
return state_dict["variables"]
|
||||
|
||||
|
||||
def read_multiplexed_response(socket):
|
||||
|
@ -84,7 +86,7 @@ def read_multiplexed_response(socket):
|
|||
while True and i < 1000:
|
||||
# Stream output from socket
|
||||
response_data = socket.recv(4096)
|
||||
responses = response_data.split(b'\x01\x00\x00\x00\x00\x00')
|
||||
responses = response_data.split(b"\x01\x00\x00\x00\x00\x00")
|
||||
|
||||
# The last non-empty chunk should be our JSON response
|
||||
if len(responses) > 0:
|
||||
|
@ -92,9 +94,9 @@ def read_multiplexed_response(socket):
|
|||
if chunk and len(chunk.strip()) > 0:
|
||||
try:
|
||||
# Find the start of valid JSON by looking for '{'
|
||||
json_start = chunk.find(b'{')
|
||||
json_start = chunk.find(b"{")
|
||||
if json_start != -1:
|
||||
decoded = chunk[json_start:].decode('utf-8')
|
||||
decoded = chunk[json_start:].decode("utf-8")
|
||||
result = json.loads(decoded)
|
||||
if "output" in result:
|
||||
return decoded
|
||||
|
@ -113,7 +115,6 @@ class DockerPythonInterpreter:
|
|||
self.socket = None
|
||||
self.state_manager = StateManager(work_dir)
|
||||
|
||||
|
||||
def create_interpreter_script(self) -> str:
|
||||
"""Create the interpreter script that will run inside the container"""
|
||||
script = """
|
||||
|
@ -230,9 +231,7 @@ if __name__ == '__main__':
|
|||
self.create_interpreter_script()
|
||||
|
||||
# Setup volume mapping
|
||||
volumes = {
|
||||
str(self.work_dir.resolve()): {"bind": "/workspace", "mode": "rw"}
|
||||
}
|
||||
volumes = {str(self.work_dir.resolve()): {"bind": "/workspace", "mode": "rw"}}
|
||||
|
||||
for container in self.client.containers.list(all=True):
|
||||
if container_name == container.name:
|
||||
|
@ -250,22 +249,20 @@ if __name__ == '__main__':
|
|||
tty=True,
|
||||
stdin_open=True,
|
||||
working_dir="/workspace",
|
||||
volumes=volumes
|
||||
volumes=volumes,
|
||||
)
|
||||
# Install packages in the new container
|
||||
print("Installing packages...")
|
||||
packages = ["pandas", "numpy", "pickle5"] # Add your required packages here
|
||||
|
||||
result = self.container.exec_run(
|
||||
f"pip install {' '.join(packages)}",
|
||||
workdir="/workspace"
|
||||
f"pip install {' '.join(packages)}", workdir="/workspace"
|
||||
)
|
||||
if result.exit_code != 0:
|
||||
print(f"Warning: Failed to install: {result.output.decode()}")
|
||||
else:
|
||||
print(f"Installed {packages}.")
|
||||
|
||||
|
||||
if not self.wait_for_ready(self.container):
|
||||
raise Exception("Failed to start container")
|
||||
|
||||
|
@ -276,14 +273,12 @@ if __name__ == '__main__':
|
|||
stdin=True,
|
||||
stdout=True,
|
||||
stderr=True,
|
||||
tty=True
|
||||
tty=True,
|
||||
)
|
||||
|
||||
# Connect to the exec instance
|
||||
self.socket = self.client.api.exec_start(
|
||||
self.exec_id['Id'],
|
||||
socket=True,
|
||||
demux=True
|
||||
self.exec_id["Id"], socket=True, demux=True
|
||||
)._sock
|
||||
|
||||
def _raw_execute(self, code: str) -> Tuple[str, bool]:
|
||||
|
@ -296,14 +291,14 @@ if __name__ == '__main__':
|
|||
if not self.socket:
|
||||
raise Exception("Socket not started")
|
||||
|
||||
command = json.dumps({'code': code}) + '\n'
|
||||
command = json.dumps({"code": code}) + "\n"
|
||||
self.socket.send(command.encode())
|
||||
|
||||
response = read_multiplexed_response(self.socket)
|
||||
|
||||
try:
|
||||
result = json.loads(response)
|
||||
return result['output'], result['more']
|
||||
return result["output"], result["more"]
|
||||
except json.JSONDecodeError:
|
||||
return f"Error: Invalid response from interpreter: {response}", False
|
||||
|
||||
|
@ -311,7 +306,7 @@ if __name__ == '__main__':
|
|||
"""Get the current locals dictionary from the interpreter by pickling directly from Docker."""
|
||||
pickle_path = self.work_dir / "locals.pickle"
|
||||
if pickle_path.exists():
|
||||
with open(pickle_path, 'rb') as f:
|
||||
with open(pickle_path, "rb") as f:
|
||||
try:
|
||||
return pickle.load(f)
|
||||
except Exception as e:
|
||||
|
@ -326,10 +321,7 @@ if __name__ == '__main__':
|
|||
output, more = self._raw_execute(code)
|
||||
|
||||
# Save state after execution
|
||||
self.state_manager.save_state(
|
||||
self.get_locals_dict(),
|
||||
'docker'
|
||||
)
|
||||
self.state_manager.save_state(self.get_locals_dict(), "docker")
|
||||
return output, more
|
||||
|
||||
def stop(self, remove: bool = False):
|
||||
|
@ -349,6 +341,7 @@ if __name__ == '__main__':
|
|||
print(f"Error stopping container: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any:
|
||||
from .local_python_executor import evaluate_python_code, LIST_SAFE_MODULES
|
||||
|
||||
|
@ -359,7 +352,7 @@ def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any:
|
|||
state_manager.track_imports(code)
|
||||
|
||||
# Load state from Docker if available
|
||||
locals_dict = state_manager.load_state('local')
|
||||
locals_dict = state_manager.load_state("local")
|
||||
|
||||
# Execute in a new namespace with loaded state
|
||||
namespace = {}
|
||||
|
@ -374,7 +367,7 @@ def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any:
|
|||
)
|
||||
|
||||
# Save state for Docker
|
||||
state_manager.save_state(namespace, 'local')
|
||||
state_manager.save_state(namespace, "local")
|
||||
return output
|
||||
|
||||
|
||||
|
@ -382,14 +375,17 @@ def create_tools_regex(tool_names):
|
|||
# Escape any special regex characters in tool names
|
||||
escaped_names = [re.escape(name) for name in tool_names]
|
||||
# Join with | and add word boundaries
|
||||
pattern = r'\b(' + '|'.join(escaped_names) + r')\b'
|
||||
pattern = r"\b(" + "|".join(escaped_names) + r")\b"
|
||||
return re.compile(pattern)
|
||||
|
||||
|
||||
def execute_code(code: str, tools: Dict[str, Any], work_dir: Path, interpreter):
|
||||
"""Execute code with automatic switching between Docker and local."""
|
||||
lines = code.split('\n')
|
||||
lines = code.split("\n")
|
||||
current_block = []
|
||||
tool_regex = create_tools_regex(list(tools.keys()) + ["print"]) # Added print for testing
|
||||
tool_regex = create_tools_regex(
|
||||
list(tools.keys()) + ["print"]
|
||||
) # Added print for testing
|
||||
|
||||
tools = {
|
||||
**BASE_PYTHON_TOOLS.copy(),
|
||||
|
@ -400,20 +396,20 @@ def execute_code(code: str, tools: Dict[str, Any], work_dir: Path, interpreter):
|
|||
if tool_regex.search(line):
|
||||
# Execute accumulated Docker code if any
|
||||
if current_block:
|
||||
output, more = interpreter.execute('\n'.join(current_block))
|
||||
print(output, end='')
|
||||
output, more = interpreter.execute("\n".join(current_block))
|
||||
print(output, end="")
|
||||
current_block = []
|
||||
|
||||
output = execute_locally(line, work_dir, tools)
|
||||
if output:
|
||||
print(output, end='')
|
||||
print(output, end="")
|
||||
else:
|
||||
current_block.append(line)
|
||||
|
||||
# Execute any remaining Docker code
|
||||
if current_block:
|
||||
output, more = interpreter.execute('\n'.join(current_block))
|
||||
print(output, end='')
|
||||
output, more = interpreter.execute("\n".join(current_block))
|
||||
print(output, end="")
|
||||
|
||||
|
||||
__all__ = ["DockerPythonInterpreter", "execute_code"]
|
|
@ -111,4 +111,5 @@ class GradioUI:
|
|||
|
||||
demo.launch()
|
||||
|
||||
|
||||
__all__ = ["stream_to_gradio", "GradioUI"]
|
|
@ -37,6 +37,7 @@ DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
|
|||
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_action>",
|
||||
}
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
|
@ -48,6 +49,7 @@ class MessageRole(str, Enum):
|
|||
def roles(cls):
|
||||
return [r.value for r in cls]
|
||||
|
||||
|
||||
openai_role_conversions = {
|
||||
MessageRole.TOOL_RESPONSE: MessageRole.USER,
|
||||
}
|
||||
|
@ -56,6 +58,7 @@ llama_role_conversions = {
|
|||
MessageRole.TOOL_RESPONSE: MessageRole.USER,
|
||||
}
|
||||
|
||||
|
||||
def get_clean_message_list(
|
||||
message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}
|
||||
):
|
||||
|
@ -118,7 +121,7 @@ class HfEngine:
|
|||
messages: List[Dict[str, str]],
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500
|
||||
max_tokens: int = 1500,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -276,7 +279,12 @@ class TransformersEngine(HfEngine):
|
|||
|
||||
|
||||
class OpenAIEngine:
|
||||
def __init__(self, model_name: Optional[str] = None, api_key: Optional[str] = None, base_url: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
):
|
||||
"""Creates a LLM Engine that follows OpenAI format.
|
||||
|
||||
Args:
|
||||
|
@ -301,7 +309,9 @@ class OpenAIEngine:
|
|||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500,
|
||||
) -> str:
|
||||
messages = get_clean_message_list(messages, role_conversions=openai_role_conversions)
|
||||
messages = get_clean_message_list(
|
||||
messages, role_conversions=openai_role_conversions
|
||||
)
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
|
@ -337,7 +347,9 @@ class AnthropicEngine:
|
|||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500,
|
||||
) -> str:
|
||||
messages = get_clean_message_list(messages, role_conversions=openai_role_conversions)
|
||||
messages = get_clean_message_list(
|
||||
messages, role_conversions=openai_role_conversions
|
||||
)
|
||||
index_system_message, system_prompt = None, None
|
||||
for index, message in enumerate(messages):
|
||||
if message["role"] == MessageRole.SYSTEM:
|
||||
|
@ -346,7 +358,9 @@ class AnthropicEngine:
|
|||
if system_prompt is None:
|
||||
raise Exception("No system prompt found!")
|
||||
|
||||
filtered_messages = [message for i, message in enumerate(messages) if i != index_system_message]
|
||||
filtered_messages = [
|
||||
message for i, message in enumerate(messages) if i != index_system_message
|
||||
]
|
||||
if len(filtered_messages) == 0:
|
||||
print("Error, no user message:", messages)
|
||||
assert False
|
||||
|
@ -366,4 +380,13 @@ class AnthropicEngine:
|
|||
return full_response_text
|
||||
|
||||
|
||||
__all__ = ["MessageRole", "llama_role_conversions", "get_clean_message_list", "HfEngine", "TransformersEngine", "HfApiEngine", "OpenAIEngine", "AnthropicEngine"]
|
||||
__all__ = [
|
||||
"MessageRole",
|
||||
"llama_role_conversions",
|
||||
"get_clean_message_list",
|
||||
"HfEngine",
|
||||
"TransformersEngine",
|
||||
"HfApiEngine",
|
||||
"OpenAIEngine",
|
||||
"AnthropicEngine",
|
||||
]
|
||||
|
|
|
@ -1000,4 +1000,5 @@ def evaluate_python_code(
|
|||
msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
|
||||
raise InterpreterError(msg)
|
||||
|
||||
|
||||
__all__ = ["evaluate_python_code"]
|
|
@ -44,4 +44,5 @@ class Monitor:
|
|||
console.print(f"- Input tokens: {self.total_input_token_count:,}")
|
||||
console.print(f"- Output tokens: {self.total_output_token_count:,}")
|
||||
|
||||
|
||||
__all__ = ["Monitor"]
|
|
@ -491,4 +491,10 @@ Here is my new/updated plan of action to solve the task:
|
|||
{plan_update}
|
||||
```"""
|
||||
|
||||
__all__ = ["USER_PROMPT_PLAN_UPDATE", "PLAN_UPDATE_FINAL_PLAN_REDACTION", "ONESHOT_CODE_SYSTEM_PROMPT", "CODE_SYSTEM_PROMPT", "JSON_SYSTEM_PROMPT"]
|
||||
__all__ = [
|
||||
"USER_PROMPT_PLAN_UPDATE",
|
||||
"PLAN_UPDATE_FINAL_PLAN_REDACTION",
|
||||
"ONESHOT_CODE_SYSTEM_PROMPT",
|
||||
"CODE_SYSTEM_PROMPT",
|
||||
"JSON_SYSTEM_PROMPT",
|
||||
]
|
||||
|
|
|
@ -78,4 +78,5 @@ class VisitWebpageTool(Tool):
|
|||
except Exception as e:
|
||||
return f"An unexpected error occurred: {str(e)}"
|
||||
|
||||
|
||||
__all__ = ["DuckDuckGoSearchTool", "VisitWebpageTool"]
|
|
@ -39,12 +39,6 @@ from huggingface_hub import (
|
|||
from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session
|
||||
from packaging import version
|
||||
|
||||
from transformers.dynamic_module_utils import (
|
||||
custom_object_save,
|
||||
get_class_from_dynamic_module,
|
||||
get_imports,
|
||||
)
|
||||
from transformers import AutoProcessor
|
||||
from transformers.utils import (
|
||||
TypeHintParsingException,
|
||||
cached_file,
|
||||
|
@ -62,11 +56,10 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
pass
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import PartialState
|
||||
from accelerate.utils import send_to_device
|
||||
pass
|
||||
|
||||
|
||||
TOOL_CONFIG_FILE = "tool_config.json"
|
||||
|
@ -123,6 +116,7 @@ def validate_after_init(cls, do_validate_forward: bool = True):
|
|||
cls.__init__ = new_init
|
||||
return cls
|
||||
|
||||
|
||||
def validate_args_are_self_contained(source_code):
|
||||
"""Validates that all names in forward method are properly defined.
|
||||
In particular it will check that all imports are done within the function."""
|
||||
|
@ -150,7 +144,7 @@ def validate_args_are_self_contained(source_code):
|
|||
|
||||
def visit_ImportFrom(self, node):
|
||||
"""Handle from imports like 'from datetime import datetime'."""
|
||||
module = node.module or ''
|
||||
module = node.module or ""
|
||||
for name in node.names:
|
||||
actual_name = name.asname or name.name
|
||||
self.from_imports[actual_name] = (module, name.name, actual_name)
|
||||
|
@ -187,9 +181,11 @@ def validate_args_are_self_contained(source_code):
|
|||
self.assigned_names.update(target_names)
|
||||
|
||||
# Special handling for enumerate
|
||||
if (isinstance(node.iter, ast.Call) and
|
||||
isinstance(node.iter.func, ast.Name) and
|
||||
node.iter.func.id == 'enumerate'):
|
||||
if (
|
||||
isinstance(node.iter, ast.Call)
|
||||
and isinstance(node.iter.func, ast.Name)
|
||||
and node.iter.func.id == "enumerate"
|
||||
):
|
||||
# For enumerate, if we have "for i, x in enumerate(...)",
|
||||
# both i and x should be marked as assigned
|
||||
if isinstance(node.target, ast.Tuple):
|
||||
|
@ -201,19 +197,19 @@ def validate_args_are_self_contained(source_code):
|
|||
self.generic_visit(node)
|
||||
|
||||
def visit_Name(self, node):
|
||||
if (isinstance(node.ctx, ast.Load) and not (
|
||||
node.id == "tool" or
|
||||
node.id in builtin_names or
|
||||
node.id in arg_names or
|
||||
node.id == 'self' or
|
||||
node.id in self.assigned_names
|
||||
)):
|
||||
if isinstance(node.ctx, ast.Load) and not (
|
||||
node.id == "tool"
|
||||
or node.id in builtin_names
|
||||
or node.id in arg_names
|
||||
or node.id == "self"
|
||||
or node.id in self.assigned_names
|
||||
):
|
||||
if node.id not in self.from_imports and node.id not in self.imports:
|
||||
self.undefined_names.add(node.id)
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
# Skip self.something
|
||||
if not (isinstance(node.value, ast.Name) and node.value.id == 'self'):
|
||||
if not (isinstance(node.value, ast.Name) and node.value.id == "self"):
|
||||
self.generic_visit(node)
|
||||
|
||||
checker = NameChecker()
|
||||
|
@ -226,6 +222,7 @@ def validate_args_are_self_contained(source_code):
|
|||
"""
|
||||
)
|
||||
|
||||
|
||||
AUTHORIZED_TYPES = [
|
||||
"string",
|
||||
"boolean",
|
||||
|
@ -273,7 +270,6 @@ class Tool:
|
|||
super().__init_subclass__(**kwargs)
|
||||
validate_after_init(cls, do_validate_forward=False)
|
||||
|
||||
|
||||
def validate_arguments(self, do_validate_forward: bool = True):
|
||||
required_attributes = {
|
||||
"description": str,
|
||||
|
@ -359,13 +355,13 @@ class {class_name}(Tool):
|
|||
|
||||
def add_self_argument(source_code: str) -> str:
|
||||
"""Add 'self' as first argument to a function definition if not present."""
|
||||
pattern = r'def forward\(((?!self)[^)]*)\)'
|
||||
pattern = r"def forward\(((?!self)[^)]*)\)"
|
||||
|
||||
def replacement(match):
|
||||
args = match.group(1).strip()
|
||||
if args: # If there are other arguments
|
||||
return f'def forward(self, {args})'
|
||||
return 'def forward(self)'
|
||||
return f"def forward(self, {args})"
|
||||
return "def forward(self)"
|
||||
|
||||
return re.sub(pattern, replacement, source_code)
|
||||
|
||||
|
@ -391,11 +387,7 @@ class {class_name}(Tool):
|
|||
# Save app file
|
||||
app_file = os.path.join(output_dir, "app.py")
|
||||
with open(app_file, "w", encoding="utf-8") as f:
|
||||
f.write(
|
||||
APP_FILE_TEMPLATE.format(
|
||||
class_name=class_name
|
||||
)
|
||||
)
|
||||
f.write(APP_FILE_TEMPLATE.format(class_name=class_name))
|
||||
|
||||
# Save requirements file
|
||||
requirements_file = os.path.join(output_dir, "requirements.txt")
|
||||
|
@ -457,7 +449,7 @@ class {class_name}(Tool):
|
|||
self.save(work_dir)
|
||||
print(work_dir)
|
||||
with open(work_dir + "/tool.py", "r") as f:
|
||||
print('\n'.join(f.readlines()))
|
||||
print("\n".join(f.readlines()))
|
||||
logger.info(
|
||||
f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}"
|
||||
)
|
||||
|
@ -575,7 +567,6 @@ class {class_name}(Tool):
|
|||
|
||||
return tool_class(**kwargs)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_space(
|
||||
space_id: str,
|
||||
|
@ -702,7 +693,11 @@ class {class_name}(Tool):
|
|||
return output
|
||||
|
||||
return SpaceToolWrapper(
|
||||
space_id=space_id, name=name, description=description, api_name=api_name, token=token
|
||||
space_id=space_id,
|
||||
name=name,
|
||||
description=description,
|
||||
api_name=api_name,
|
||||
token=token,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -859,7 +854,7 @@ def load_tool(
|
|||
model_repo_id: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Main function to quickly load a tool, be it on the Hub or in the Transformers library.
|
||||
|
@ -909,7 +904,11 @@ def load_tool(
|
|||
f"code that you have checked."
|
||||
)
|
||||
return Tool.from_hub(
|
||||
task_or_repo_id, model_repo_id=model_repo_id, token=token, trust_remote_code=trust_remote_code, **kwargs
|
||||
task_or_repo_id,
|
||||
model_repo_id=model_repo_id,
|
||||
token=token,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1028,7 +1027,7 @@ def tool(tool_function: Callable) -> Tool:
|
|||
raise TypeHintParsingException(
|
||||
"Tool return type not found: make sure your function has a return type hint!"
|
||||
)
|
||||
class_name = ''.join([el.title() for el in parameters['name'].split('_')])
|
||||
class_name = "".join([el.title() for el in parameters["name"].split("_")])
|
||||
|
||||
if parameters["return"]["type"] == "object":
|
||||
parameters["return"]["type"] = "any"
|
||||
|
@ -1086,7 +1085,9 @@ class Toolbox:
|
|||
"""Get all tools currently in the toolbox"""
|
||||
return self._tools
|
||||
|
||||
def show_tool_descriptions(self, tool_description_template: Optional[str] = None) -> str:
|
||||
def show_tool_descriptions(
|
||||
self, tool_description_template: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Returns the description of all tools in the toolbox
|
||||
|
||||
|
@ -1151,4 +1152,12 @@ class Toolbox:
|
|||
toolbox_description += f"\t{tool.name}: {tool.description}\n"
|
||||
return toolbox_description
|
||||
|
||||
__all__ = ["AUTHORIZED_TYPES", "Tool", "tool", "load_tool", "launch_gradio_demo", "Toolbox"]
|
||||
|
||||
__all__ = [
|
||||
"AUTHORIZED_TYPES",
|
||||
"Tool",
|
||||
"tool",
|
||||
"load_tool",
|
||||
"launch_gradio_demo",
|
||||
"Toolbox",
|
||||
]
|
||||
|
|
|
@ -267,4 +267,5 @@ def handle_agent_outputs(output, output_type=None):
|
|||
return _v(output)
|
||||
return output
|
||||
|
||||
|
||||
__all__ = ["AgentType", "AgentImage", "AgentText", "AgentAudio"]
|
|
@ -18,6 +18,7 @@ import json
|
|||
import re
|
||||
from typing import Tuple, Dict, Union
|
||||
import ast
|
||||
from rich.console import Console
|
||||
|
||||
from transformers.utils.import_utils import _is_package_available
|
||||
|
||||
|
@ -28,8 +29,6 @@ def is_pygments_available():
|
|||
return _pygments_available
|
||||
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
|
@ -111,6 +110,7 @@ def truncate_content(
|
|||
+ content[-MAX_LENGTH_TRUNCATE_CONTENT // 2 :]
|
||||
)
|
||||
|
||||
|
||||
class ImportFinder(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
self.packages = set()
|
||||
|
@ -118,13 +118,14 @@ class ImportFinder(ast.NodeVisitor):
|
|||
def visit_Import(self, node):
|
||||
for alias in node.names:
|
||||
# Get the base package name (before any dots)
|
||||
base_package = alias.name.split('.')[0]
|
||||
base_package = alias.name.split(".")[0]
|
||||
self.packages.add(base_package)
|
||||
|
||||
def visit_ImportFrom(self, node):
|
||||
if node.module: # for "from x import y" statements
|
||||
# Get the base package name (before any dots)
|
||||
base_package = node.module.split('.')[0]
|
||||
base_package = node.module.split(".")[0]
|
||||
self.packages.add(base_package)
|
||||
|
||||
|
||||
__all__ = []
|
|
@ -27,12 +27,13 @@ from agents.agents import (
|
|||
CodeAgent,
|
||||
JsonAgent,
|
||||
Toolbox,
|
||||
ToolCall
|
||||
ToolCall,
|
||||
)
|
||||
from agents.tools import tool
|
||||
from agents.default_tools import PythonInterpreterTool
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
|
||||
def get_new_path(suffix="") -> str:
|
||||
directory = tempfile.mkdtemp()
|
||||
return os.path.join(directory, str(uuid.uuid4()) + suffix)
|
||||
|
@ -60,6 +61,7 @@ Action:
|
|||
}
|
||||
"""
|
||||
|
||||
|
||||
def fake_json_llm_image(messages, stop_sequences=None, grammar=None) -> str:
|
||||
prompt = str(messages)
|
||||
|
||||
|
@ -82,6 +84,7 @@ Action:
|
|||
}
|
||||
"""
|
||||
|
||||
|
||||
def fake_code_llm(messages, stop_sequences=None, grammar=None) -> str:
|
||||
prompt = str(messages)
|
||||
if "special_marker" not in prompt:
|
||||
|
@ -179,9 +182,7 @@ class AgentTests(unittest.TestCase):
|
|||
assert output == "7.2904"
|
||||
|
||||
def test_fake_json_agent(self):
|
||||
agent = JsonAgent(
|
||||
tools=[PythonInterpreterTool()], llm_engine=fake_json_llm
|
||||
)
|
||||
agent = JsonAgent(tools=[PythonInterpreterTool()], llm_engine=fake_json_llm)
|
||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert isinstance(output, str)
|
||||
assert output == "7.2904"
|
||||
|
@ -209,9 +210,7 @@ Action:
|
|||
Args:
|
||||
prompt: The prompt
|
||||
"""
|
||||
return Image.open(
|
||||
Path(get_tests_dir("fixtures")) / "000000039769.png"
|
||||
)
|
||||
return Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png")
|
||||
|
||||
agent = JsonAgent(
|
||||
tools=[fake_image_generation_tool], llm_engine=fake_json_llm_image
|
||||
|
@ -221,9 +220,7 @@ Action:
|
|||
assert isinstance(agent.state["image.png"], Image.Image)
|
||||
|
||||
def test_fake_code_agent(self):
|
||||
agent = CodeAgent(
|
||||
tools=[PythonInterpreterTool()], llm_engine=fake_code_llm
|
||||
)
|
||||
agent = CodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_code_llm)
|
||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert isinstance(output, float)
|
||||
assert output == 7.2904
|
||||
|
@ -234,9 +231,7 @@ Action:
|
|||
)
|
||||
|
||||
def test_reset_conversations(self):
|
||||
agent = CodeAgent(
|
||||
tools=[PythonInterpreterTool()], llm_engine=fake_code_llm
|
||||
)
|
||||
agent = CodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_code_llm)
|
||||
output = agent.run("What is 2 multiplied by 3.6452?", reset=True)
|
||||
assert output == 7.2904
|
||||
assert len(agent.logs) == 4
|
||||
|
@ -299,9 +294,7 @@ Action:
|
|||
|
||||
# check that python_interpreter base tool does not get added to code agents
|
||||
agent = CodeAgent(tools=[], llm_engine=fake_code_llm, add_base_tools=True)
|
||||
assert (
|
||||
len(agent.toolbox.tools) == 2
|
||||
) # added final_answer tool + search
|
||||
assert len(agent.toolbox.tools) == 2 # added final_answer tool + search
|
||||
|
||||
def test_function_persistence_across_steps(self):
|
||||
agent = CodeAgent(
|
||||
|
|
|
@ -25,6 +25,7 @@ from pathlib import Path
|
|||
from typing import List
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
class SubprocessCallException(Exception):
|
||||
pass
|
||||
|
||||
|
@ -59,7 +60,7 @@ class DocCodeExtractor:
|
|||
@staticmethod
|
||||
def extract_python_code(content: str) -> List[str]:
|
||||
"""Extract Python code blocks from markdown content."""
|
||||
pattern = r'```(?:python|py)\n(.*?)\n```'
|
||||
pattern = r"```(?:python|py)\n(.*?)\n```"
|
||||
matches = re.finditer(pattern, content, re.DOTALL)
|
||||
return [match.group(1).strip() for match in matches]
|
||||
|
||||
|
@ -118,18 +119,27 @@ class TestDocs:
|
|||
|
||||
# Create and execute test script
|
||||
try:
|
||||
excluded_snippets = ["ToolCollection", "image_generation_tool", "from_langchain"]
|
||||
excluded_snippets = [
|
||||
"ToolCollection",
|
||||
"image_generation_tool",
|
||||
"from_langchain",
|
||||
]
|
||||
code_blocks = [
|
||||
block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", self.hf_token) for block in code_blocks
|
||||
if not any([snippet in block for snippet in excluded_snippets]) # Exclude these tools that take longer to run and add dependencies
|
||||
block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", self.hf_token)
|
||||
for block in code_blocks
|
||||
if not any(
|
||||
[snippet in block for snippet in excluded_snippets]
|
||||
) # Exclude these tools that take longer to run and add dependencies
|
||||
]
|
||||
test_script = self.extractor.create_test_script(code_blocks, self._tmpdir)
|
||||
run_command(self.launch_args + [str(test_script)])
|
||||
|
||||
except SubprocessCallException as e:
|
||||
pytest.fail(f"\nError while testing {doc_path.name}:\n{str(e)}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"\nUnexpected error while testing {doc_path.name}:\n{traceback.format_exc()}")
|
||||
except Exception:
|
||||
pytest.fail(
|
||||
f"\nUnexpected error while testing {doc_path.name}:\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup(self):
|
||||
|
@ -152,7 +162,5 @@ def pytest_generate_tests(metafunc):
|
|||
|
||||
# Parameterize with the markdown files
|
||||
metafunc.parametrize(
|
||||
"doc_path",
|
||||
test_class.md_files,
|
||||
ids=[f.stem for f in test_class.md_files]
|
||||
"doc_path", test_class.md_files, ids=[f.stem for f in test_class.md_files]
|
||||
)
|
|
@ -20,9 +20,10 @@ import numpy as np
|
|||
from PIL import Image
|
||||
|
||||
from transformers import is_torch_available
|
||||
from agents.types import AGENT_TYPE_MAPPING
|
||||
from agents.default_tools import FinalAnswerTool
|
||||
from transformers.testing_utils import get_tests_dir, require_torch
|
||||
from agents.types import AGENT_TYPE_MAPPING
|
||||
|
||||
from agents.default_tools import FinalAnswerTool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
|
|
@ -98,6 +98,7 @@ class ToolTesterMixin:
|
|||
agent_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, agent_type))
|
||||
|
||||
|
||||
class ToolTests(unittest.TestCase):
|
||||
def test_tool_init_with_decorator(self):
|
||||
@tool
|
||||
|
@ -163,19 +164,22 @@ class ToolTests(unittest.TestCase):
|
|||
assert coolfunc.output_type == "number"
|
||||
assert "docstring has no description for the argument" in str(e)
|
||||
|
||||
def test_tool_definition_needs_imports_in_function(self):
|
||||
def test_tool_definition_raises_error_imports_outside_function(self):
|
||||
with pytest.raises(Exception) as e:
|
||||
from datetime import datetime
|
||||
|
||||
@tool
|
||||
def get_current_time() -> str:
|
||||
"""
|
||||
Gets the current time.
|
||||
"""
|
||||
return str(datetime.now())
|
||||
|
||||
assert "datetime" in str(e)
|
||||
|
||||
# Also test with classic definition
|
||||
with pytest.raises(Exception) as e:
|
||||
|
||||
class GetCurrentTimeTool(Tool):
|
||||
name = "get_current_time_tool"
|
||||
description = "Gets the current time"
|
||||
|
@ -184,14 +188,17 @@ class ToolTests(unittest.TestCase):
|
|||
|
||||
def forward(self):
|
||||
return str(datetime.now())
|
||||
|
||||
assert "datetime" in str(e)
|
||||
|
||||
def test_tool_definition_raises_no_error_imports_in_function(self):
|
||||
@tool
|
||||
def get_current_time() -> str:
|
||||
"""
|
||||
Gets the current time.
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
return str(datetime.now())
|
||||
|
||||
class GetCurrentTimeTool(Tool):
|
||||
|
|
|
@ -5,6 +5,7 @@ import tempfile
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def str_to_bool(value) -> int:
|
||||
"""
|
||||
Converts a string representation of truth to `True` (1) or `False` (0).
|
||||
|
@ -28,10 +29,13 @@ def get_int_from_env(env_keys, default):
|
|||
return val
|
||||
return default
|
||||
|
||||
|
||||
def parse_flag_from_env(key, default=False):
|
||||
"""Returns truthy value for `key` from the env if available else the default."""
|
||||
value = os.environ.get(key, str(default))
|
||||
return str_to_bool(value) == 1 # As its name indicates `str_to_bool` actually returns an int...
|
||||
return (
|
||||
str_to_bool(value) == 1
|
||||
) # As its name indicates `str_to_bool` actually returns an int...
|
||||
|
||||
|
||||
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
|
||||
|
|
Loading…
Reference in New Issue