Ruff formatting
This commit is contained in:
parent
851e177e71
commit
67deb6808f
|
@ -24,10 +24,25 @@ from transformers.utils import (
|
|||
|
||||
|
||||
_import_structure = {
|
||||
"agents": ["Agent", "CodeAgent", "ManagedAgent", "ReactAgent", "CodeAgent", "JsonAgent", "Toolbox"],
|
||||
"agents": [
|
||||
"Agent",
|
||||
"CodeAgent",
|
||||
"ManagedAgent",
|
||||
"ReactAgent",
|
||||
"CodeAgent",
|
||||
"JsonAgent",
|
||||
"Toolbox",
|
||||
],
|
||||
"llm_engine": ["HfApiEngine", "TransformersEngine"],
|
||||
"monitoring": ["stream_to_gradio"],
|
||||
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool", "tool"],
|
||||
"tools": [
|
||||
"PipelineTool",
|
||||
"Tool",
|
||||
"ToolCollection",
|
||||
"launch_gradio_demo",
|
||||
"load_tool",
|
||||
"tool",
|
||||
],
|
||||
}
|
||||
|
||||
try:
|
||||
|
@ -45,10 +60,25 @@ else:
|
|||
_import_structure["translation"] = ["TranslationTool"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .agents import Agent, CodeAgent, ManagedAgent, ReactAgent, CodeAgent, JsonAgent, Toolbox
|
||||
from .agents import (
|
||||
Agent,
|
||||
CodeAgent,
|
||||
ManagedAgent,
|
||||
ReactAgent,
|
||||
CodeAgent,
|
||||
JsonAgent,
|
||||
Toolbox,
|
||||
)
|
||||
from .llm_engine import HfApiEngine, TransformersEngine
|
||||
from .monitoring import stream_to_gradio
|
||||
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool, tool
|
||||
from .tools import (
|
||||
PipelineTool,
|
||||
Tool,
|
||||
ToolCollection,
|
||||
launch_gradio_demo,
|
||||
load_tool,
|
||||
tool,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
|
@ -66,4 +96,6 @@ if TYPE_CHECKING:
|
|||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__, globals()["__file__"], _import_structure, module_spec=__spec__
|
||||
)
|
||||
|
|
|
@ -19,7 +19,11 @@ import uuid
|
|||
|
||||
import numpy as np
|
||||
|
||||
from transformers.utils import is_soundfile_availble, is_torch_available, is_vision_available
|
||||
from transformers.utils import (
|
||||
is_soundfile_availble,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
import logging
|
||||
|
||||
|
||||
|
@ -108,7 +112,9 @@ class AgentImage(AgentType, ImageType):
|
|||
elif isinstance(value, np.ndarray):
|
||||
self._tensor = torch.from_numpy(value)
|
||||
else:
|
||||
raise TypeError(f"Unsupported type for {self.__class__.__name__}: {type(value)}")
|
||||
raise TypeError(
|
||||
f"Unsupported type for {self.__class__.__name__}: {type(value)}"
|
||||
)
|
||||
|
||||
def _ipython_display_(self, include=None, exclude=None):
|
||||
"""
|
||||
|
@ -159,7 +165,7 @@ class AgentImage(AgentType, ImageType):
|
|||
|
||||
return self._path
|
||||
|
||||
def save(self, output_bytes, format : str = None, **params):
|
||||
def save(self, output_bytes, format: str = None, **params):
|
||||
"""
|
||||
Saves the image to a file.
|
||||
Args:
|
||||
|
@ -243,7 +249,9 @@ if is_torch_available():
|
|||
|
||||
def handle_agent_inputs(*args, **kwargs):
|
||||
args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args]
|
||||
kwargs = {k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()}
|
||||
kwargs = {
|
||||
k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()
|
||||
}
|
||||
return args, kwargs
|
||||
|
||||
|
||||
|
|
376
agents/agents.py
376
agents/agents.py
|
@ -15,12 +15,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
||||
from dataclasses import dataclass
|
||||
from rich.syntax import Syntax
|
||||
|
||||
from langfuse.decorators import langfuse_context, observe
|
||||
|
||||
from transformers.utils import is_torch_available
|
||||
|
||||
from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content
|
||||
|
@ -51,6 +49,7 @@ from .tools import (
|
|||
|
||||
HUGGINGFACE_DEFAULT_TOOLS = {}
|
||||
|
||||
|
||||
class AgentError(Exception):
|
||||
"""Base class for other agent-related exceptions"""
|
||||
|
||||
|
@ -60,7 +59,6 @@ class AgentError(Exception):
|
|||
console.print(f"[bold red]{message}[/bold red]")
|
||||
|
||||
|
||||
|
||||
class AgentParsingError(AgentError):
|
||||
"""Exception raised for errors in parsing in the agent"""
|
||||
|
||||
|
@ -84,12 +82,14 @@ class AgentGenerationError(AgentError):
|
|||
|
||||
pass
|
||||
|
||||
|
||||
class AgentStep:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionStep(AgentStep):
|
||||
tool_call: str | None = None
|
||||
tool_call: Dict[str, str] | None = None
|
||||
start_time: float | None = None
|
||||
step_end_time: float | None = None
|
||||
iteration: int | None = None
|
||||
|
@ -97,32 +97,43 @@ class ActionStep(AgentStep):
|
|||
error: AgentError | None = None
|
||||
step_duration: float | None = None
|
||||
llm_output: str | None = None
|
||||
observation: str | None = None
|
||||
agent_memory: List[Dict[str, str]] | None = None
|
||||
rationale: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlanningStep(AgentStep):
|
||||
plan: str
|
||||
facts: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskStep(AgentStep):
|
||||
task: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemPromptStep(AgentStep):
|
||||
system_prompt: str
|
||||
|
||||
|
||||
def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str:
|
||||
def format_prompt_with_tools(
|
||||
toolbox: Toolbox, prompt_template: str, tool_description_template: str
|
||||
) -> str:
|
||||
tool_descriptions = toolbox.show_tool_descriptions(tool_description_template)
|
||||
prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions)
|
||||
|
||||
if "{{tool_names}}" in prompt:
|
||||
prompt = prompt.replace("{{tool_names}}", ", ".join([f"'{tool_name}'" for tool_name in toolbox.tools.keys()]))
|
||||
prompt = prompt.replace(
|
||||
"{{tool_names}}",
|
||||
", ".join([f"'{tool_name}'" for tool_name in toolbox.tools.keys()]),
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def show_agents_descriptions(managed_agents: list):
|
||||
def show_agents_descriptions(managed_agents: Dict):
|
||||
managed_agents_descriptions = """
|
||||
You can also give requests to team members.
|
||||
Calling a team member works the same as for calling a tool: simply, the only argument you can give in the call is 'request', a long string explaning your request.
|
||||
|
@ -133,16 +144,24 @@ Here is a list of the team members that you can call:"""
|
|||
return managed_agents_descriptions
|
||||
|
||||
|
||||
def format_prompt_with_managed_agents_descriptions(prompt_template, managed_agents=None) -> str:
|
||||
def format_prompt_with_managed_agents_descriptions(
|
||||
prompt_template, managed_agents=None
|
||||
) -> str:
|
||||
if managed_agents is not None:
|
||||
return prompt_template.replace("<<managed_agents_descriptions>>", show_agents_descriptions(managed_agents))
|
||||
return prompt_template.replace(
|
||||
"<<managed_agents_descriptions>>", show_agents_descriptions(managed_agents)
|
||||
)
|
||||
else:
|
||||
return prompt_template.replace("<<managed_agents_descriptions>>", "")
|
||||
|
||||
|
||||
def format_prompt_with_imports(prompt_template: str, authorized_imports: List[str]) -> str:
|
||||
def format_prompt_with_imports(
|
||||
prompt_template: str, authorized_imports: List[str]
|
||||
) -> str:
|
||||
if "<<authorized_imports>>" not in prompt_template:
|
||||
raise AgentError("Tag '<<authorized_imports>>' should be provided in the prompt.")
|
||||
raise AgentError(
|
||||
"Tag '<<authorized_imports>>' should be provided in the prompt."
|
||||
)
|
||||
return prompt_template.replace("<<authorized_imports>>", str(authorized_imports))
|
||||
|
||||
|
||||
|
@ -150,7 +169,7 @@ class BaseAgent:
|
|||
def __init__(
|
||||
self,
|
||||
tools: Union[List[Tool], Toolbox],
|
||||
llm_engine: Callable = None,
|
||||
llm_engine: Optional[Callable[[List[Dict[str, str]]], str]] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
tool_description_template: Optional[str] = None,
|
||||
additional_args: Dict = {},
|
||||
|
@ -159,10 +178,12 @@ class BaseAgent:
|
|||
add_base_tools: bool = False,
|
||||
verbose: bool = False,
|
||||
grammar: Optional[Dict[str, str]] = None,
|
||||
managed_agents: Optional[List] = None,
|
||||
managed_agents: Optional[Dict] = None,
|
||||
step_callbacks: Optional[List[Callable]] = None,
|
||||
monitor_metrics: bool = True,
|
||||
):
|
||||
if llm_engine is None:
|
||||
llm_engine = HfApiEngine()
|
||||
if system_prompt is None:
|
||||
system_prompt = CODE_SYSTEM_PROMPT
|
||||
if tool_parser is None:
|
||||
|
@ -171,14 +192,16 @@ class BaseAgent:
|
|||
self.llm_engine = llm_engine
|
||||
self.system_prompt_template = system_prompt
|
||||
self.tool_description_template = (
|
||||
tool_description_template if tool_description_template else DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
||||
tool_description_template
|
||||
if tool_description_template
|
||||
else DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
||||
)
|
||||
self.additional_args = additional_args
|
||||
self.max_iterations = max_iterations
|
||||
self.tool_parser = tool_parser
|
||||
self.grammar = grammar
|
||||
|
||||
self.managed_agents = None
|
||||
self.managed_agents = {}
|
||||
if managed_agents is not None:
|
||||
self.managed_agents = {agent.name: agent for agent in managed_agents}
|
||||
|
||||
|
@ -186,9 +209,13 @@ class BaseAgent:
|
|||
self._toolbox = tools
|
||||
if add_base_tools:
|
||||
if not is_torch_available():
|
||||
raise ImportError("Using the base tools requires torch to be installed.")
|
||||
raise ImportError(
|
||||
"Using the base tools requires torch to be installed."
|
||||
)
|
||||
|
||||
self._toolbox.add_base_tools(add_python_interpreter=(self.__class__ == JsonAgent))
|
||||
self._toolbox.add_base_tools(
|
||||
add_python_interpreter=(self.__class__ == JsonAgent)
|
||||
)
|
||||
else:
|
||||
self._toolbox = Toolbox(tools, add_base_tools=add_base_tools)
|
||||
self._toolbox.add_tool(FinalAnswerTool())
|
||||
|
@ -196,7 +223,9 @@ class BaseAgent:
|
|||
self.system_prompt = format_prompt_with_tools(
|
||||
self._toolbox, self.system_prompt_template, self.tool_description_template
|
||||
)
|
||||
self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents)
|
||||
self.system_prompt = format_prompt_with_managed_agents_descriptions(
|
||||
self.system_prompt, self.managed_agents
|
||||
)
|
||||
self.prompt_messages = None
|
||||
self.logs = []
|
||||
self.task = None
|
||||
|
@ -222,15 +251,20 @@ class BaseAgent:
|
|||
self.system_prompt_template,
|
||||
self.tool_description_template,
|
||||
)
|
||||
self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents)
|
||||
self.system_prompt = format_prompt_with_managed_agents_descriptions(
|
||||
self.system_prompt, self.managed_agents
|
||||
)
|
||||
if hasattr(self, "authorized_imports"):
|
||||
self.system_prompt = format_prompt_with_imports(
|
||||
self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports))
|
||||
self.system_prompt,
|
||||
list(set(LIST_SAFE_MODULES) | set(getattr(self, "authorized_imports"))),
|
||||
)
|
||||
|
||||
return self.system_prompt
|
||||
|
||||
def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]:
|
||||
def write_inner_memory_from_logs(
|
||||
self, summary_mode: Optional[bool] = False
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages
|
||||
that can be used as input to the LLM.
|
||||
|
@ -253,7 +287,10 @@ class BaseAgent:
|
|||
memory.append(thought_message)
|
||||
|
||||
if not summary_mode:
|
||||
thought_message = {"role": MessageRole.ASSISTANT, "content": "[PLAN]:\n" + step_log.plan.strip()}
|
||||
thought_message = {
|
||||
"role": MessageRole.ASSISTANT,
|
||||
"content": "[PLAN]:\n" + step_log.plan.strip(),
|
||||
}
|
||||
memory.append(thought_message)
|
||||
|
||||
elif isinstance(step_log, TaskStep):
|
||||
|
@ -265,13 +302,17 @@ class BaseAgent:
|
|||
|
||||
elif isinstance(step_log, ActionStep):
|
||||
if step_log.llm_output is not None and not summary_mode:
|
||||
thought_message = {"role": MessageRole.ASSISTANT, "content": step_log.llm_output.strip()}
|
||||
thought_message = {
|
||||
"role": MessageRole.ASSISTANT,
|
||||
"content": step_log.llm_output.strip(),
|
||||
}
|
||||
memory.append(thought_message)
|
||||
|
||||
if step_log.tool_call is not None and summary_mode:
|
||||
tool_call_message = {
|
||||
"role": MessageRole.ASSISTANT,
|
||||
"content": f"[STEP {i} TOOL CALL]: " + str(step_log.tool_call).strip(),
|
||||
"content": f"[STEP {i} TOOL CALL]: "
|
||||
+ str(step_log.tool_call).strip(),
|
||||
}
|
||||
memory.append(tool_call_message)
|
||||
|
||||
|
@ -284,15 +325,21 @@ class BaseAgent:
|
|||
)
|
||||
elif step_log.observation is not None:
|
||||
message_content = f"[OUTPUT OF STEP {i}] -> Observation:\n{step_log.observation}"
|
||||
tool_response_message = {"role": MessageRole.TOOL_RESPONSE, "content": message_content}
|
||||
tool_response_message = {
|
||||
"role": MessageRole.TOOL_RESPONSE,
|
||||
"content": message_content,
|
||||
}
|
||||
memory.append(tool_response_message)
|
||||
|
||||
return memory
|
||||
|
||||
def get_succinct_logs(self):
|
||||
return [{key: value for key, value in log.items() if key != "agent_memory"} for log in self.logs]
|
||||
return [
|
||||
{key: value for key, value in log.items() if key != "agent_memory"}
|
||||
for log in self.logs
|
||||
]
|
||||
|
||||
def extract_action(self, llm_output: str, split_token: str) -> str:
|
||||
def extract_action(self, llm_output: str, split_token: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Parse action from the LLM output
|
||||
|
||||
|
@ -312,54 +359,6 @@ class BaseAgent:
|
|||
)
|
||||
return rationale.strip(), action.strip()
|
||||
|
||||
def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any:
|
||||
"""
|
||||
Execute tool with the provided input and returns the result.
|
||||
This method replaces arguments with the actual values from the state if they refer to state variables.
|
||||
|
||||
Args:
|
||||
tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox).
|
||||
arguments (Dict[str, str]): Arguments passed to the Tool.
|
||||
"""
|
||||
available_tools = self.toolbox.tools
|
||||
if self.managed_agents is not None:
|
||||
available_tools = {**available_tools, **self.managed_agents}
|
||||
if tool_name not in available_tools:
|
||||
error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
|
||||
console.print(f"[bold red]{error_msg}")
|
||||
raise AgentExecutionError(error_msg)
|
||||
|
||||
try:
|
||||
if isinstance(arguments, str):
|
||||
observation = available_tools[tool_name](arguments)
|
||||
elif isinstance(arguments, dict):
|
||||
for key, value in arguments.items():
|
||||
if isinstance(value, str) and value in self.state:
|
||||
arguments[key] = self.state[value]
|
||||
observation = available_tools[tool_name](**arguments)
|
||||
else:
|
||||
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
|
||||
console.print(f"[bold red]{error_msg}")
|
||||
raise AgentExecutionError(error_msg)
|
||||
return observation
|
||||
except Exception as e:
|
||||
if tool_name in self.toolbox.tools:
|
||||
tool_description = get_tool_description_with_args(available_tools[tool_name])
|
||||
error_msg = (
|
||||
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
|
||||
f"As a reminder, this tool's description is the following:\n{tool_description}"
|
||||
)
|
||||
console.print(f"[bold red]{error_msg}")
|
||||
raise AgentExecutionError(error_msg)
|
||||
elif tool_name in self.managed_agents:
|
||||
error_msg = (
|
||||
f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
|
||||
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
|
||||
)
|
||||
console.print(f"[bold red]{error_msg}")
|
||||
raise AgentExecutionError(error_msg)
|
||||
|
||||
|
||||
def run(self, **kwargs):
|
||||
"""To be implemented in the child class"""
|
||||
raise NotImplementedError
|
||||
|
@ -382,8 +381,6 @@ class ReactAgent(BaseAgent):
|
|||
planning_interval: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if llm_engine is None:
|
||||
llm_engine = HfApiEngine()
|
||||
if system_prompt is None:
|
||||
system_prompt = CODE_SYSTEM_PROMPT
|
||||
if tool_description_template is None:
|
||||
|
@ -423,8 +420,67 @@ class ReactAgent(BaseAgent):
|
|||
console.print(f"[bold red]{error_msg}[/bold red]")
|
||||
return error_msg
|
||||
|
||||
@observe
|
||||
def run(self, task: str, stream: bool = False, reset: bool = True, oneshot: bool = False, **kwargs):
|
||||
def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any:
|
||||
"""
|
||||
Execute tool with the provided input and returns the result.
|
||||
This method replaces arguments with the actual values from the state if they refer to state variables.
|
||||
|
||||
Args:
|
||||
tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox).
|
||||
arguments (Dict[str, str]): Arguments passed to the Tool.
|
||||
"""
|
||||
available_tools = self.toolbox.tools
|
||||
if self.managed_agents is not None:
|
||||
available_tools = {**available_tools, **self.managed_agents}
|
||||
if tool_name not in available_tools:
|
||||
error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
|
||||
console.print(f"[bold red]{error_msg}")
|
||||
raise AgentExecutionError(error_msg)
|
||||
|
||||
try:
|
||||
if isinstance(arguments, str):
|
||||
observation = available_tools[tool_name](arguments)
|
||||
elif isinstance(arguments, dict):
|
||||
for key, value in arguments.items():
|
||||
if isinstance(value, str) and value in self.state:
|
||||
arguments[key] = self.state[value]
|
||||
observation = available_tools[tool_name](**arguments)
|
||||
else:
|
||||
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
|
||||
console.print(f"[bold red]{error_msg}")
|
||||
raise AgentExecutionError(error_msg)
|
||||
return observation
|
||||
except Exception as e:
|
||||
if tool_name in self.toolbox.tools:
|
||||
tool_description = get_tool_description_with_args(
|
||||
available_tools[tool_name]
|
||||
)
|
||||
error_msg = (
|
||||
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
|
||||
f"As a reminder, this tool's description is the following:\n{tool_description}"
|
||||
)
|
||||
console.print(f"[bold red]{error_msg}")
|
||||
raise AgentExecutionError(error_msg)
|
||||
elif tool_name in self.managed_agents:
|
||||
error_msg = (
|
||||
f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
|
||||
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
|
||||
)
|
||||
console.print(f"[bold red]{error_msg}")
|
||||
raise AgentExecutionError(error_msg)
|
||||
|
||||
def step(self, log_entry: ActionStep):
|
||||
"""To be implemented in children classes"""
|
||||
pass
|
||||
|
||||
def run(
|
||||
self,
|
||||
task: str,
|
||||
stream: bool = False,
|
||||
reset: bool = True,
|
||||
oneshot: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Runs the agent for the given task.
|
||||
|
||||
|
@ -441,10 +497,11 @@ class ReactAgent(BaseAgent):
|
|||
agent.run("What is the result of 2 power 3.7384?")
|
||||
```
|
||||
"""
|
||||
print("LANGFUSE REF:", langfuse_context.get_current_trace_url())
|
||||
self.task = task
|
||||
if len(kwargs) > 0:
|
||||
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
|
||||
self.task += (
|
||||
f"\nYou have been provided with these initial arguments: {str(kwargs)}."
|
||||
)
|
||||
self.state = kwargs.copy()
|
||||
|
||||
self.initialize_system_prompt()
|
||||
|
@ -460,7 +517,7 @@ class ReactAgent(BaseAgent):
|
|||
else:
|
||||
self.logs.append(system_prompt_step)
|
||||
|
||||
console.rule("[bold]New task", characters='=')
|
||||
console.rule("[bold]New task", characters="=")
|
||||
console.print(self.task)
|
||||
self.logs.append(TaskStep(task=task))
|
||||
|
||||
|
@ -489,8 +546,13 @@ class ReactAgent(BaseAgent):
|
|||
step_start_time = time.time()
|
||||
step_log = ActionStep(iteration=iteration, start_time=step_start_time)
|
||||
try:
|
||||
if self.planning_interval is not None and iteration % self.planning_interval == 0:
|
||||
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
|
||||
if (
|
||||
self.planning_interval is not None
|
||||
and iteration % self.planning_interval == 0
|
||||
):
|
||||
self.planning_step(
|
||||
task, is_first_step=(iteration == 0), iteration=iteration
|
||||
)
|
||||
console.rule("[bold]New step")
|
||||
self.step(step_log)
|
||||
if step_log.final_answer is not None:
|
||||
|
@ -530,8 +592,13 @@ class ReactAgent(BaseAgent):
|
|||
step_start_time = time.time()
|
||||
step_log = ActionStep(iteration=iteration, start_time=step_start_time)
|
||||
try:
|
||||
if self.planning_interval is not None and iteration % self.planning_interval == 0:
|
||||
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
|
||||
if (
|
||||
self.planning_interval is not None
|
||||
and iteration % self.planning_interval == 0
|
||||
):
|
||||
self.planning_step(
|
||||
task, is_first_step=(iteration == 0), iteration=iteration
|
||||
)
|
||||
console.rule("[bold]New step")
|
||||
self.step(step_log)
|
||||
if step_log.final_answer is not None:
|
||||
|
@ -559,7 +626,7 @@ class ReactAgent(BaseAgent):
|
|||
|
||||
return final_answer
|
||||
|
||||
def planning_step(self, task, is_first_step: bool = False, iteration: int = None):
|
||||
def planning_step(self, task, is_first_step: bool, iteration: int):
|
||||
"""
|
||||
Used periodically by the agent to plan the next steps to reach the objective.
|
||||
|
||||
|
@ -569,7 +636,10 @@ class ReactAgent(BaseAgent):
|
|||
iteration (`int`): The number of the current step, used as an indication for the LLM.
|
||||
"""
|
||||
if is_first_step:
|
||||
message_prompt_facts = {"role": MessageRole.SYSTEM, "content": SYSTEM_PROMPT_FACTS}
|
||||
message_prompt_facts = {
|
||||
"role": MessageRole.SYSTEM,
|
||||
"content": SYSTEM_PROMPT_FACTS,
|
||||
}
|
||||
message_prompt_task = {
|
||||
"role": MessageRole.USER,
|
||||
"content": f"""Here is the task:
|
||||
|
@ -589,15 +659,20 @@ Now begin!""",
|
|||
"role": MessageRole.USER,
|
||||
"content": USER_PROMPT_PLAN.format(
|
||||
task=task,
|
||||
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
|
||||
tool_descriptions=self._toolbox.show_tool_descriptions(
|
||||
self.tool_description_template
|
||||
),
|
||||
managed_agents_descriptions=(
|
||||
show_agents_descriptions(self.managed_agents) if self.managed_agents is not None else ""
|
||||
show_agents_descriptions(self.managed_agents)
|
||||
if self.managed_agents is not None
|
||||
else ""
|
||||
),
|
||||
answer_facts=answer_facts,
|
||||
),
|
||||
}
|
||||
answer_plan = self.llm_engine(
|
||||
[message_system_prompt_plan, message_user_prompt_plan], stop_sequences=["<end_plan>"]
|
||||
[message_system_prompt_plan, message_user_prompt_plan],
|
||||
stop_sequences=["<end_plan>"],
|
||||
)
|
||||
|
||||
final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task:
|
||||
|
@ -608,10 +683,12 @@ Now begin!""",
|
|||
```
|
||||
{answer_facts}
|
||||
```""".strip()
|
||||
self.logs.append(PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction))
|
||||
self.logs.append(
|
||||
PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)
|
||||
)
|
||||
console.rule("[orange]Initial plan")
|
||||
console.print(final_plan_redaction)
|
||||
else: # update plan
|
||||
else: # update plan
|
||||
agent_memory = self.write_inner_memory_from_logs(
|
||||
summary_mode=False
|
||||
) # This will not log the plan but will log facts
|
||||
|
@ -625,7 +702,9 @@ Now begin!""",
|
|||
"role": MessageRole.USER,
|
||||
"content": USER_PROMPT_FACTS_UPDATE,
|
||||
}
|
||||
facts_update = self.llm_engine([facts_update_system_prompt] + agent_memory + [facts_update_message])
|
||||
facts_update = self.llm_engine(
|
||||
[facts_update_system_prompt] + agent_memory + [facts_update_message]
|
||||
)
|
||||
|
||||
# Redact updated plan
|
||||
plan_update_message = {
|
||||
|
@ -636,25 +715,34 @@ Now begin!""",
|
|||
"role": MessageRole.USER,
|
||||
"content": USER_PROMPT_PLAN_UPDATE.format(
|
||||
task=task,
|
||||
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
|
||||
tool_descriptions=self._toolbox.show_tool_descriptions(
|
||||
self.tool_description_template
|
||||
),
|
||||
managed_agents_descriptions=(
|
||||
show_agents_descriptions(self.managed_agents) if self.managed_agents is not None else ""
|
||||
show_agents_descriptions(self.managed_agents)
|
||||
if self.managed_agents is not None
|
||||
else ""
|
||||
),
|
||||
facts_update=facts_update,
|
||||
remaining_steps=(self.max_iterations - iteration),
|
||||
),
|
||||
}
|
||||
plan_update = self.llm_engine(
|
||||
[plan_update_message] + agent_memory + [plan_update_message_user], stop_sequences=["<end_plan>"]
|
||||
[plan_update_message] + agent_memory + [plan_update_message_user],
|
||||
stop_sequences=["<end_plan>"],
|
||||
)
|
||||
|
||||
# Log final facts and plan
|
||||
final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(task=task, plan_update=plan_update)
|
||||
final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(
|
||||
task=task, plan_update=plan_update
|
||||
)
|
||||
final_facts_redaction = f"""Here is the updated list of the facts that I know:
|
||||
```
|
||||
{facts_update}
|
||||
```"""
|
||||
self.logs.append(PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction))
|
||||
self.logs.append(
|
||||
PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)
|
||||
)
|
||||
console.rule("[orange]Updated plan")
|
||||
console.print(final_plan_redaction)
|
||||
|
||||
|
@ -705,14 +793,20 @@ class JsonAgent(ReactAgent):
|
|||
log_entry.agent_memory = agent_memory.copy()
|
||||
|
||||
if self.verbose:
|
||||
console.rule("[italic]Calling LLM engine with this last message:", align="left")
|
||||
console.rule(
|
||||
"[italic]Calling LLM engine with this last message:", align="left"
|
||||
)
|
||||
console.print(self.prompt_messages[-1])
|
||||
console.rule()
|
||||
|
||||
try:
|
||||
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
|
||||
additional_args = (
|
||||
{"grammar": self.grammar} if self.grammar is not None else {}
|
||||
)
|
||||
llm_output = self.llm_engine(
|
||||
self.prompt_messages, stop_sequences=["<end_action>", "Observation:"], **additional_args
|
||||
self.prompt_messages,
|
||||
stop_sequences=["<end_action>", "Observation:"],
|
||||
**additional_args,
|
||||
)
|
||||
log_entry.llm_output = llm_output
|
||||
except Exception as e:
|
||||
|
@ -721,9 +815,11 @@ class JsonAgent(ReactAgent):
|
|||
if self.verbose:
|
||||
console.rule("[italic]Output message of the LLM:")
|
||||
console.print(llm_output)
|
||||
|
||||
|
||||
# Parse
|
||||
rationale, action = self.extract_action(llm_output=llm_output, split_token="Action:")
|
||||
rationale, action = self.extract_action(
|
||||
llm_output=llm_output, split_token="Action:"
|
||||
)
|
||||
|
||||
try:
|
||||
tool_name, arguments = self.tool_parser(action)
|
||||
|
@ -807,12 +903,18 @@ class CodeAgent(ReactAgent):
|
|||
)
|
||||
|
||||
self.python_evaluator = evaluate_python_code
|
||||
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
|
||||
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
|
||||
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
|
||||
self.additional_authorized_imports = (
|
||||
additional_authorized_imports if additional_authorized_imports else []
|
||||
)
|
||||
self.authorized_imports = list(
|
||||
set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)
|
||||
)
|
||||
self.system_prompt = self.system_prompt.replace(
|
||||
"<<authorized_imports>>", str(self.authorized_imports)
|
||||
)
|
||||
self.custom_tools = {}
|
||||
|
||||
def step(self, log_entry: Dict[str, Any]):
|
||||
def step(self, log_entry: ActionStep):
|
||||
"""
|
||||
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
|
||||
The errors are raised here, they are caught and logged in the run() method.
|
||||
|
@ -825,14 +927,20 @@ class CodeAgent(ReactAgent):
|
|||
log_entry.agent_memory = agent_memory.copy()
|
||||
|
||||
if self.verbose:
|
||||
console.rule("[italic]Calling LLM engine with these last messages:", align="left")
|
||||
console.rule(
|
||||
"[italic]Calling LLM engine with these last messages:", align="left"
|
||||
)
|
||||
console.print(self.prompt_messages[-2:])
|
||||
console.rule()
|
||||
|
||||
try:
|
||||
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
|
||||
additional_args = (
|
||||
{"grammar": self.grammar} if self.grammar is not None else {}
|
||||
)
|
||||
llm_output = self.llm_engine(
|
||||
self.prompt_messages, stop_sequences=["<end_action>", "Observation:"], **additional_args
|
||||
self.prompt_messages,
|
||||
stop_sequences=["<end_action>", "Observation:"],
|
||||
**additional_args,
|
||||
)
|
||||
log_entry.llm_output = llm_output
|
||||
except Exception as e:
|
||||
|
@ -840,13 +948,19 @@ class CodeAgent(ReactAgent):
|
|||
|
||||
if self.verbose:
|
||||
console.rule("[italic]Output message of the LLM:")
|
||||
console.print(Syntax(llm_output, lexer='markdown', background_color='default'))
|
||||
console.print(
|
||||
Syntax(llm_output, lexer="markdown", background_color="default")
|
||||
)
|
||||
|
||||
# Parse
|
||||
try:
|
||||
rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
|
||||
rationale, raw_code_action = self.extract_action(
|
||||
llm_output=llm_output, split_token="Code:"
|
||||
)
|
||||
except Exception as e:
|
||||
console.print(f"Error in extracting action, trying to parse the whole output. Error trace: {e}")
|
||||
console.print(
|
||||
f"Error in extracting action, trying to parse the whole output. Error trace: {e}"
|
||||
)
|
||||
rationale, raw_code_action = llm_output, llm_output
|
||||
|
||||
try:
|
||||
|
@ -856,14 +970,17 @@ class CodeAgent(ReactAgent):
|
|||
raise AgentParsingError(error_msg)
|
||||
|
||||
log_entry.rationale = rationale
|
||||
log_entry.tool_call = {"tool_name": "code interpreter", "tool_arguments": code_action}
|
||||
log_entry.tool_call = {
|
||||
"tool_name": "code interpreter",
|
||||
"tool_arguments": code_action,
|
||||
}
|
||||
|
||||
# Execute
|
||||
if self.verbose:
|
||||
console.rule("[italic]Agent thoughts")
|
||||
console.print(rationale)
|
||||
console.rule("[bold]Agent is executing the code below:", align="left")
|
||||
console.print(Syntax(code_action, lexer='python', background_color='default'))
|
||||
console.print(Syntax(code_action, lexer="python", background_color="default"))
|
||||
console.rule("", align="left")
|
||||
|
||||
try:
|
||||
|
@ -886,7 +1003,9 @@ class CodeAgent(ReactAgent):
|
|||
if result is not None:
|
||||
console.rule("Last output from code snippet:", align="left")
|
||||
console.print(str(result))
|
||||
observation += "Last output from code snippet:\n" + truncate_content(str(result))
|
||||
observation += "Last output from code snippet:\n" + truncate_content(
|
||||
str(result)
|
||||
)
|
||||
log_entry.observation = observation
|
||||
except Exception as e:
|
||||
error_msg = f"Code execution failed due to the following error:\n{str(e)}"
|
||||
|
@ -902,7 +1021,14 @@ class CodeAgent(ReactAgent):
|
|||
|
||||
|
||||
class ManagedAgent:
|
||||
def __init__(self, agent, name, description, additional_prompting=None, provide_run_summary=False):
|
||||
def __init__(
|
||||
self,
|
||||
agent,
|
||||
name,
|
||||
description,
|
||||
additional_prompting=None,
|
||||
provide_run_summary=False,
|
||||
):
|
||||
self.agent = agent
|
||||
self.name = name
|
||||
self.description = description
|
||||
|
@ -925,18 +1051,22 @@ Your final_answer WILL HAVE to contain these parts:
|
|||
|
||||
Put all these in your final_answer tool, everything that you do not pass as an argument to final_answer will be lost.
|
||||
And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback.
|
||||
<<additional_prompting>>"""
|
||||
{{additional_prompting}}"""
|
||||
if self.additional_prompting:
|
||||
full_task = full_task.replace("\n<<additional_prompting>>", self.additional_prompting).strip()
|
||||
full_task = full_task.replace(
|
||||
"\n{{additional_prompting}}", self.additional_prompting
|
||||
).strip()
|
||||
else:
|
||||
full_task = full_task.replace("\n<<additional_prompting>>", "").strip()
|
||||
full_task = full_task.replace("\n{{additional_prompting}}", "").strip()
|
||||
return full_task
|
||||
|
||||
def __call__(self, request, **kwargs):
|
||||
full_task = self.write_full_task(request)
|
||||
output = self.agent.run(full_task, **kwargs)
|
||||
if self.provide_run_summary:
|
||||
answer = f"Here is the final answer from your managed agent '{self.name}':\n"
|
||||
answer = (
|
||||
f"Here is the final answer from your managed agent '{self.name}':\n"
|
||||
)
|
||||
answer += str(output)
|
||||
answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n"
|
||||
for message in self.agent.write_inner_memory_from_logs(summary_mode=True):
|
||||
|
|
|
@ -105,7 +105,9 @@ def get_remote_tools(logger, organization="huggingface-tools"):
|
|||
tools = {}
|
||||
for space_info in spaces:
|
||||
repo_id = space_info.id
|
||||
resolved_config_file = hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space")
|
||||
resolved_config_file = hf_hub_download(
|
||||
repo_id, TOOL_CONFIG_FILE, repo_type="space"
|
||||
)
|
||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||
config = json.load(reader)
|
||||
task = repo_id.split("/")[-1]
|
||||
|
@ -131,7 +133,9 @@ class PythonInterpreterTool(Tool):
|
|||
if authorized_imports is None:
|
||||
self.authorized_imports = list(set(LIST_SAFE_MODULES))
|
||||
else:
|
||||
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports))
|
||||
self.authorized_imports = list(
|
||||
set(LIST_SAFE_MODULES) | set(authorized_imports)
|
||||
)
|
||||
self.inputs = {
|
||||
"code": {
|
||||
"type": "string",
|
||||
|
@ -145,7 +149,11 @@ class PythonInterpreterTool(Tool):
|
|||
|
||||
def forward(self, code):
|
||||
output = str(
|
||||
evaluate_python_code(code, static_tools=BASE_PYTHON_TOOLS, authorized_imports=self.authorized_imports)
|
||||
evaluate_python_code(
|
||||
code,
|
||||
static_tools=BASE_PYTHON_TOOLS,
|
||||
authorized_imports=self.authorized_imports,
|
||||
)
|
||||
)
|
||||
return output
|
||||
|
||||
|
@ -153,16 +161,21 @@ class PythonInterpreterTool(Tool):
|
|||
class FinalAnswerTool(Tool):
|
||||
name = "final_answer"
|
||||
description = "Provides a final answer to the given problem."
|
||||
inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}}
|
||||
inputs = {
|
||||
"answer": {"type": "any", "description": "The final answer to the problem"}
|
||||
}
|
||||
output_type = "any"
|
||||
|
||||
def forward(self, answer):
|
||||
return answer
|
||||
|
||||
|
||||
class UserInputTool(Tool):
|
||||
name = "user_input"
|
||||
description = "Asks for user's input on a specific question"
|
||||
inputs = {"question": {"type": "string", "description": "The question to ask the user"}}
|
||||
inputs = {
|
||||
"question": {"type": "string", "description": "The question to ask the user"}
|
||||
}
|
||||
output_type = "string"
|
||||
|
||||
def forward(self, question):
|
||||
|
|
|
@ -18,6 +18,7 @@ from .agent_types import AgentAudio, AgentImage, AgentText
|
|||
from .agents import BaseAgent, AgentStep, ActionStep
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
|
||||
"""Extract ChatMessage objects from agent steps"""
|
||||
if isinstance(step_log, ActionStep):
|
||||
|
@ -33,7 +34,9 @@ def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
|
|||
content=str(content),
|
||||
)
|
||||
if step_log.observation is not None:
|
||||
yield gr.ChatMessage(role="assistant", content=f"```\n{step_log.observation}\n```")
|
||||
yield gr.ChatMessage(
|
||||
role="assistant", content=f"```\n{step_log.observation}\n```"
|
||||
)
|
||||
if step_log.error is not None:
|
||||
yield gr.ChatMessage(
|
||||
role="assistant",
|
||||
|
@ -42,7 +45,13 @@ def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
|
|||
)
|
||||
|
||||
|
||||
def stream_to_gradio(agent, task: str, test_mode: bool = False, reset_agent_memory: bool=False, **kwargs):
|
||||
def stream_to_gradio(
|
||||
agent,
|
||||
task: str,
|
||||
test_mode: bool = False,
|
||||
reset_agent_memory: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
|
||||
|
||||
for step_log in agent.run(task, stream=True, reset=reset_agent_memory, **kwargs):
|
||||
|
@ -52,7 +61,10 @@ def stream_to_gradio(agent, task: str, test_mode: bool = False, reset_agent_memo
|
|||
final_answer = step_log # Last log is the run's final_answer
|
||||
|
||||
if isinstance(final_answer, AgentText):
|
||||
yield gr.ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{final_answer.to_string()}\n```")
|
||||
yield gr.ChatMessage(
|
||||
role="assistant",
|
||||
content=f"**Final answer:**\n```\n{final_answer.to_string()}\n```",
|
||||
)
|
||||
elif isinstance(final_answer, AgentImage):
|
||||
yield gr.ChatMessage(
|
||||
role="assistant",
|
||||
|
@ -67,10 +79,11 @@ def stream_to_gradio(agent, task: str, test_mode: bool = False, reset_agent_memo
|
|||
yield gr.ChatMessage(role="assistant", content=str(final_answer))
|
||||
|
||||
|
||||
class GradioUI():
|
||||
class GradioUI:
|
||||
"""A one-line interface to launch your agent in Gradio"""
|
||||
|
||||
def __init__(self, agent: BaseAgent):
|
||||
self.agent = agent
|
||||
self.agent = agent
|
||||
|
||||
def interact_with_agent(self, prompt, messages):
|
||||
messages.append(gr.ChatMessage(role="user", content=prompt))
|
||||
|
@ -83,10 +96,17 @@ class GradioUI():
|
|||
def run(self):
|
||||
with gr.Blocks() as demo:
|
||||
stored_message = gr.State([])
|
||||
chatbot = gr.Chatbot(label="Agent",
|
||||
type="messages",
|
||||
avatar_images=(None, "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png"))
|
||||
chatbot = gr.Chatbot(
|
||||
label="Agent",
|
||||
type="messages",
|
||||
avatar_images=(
|
||||
None,
|
||||
"https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
|
||||
),
|
||||
)
|
||||
text_input = gr.Textbox(lines=1, label="Chat Message")
|
||||
text_input.submit(lambda s: (s, ""), [text_input], [stored_message, text_input]).then(self.interact_with_agent, [stored_message, chatbot], [chatbot])
|
||||
text_input.submit(
|
||||
lambda s: (s, ""), [text_input], [stored_message, text_input]
|
||||
).then(self.interact_with_agent, [stored_message, chatbot], [chatbot])
|
||||
|
||||
demo.launch()
|
||||
demo.launch()
|
||||
|
|
|
@ -39,7 +39,9 @@ class MessageRole(str, Enum):
|
|||
return [r.value for r in cls]
|
||||
|
||||
|
||||
def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}):
|
||||
def get_clean_message_list(
|
||||
message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}
|
||||
):
|
||||
"""
|
||||
Subsequent messages with the same role will be concatenated to a single message.
|
||||
|
||||
|
@ -54,12 +56,17 @@ def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions:
|
|||
|
||||
role = message["role"]
|
||||
if role not in MessageRole.roles():
|
||||
raise ValueError(f"Incorrect role {role}, only {MessageRole.roles()} are supported for now.")
|
||||
raise ValueError(
|
||||
f"Incorrect role {role}, only {MessageRole.roles()} are supported for now."
|
||||
)
|
||||
|
||||
if role in role_conversions:
|
||||
message["role"] = role_conversions[role]
|
||||
|
||||
if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]:
|
||||
if (
|
||||
len(final_message_list) > 0
|
||||
and message["role"] == final_message_list[-1]["role"]
|
||||
):
|
||||
final_message_list[-1]["content"] += "\n=======\n" + message["content"]
|
||||
else:
|
||||
final_message_list.append(message)
|
||||
|
@ -81,8 +88,12 @@ class HfEngine:
|
|||
try:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load tokenizer for model {model_id}: {e}. Loading default tokenizer instead.")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-1.7B-Instruct")
|
||||
logger.warning(
|
||||
f"Failed to load tokenizer for model {model_id}: {e}. Loading default tokenizer instead."
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
"HuggingFaceTB/SmolLM2-1.7B-Instruct"
|
||||
)
|
||||
|
||||
def get_token_counts(self):
|
||||
return {
|
||||
|
@ -91,12 +102,18 @@ class HfEngine:
|
|||
}
|
||||
|
||||
def generate(
|
||||
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
grammar: Optional[str] = None,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(
|
||||
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
grammar: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Process the input messages and return the model's response.
|
||||
|
||||
|
@ -127,11 +144,15 @@ class HfEngine:
|
|||
```
|
||||
"""
|
||||
if not isinstance(messages, List):
|
||||
raise ValueError("Messages should be a list of dictionaries with 'role' and 'content' keys.")
|
||||
raise ValueError(
|
||||
"Messages should be a list of dictionaries with 'role' and 'content' keys."
|
||||
)
|
||||
if stop_sequences is None:
|
||||
stop_sequences = []
|
||||
response = self.generate(messages, stop_sequences, grammar)
|
||||
self.last_input_token_count = len(self.tokenizer.apply_chat_template(messages, tokenize=True))
|
||||
self.last_input_token_count = len(
|
||||
self.tokenizer.apply_chat_template(messages, tokenize=True)
|
||||
)
|
||||
self.last_output_token_count = len(self.tokenizer.encode(response))
|
||||
|
||||
# Remove stop sequences from LLM output
|
||||
|
@ -175,18 +196,28 @@ class HfApiEngine(HfEngine):
|
|||
self.max_tokens = max_tokens
|
||||
|
||||
def generate(
|
||||
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
grammar: Optional[str] = None,
|
||||
) -> str:
|
||||
# Get clean message list
|
||||
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
|
||||
messages = get_clean_message_list(
|
||||
messages, role_conversions=llama_role_conversions
|
||||
)
|
||||
|
||||
# Send messages to the Hugging Face Inference API
|
||||
if grammar is not None:
|
||||
response = self.client.chat_completion(
|
||||
messages, stop=stop_sequences, max_tokens=self.max_tokens, response_format=grammar
|
||||
messages,
|
||||
stop=stop_sequences,
|
||||
max_tokens=self.max_tokens,
|
||||
response_format=grammar,
|
||||
)
|
||||
else:
|
||||
response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=self.max_tokens)
|
||||
response = self.client.chat_completion(
|
||||
messages, stop=stop_sequences, max_tokens=self.max_tokens
|
||||
)
|
||||
|
||||
response = response.choices[0].message.content
|
||||
return response
|
||||
|
@ -207,7 +238,9 @@ class TransformersEngine(HfEngine):
|
|||
max_length: int = 1500,
|
||||
) -> str:
|
||||
# Get clean message list
|
||||
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
|
||||
messages = get_clean_message_list(
|
||||
messages, role_conversions=llama_role_conversions
|
||||
)
|
||||
|
||||
# Get LLM output
|
||||
if stop_sequences is not None and len(stop_sequences) > 0:
|
||||
|
|
|
@ -17,12 +17,14 @@
|
|||
from .utils import console
|
||||
|
||||
|
||||
|
||||
class Monitor:
|
||||
def __init__(self, tracked_llm_engine):
|
||||
self.step_durations = []
|
||||
self.tracked_llm_engine = tracked_llm_engine
|
||||
if getattr(self.tracked_llm_engine, "last_input_token_count", "Not found") != "Not found":
|
||||
if (
|
||||
getattr(self.tracked_llm_engine, "last_input_token_count", "Not found")
|
||||
!= "Not found"
|
||||
):
|
||||
self.total_input_token_count = 0
|
||||
self.total_output_token_count = 0
|
||||
|
||||
|
@ -33,103 +35,11 @@ class Monitor:
|
|||
console.print(f"- Time taken: {step_duration:.2f} seconds")
|
||||
|
||||
if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None:
|
||||
self.total_input_token_count += self.tracked_llm_engine.last_input_token_count
|
||||
self.total_output_token_count += self.tracked_llm_engine.last_output_token_count
|
||||
self.total_input_token_count += (
|
||||
self.tracked_llm_engine.last_input_token_count
|
||||
)
|
||||
self.total_output_token_count += (
|
||||
self.tracked_llm_engine.last_output_token_count
|
||||
)
|
||||
console.print(f"- Input tokens: {self.total_input_token_count:,}")
|
||||
console.print(f"- Output tokens: {self.total_output_token_count:,}")
|
||||
|
||||
|
||||
from typing import Optional, Union, List, Any
|
||||
import httpx
|
||||
import logging
|
||||
import os
|
||||
from langfuse.client import Langfuse, StatefulTraceClient, StatefulSpanClient, StateType
|
||||
|
||||
|
||||
class BaseTracker:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def call(cls, *args, **kwargs):
|
||||
pass
|
||||
|
||||
class LangfuseTracker(BaseTracker):
|
||||
log = logging.getLogger("langfuse")
|
||||
|
||||
def __init__(self, *, public_key: Optional[str] = None, secret_key: Optional[str] = None,
|
||||
host: Optional[str] = None, debug: bool = False, stateful_client: Optional[
|
||||
Union[StatefulTraceClient, StatefulSpanClient]
|
||||
] = None, update_stateful_client: bool = False, version: Optional[str] = None,
|
||||
session_id: Optional[str] = None, user_id: Optional[str] = None, trace_name: Optional[str] = None,
|
||||
release: Optional[str] = None, metadata: Optional[Any] = None, tags: Optional[List[str]] = None,
|
||||
threads: Optional[int] = None, flush_at: Optional[int] = None, flush_interval: Optional[int] = None,
|
||||
max_retries: Optional[int] = None, timeout: Optional[int] = None, enabled: Optional[bool] = None,
|
||||
httpx_client: Optional[httpx.Client] = None, sdk_integration: str = "default") -> None:
|
||||
super().__init__()
|
||||
self.version = version
|
||||
self.session_id = session_id
|
||||
self.user_id = user_id
|
||||
self.trace_name = trace_name
|
||||
self.release = release
|
||||
self.metadata = metadata
|
||||
self.tags = tags
|
||||
|
||||
self.root_span = None
|
||||
self.update_stateful_client = update_stateful_client
|
||||
self.langfuse = None
|
||||
|
||||
prio_public_key = public_key or os.environ.get("LANGFUSE_PUBLIC_KEY")
|
||||
prio_secret_key = secret_key or os.environ.get("LANGFUSE_SECRET_KEY")
|
||||
prio_host = host or os.environ.get(
|
||||
"LANGFUSE_HOST", "https://cloud.langfuse.com"
|
||||
)
|
||||
|
||||
if stateful_client and isinstance(stateful_client, StatefulTraceClient):
|
||||
self.trace = stateful_client
|
||||
self._task_manager = stateful_client.task_manager
|
||||
return
|
||||
|
||||
elif stateful_client and isinstance(stateful_client, StatefulSpanClient):
|
||||
self.root_span = stateful_client
|
||||
self.trace = StatefulTraceClient(
|
||||
stateful_client.client,
|
||||
stateful_client.trace_id,
|
||||
StateType.TRACE,
|
||||
stateful_client.trace_id,
|
||||
stateful_client.task_manager,
|
||||
)
|
||||
self._task_manager = stateful_client.task_manager
|
||||
return
|
||||
|
||||
args = {
|
||||
"public_key": prio_public_key,
|
||||
"secret_key": prio_secret_key,
|
||||
"host": prio_host,
|
||||
"debug": debug,
|
||||
}
|
||||
|
||||
if release is not None:
|
||||
args["release"] = release
|
||||
if threads is not None:
|
||||
args["threads"] = threads
|
||||
if flush_at is not None:
|
||||
args["flush_at"] = flush_at
|
||||
if flush_interval is not None:
|
||||
args["flush_interval"] = flush_interval
|
||||
if max_retries is not None:
|
||||
args["max_retries"] = max_retries
|
||||
if timeout is not None:
|
||||
args["timeout"] = timeout
|
||||
if enabled is not None:
|
||||
args["enabled"] = enabled
|
||||
if httpx_client is not None:
|
||||
args["httpx_client"] = httpx_client
|
||||
args["sdk_integration"] = sdk_integration
|
||||
|
||||
self.langfuse = Langfuse(**args)
|
||||
self.trace: Optional[StatefulTraceClient] = None
|
||||
self._task_manager = self.langfuse.task_manager
|
||||
|
||||
def call(self, i, o, name=None, **kwargs):
|
||||
self.langfuse.trace(input=i, output=o, name=name, metadata=kwargs)
|
|
@ -42,7 +42,10 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
|
|||
return prompt_or_repo_id
|
||||
|
||||
prompt_file = cached_file(
|
||||
prompt_or_repo_id, PROMPT_FILES[mode], repo_type="dataset", user_agent={"agent": agent_name}
|
||||
prompt_or_repo_id,
|
||||
PROMPT_FILES[mode],
|
||||
repo_type="dataset",
|
||||
user_agent={"agent": agent_name},
|
||||
)
|
||||
with open(prompt_file, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
|
|
@ -26,6 +26,7 @@ import pandas as pd
|
|||
|
||||
from .utils import truncate_content
|
||||
|
||||
|
||||
class InterpreterError(ValueError):
|
||||
"""
|
||||
An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported
|
||||
|
@ -38,7 +39,8 @@ class InterpreterError(ValueError):
|
|||
ERRORS = {
|
||||
name: getattr(builtins, name)
|
||||
for name in dir(builtins)
|
||||
if isinstance(getattr(builtins, name), type) and issubclass(getattr(builtins, name), BaseException)
|
||||
if isinstance(getattr(builtins, name), type)
|
||||
and issubclass(getattr(builtins, name), BaseException)
|
||||
}
|
||||
|
||||
|
||||
|
@ -92,7 +94,9 @@ def evaluate_unaryop(expression, state, static_tools, custom_tools):
|
|||
elif isinstance(expression.op, ast.Invert):
|
||||
return ~operand
|
||||
else:
|
||||
raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.")
|
||||
raise InterpreterError(
|
||||
f"Unary operation {expression.op.__class__.__name__} is not supported."
|
||||
)
|
||||
|
||||
|
||||
def evaluate_lambda(lambda_expression, state, static_tools, custom_tools):
|
||||
|
@ -102,7 +106,9 @@ def evaluate_lambda(lambda_expression, state, static_tools, custom_tools):
|
|||
new_state = state.copy()
|
||||
for arg, value in zip(args, values):
|
||||
new_state[arg] = value
|
||||
return evaluate_ast(lambda_expression.body, new_state, static_tools, custom_tools)
|
||||
return evaluate_ast(
|
||||
lambda_expression.body, new_state, static_tools, custom_tools
|
||||
)
|
||||
|
||||
return lambda_func
|
||||
|
||||
|
@ -120,7 +126,9 @@ def evaluate_while(while_loop, state, static_tools, custom_tools):
|
|||
break
|
||||
iterations += 1
|
||||
if iterations > max_iterations:
|
||||
raise InterpreterError(f"Maximum number of {max_iterations} iterations in While loop exceeded")
|
||||
raise InterpreterError(
|
||||
f"Maximum number of {max_iterations} iterations in While loop exceeded"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
|
@ -128,7 +136,10 @@ def create_function(func_def, state, static_tools, custom_tools):
|
|||
def new_func(*args, **kwargs):
|
||||
func_state = state.copy()
|
||||
arg_names = [arg.arg for arg in func_def.args.args]
|
||||
default_values = [evaluate_ast(d, state, static_tools, custom_tools) for d in func_def.args.defaults]
|
||||
default_values = [
|
||||
evaluate_ast(d, state, static_tools, custom_tools)
|
||||
for d in func_def.args.defaults
|
||||
]
|
||||
|
||||
# Apply default values
|
||||
defaults = dict(zip(arg_names[-len(default_values) :], default_values))
|
||||
|
@ -180,26 +191,39 @@ def create_class(class_name, class_bases, class_body):
|
|||
|
||||
|
||||
def evaluate_function_def(func_def, state, static_tools, custom_tools):
|
||||
custom_tools[func_def.name] = create_function(func_def, state, static_tools, custom_tools)
|
||||
custom_tools[func_def.name] = create_function(
|
||||
func_def, state, static_tools, custom_tools
|
||||
)
|
||||
return custom_tools[func_def.name]
|
||||
|
||||
|
||||
def evaluate_class_def(class_def, state, static_tools, custom_tools):
|
||||
class_name = class_def.name
|
||||
bases = [evaluate_ast(base, state, static_tools, custom_tools) for base in class_def.bases]
|
||||
bases = [
|
||||
evaluate_ast(base, state, static_tools, custom_tools)
|
||||
for base in class_def.bases
|
||||
]
|
||||
class_dict = {}
|
||||
|
||||
for stmt in class_def.body:
|
||||
if isinstance(stmt, ast.FunctionDef):
|
||||
class_dict[stmt.name] = evaluate_function_def(stmt, state, static_tools, custom_tools)
|
||||
class_dict[stmt.name] = evaluate_function_def(
|
||||
stmt, state, static_tools, custom_tools
|
||||
)
|
||||
elif isinstance(stmt, ast.Assign):
|
||||
for target in stmt.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
class_dict[target.id] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
|
||||
class_dict[target.id] = evaluate_ast(
|
||||
stmt.value, state, static_tools, custom_tools
|
||||
)
|
||||
elif isinstance(target, ast.Attribute):
|
||||
class_dict[target.attr] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
|
||||
class_dict[target.attr] = evaluate_ast(
|
||||
stmt.value, state, static_tools, custom_tools
|
||||
)
|
||||
else:
|
||||
raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}")
|
||||
raise InterpreterError(
|
||||
f"Unsupported statement in class body: {stmt.__class__.__name__}"
|
||||
)
|
||||
|
||||
new_class = type(class_name, tuple(bases), class_dict)
|
||||
state[class_name] = new_class
|
||||
|
@ -223,7 +247,9 @@ def evaluate_augassign(expression, state, static_tools, custom_tools):
|
|||
elif isinstance(target, ast.List):
|
||||
return [get_current_value(elt) for elt in target.elts]
|
||||
else:
|
||||
raise InterpreterError("AugAssign not supported for {type(target)} targets.")
|
||||
raise InterpreterError(
|
||||
"AugAssign not supported for {type(target)} targets."
|
||||
)
|
||||
|
||||
current_value = get_current_value(expression.target)
|
||||
value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools)
|
||||
|
@ -232,7 +258,9 @@ def evaluate_augassign(expression, state, static_tools, custom_tools):
|
|||
if isinstance(expression.op, ast.Add):
|
||||
if isinstance(current_value, list):
|
||||
if not isinstance(value_to_add, list):
|
||||
raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.")
|
||||
raise InterpreterError(
|
||||
f"Cannot add non-list value {value_to_add} to a list."
|
||||
)
|
||||
updated_value = current_value + value_to_add
|
||||
else:
|
||||
updated_value = current_value + value_to_add
|
||||
|
@ -259,7 +287,9 @@ def evaluate_augassign(expression, state, static_tools, custom_tools):
|
|||
elif isinstance(expression.op, ast.RShift):
|
||||
updated_value = current_value >> value_to_add
|
||||
else:
|
||||
raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.")
|
||||
raise InterpreterError(
|
||||
f"Operation {type(expression.op).__name__} is not supported."
|
||||
)
|
||||
|
||||
# Update the state
|
||||
set_value(expression.target, updated_value, state, static_tools, custom_tools)
|
||||
|
@ -311,7 +341,9 @@ def evaluate_binop(binop, state, static_tools, custom_tools):
|
|||
elif isinstance(binop.op, ast.RShift):
|
||||
return left_val >> right_val
|
||||
else:
|
||||
raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.")
|
||||
raise NotImplementedError(
|
||||
f"Binary operation {type(binop.op).__name__} is not implemented."
|
||||
)
|
||||
|
||||
|
||||
def evaluate_assign(assign, state, static_tools, custom_tools):
|
||||
|
@ -321,7 +353,9 @@ def evaluate_assign(assign, state, static_tools, custom_tools):
|
|||
set_value(target, result, state, static_tools, custom_tools)
|
||||
else:
|
||||
if len(assign.targets) != len(result):
|
||||
raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.")
|
||||
raise InterpreterError(
|
||||
f"Assign failed: expected {len(result)} values but got {len(assign.targets)}."
|
||||
)
|
||||
expanded_values = []
|
||||
for tgt in assign.targets:
|
||||
if isinstance(tgt, ast.Starred):
|
||||
|
@ -336,7 +370,9 @@ def evaluate_assign(assign, state, static_tools, custom_tools):
|
|||
def set_value(target, value, state, static_tools, custom_tools):
|
||||
if isinstance(target, ast.Name):
|
||||
if target.id in static_tools:
|
||||
raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!")
|
||||
raise InterpreterError(
|
||||
f"Cannot assign to name '{target.id}': doing this would erase the existing tool!"
|
||||
)
|
||||
state[target.id] = value
|
||||
elif isinstance(target, ast.Tuple):
|
||||
if not isinstance(value, tuple):
|
||||
|
@ -399,9 +435,14 @@ def evaluate_call(call, state, static_tools, custom_tools):
|
|||
else:
|
||||
args.append(evaluate_ast(arg, state, static_tools, custom_tools))
|
||||
|
||||
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools) for keyword in call.keywords}
|
||||
kwargs = {
|
||||
keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools)
|
||||
for keyword in call.keywords
|
||||
}
|
||||
|
||||
if isinstance(func, type) and len(func.__module__.split(".")) > 1: # Check for user-defined classes
|
||||
if (
|
||||
isinstance(func, type) and len(func.__module__.split(".")) > 1
|
||||
): # Check for user-defined classes
|
||||
# Instantiate the class using its constructor
|
||||
obj = func.__new__(func) # Create a new instance of the class
|
||||
if hasattr(obj, "__init__"): # Check if the class has an __init__ method
|
||||
|
@ -441,7 +482,9 @@ def evaluate_subscript(subscript, state, static_tools, custom_tools):
|
|||
value = evaluate_ast(subscript.value, state, static_tools, custom_tools)
|
||||
|
||||
if isinstance(value, str) and isinstance(index, str):
|
||||
raise InterpreterError("You're trying to subscript a string with a string index, which is impossible")
|
||||
raise InterpreterError(
|
||||
"You're trying to subscript a string with a string index, which is impossible"
|
||||
)
|
||||
if isinstance(value, pd.core.indexing._LocIndexer):
|
||||
parent_object = value.obj
|
||||
return parent_object.loc[index]
|
||||
|
@ -453,11 +496,15 @@ def evaluate_subscript(subscript, state, static_tools, custom_tools):
|
|||
return value[index]
|
||||
elif isinstance(value, (list, tuple)):
|
||||
if not (-len(value) <= index < len(value)):
|
||||
raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}")
|
||||
raise InterpreterError(
|
||||
f"Index {index} out of bounds for list of length {len(value)}"
|
||||
)
|
||||
return value[int(index)]
|
||||
elif isinstance(value, str):
|
||||
if not (-len(value) <= index < len(value)):
|
||||
raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}")
|
||||
raise InterpreterError(
|
||||
f"Index {index} out of bounds for string of length {len(value)}"
|
||||
)
|
||||
return value[index]
|
||||
elif index in value:
|
||||
return value[index]
|
||||
|
@ -483,7 +530,10 @@ def evaluate_name(name, state, static_tools, custom_tools):
|
|||
|
||||
def evaluate_condition(condition, state, static_tools, custom_tools):
|
||||
left = evaluate_ast(condition.left, state, static_tools, custom_tools)
|
||||
comparators = [evaluate_ast(c, state, static_tools, custom_tools) for c in condition.comparators]
|
||||
comparators = [
|
||||
evaluate_ast(c, state, static_tools, custom_tools)
|
||||
for c in condition.comparators
|
||||
]
|
||||
ops = [type(op) for op in condition.ops]
|
||||
|
||||
result = True
|
||||
|
@ -561,9 +611,13 @@ def evaluate_for(for_loop, state, static_tools, custom_tools):
|
|||
def evaluate_listcomp(listcomp, state, static_tools, custom_tools):
|
||||
def inner_evaluate(generators, index, current_state):
|
||||
if index >= len(generators):
|
||||
return [evaluate_ast(listcomp.elt, current_state, static_tools, custom_tools)]
|
||||
return [
|
||||
evaluate_ast(listcomp.elt, current_state, static_tools, custom_tools)
|
||||
]
|
||||
generator = generators[index]
|
||||
iter_value = evaluate_ast(generator.iter, current_state, static_tools, custom_tools)
|
||||
iter_value = evaluate_ast(
|
||||
generator.iter, current_state, static_tools, custom_tools
|
||||
)
|
||||
result = []
|
||||
for value in iter_value:
|
||||
new_state = current_state.copy()
|
||||
|
@ -572,7 +626,10 @@ def evaluate_listcomp(listcomp, state, static_tools, custom_tools):
|
|||
new_state[elem.id] = value[idx]
|
||||
else:
|
||||
new_state[generator.target.id] = value
|
||||
if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in generator.ifs):
|
||||
if all(
|
||||
evaluate_ast(if_clause, new_state, static_tools, custom_tools)
|
||||
for if_clause in generator.ifs
|
||||
):
|
||||
result.extend(inner_evaluate(generators, index + 1, new_state))
|
||||
return result
|
||||
|
||||
|
@ -586,7 +643,9 @@ def evaluate_try(try_node, state, static_tools, custom_tools):
|
|||
except Exception as e:
|
||||
matched = False
|
||||
for handler in try_node.handlers:
|
||||
if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, static_tools, custom_tools)):
|
||||
if handler.type is None or isinstance(
|
||||
e, evaluate_ast(handler.type, state, static_tools, custom_tools)
|
||||
):
|
||||
matched = True
|
||||
if handler.name:
|
||||
state[handler.name] = e
|
||||
|
@ -638,7 +697,9 @@ def evaluate_assert(assert_node, state, static_tools, custom_tools):
|
|||
def evaluate_with(with_node, state, static_tools, custom_tools):
|
||||
contexts = []
|
||||
for item in with_node.items:
|
||||
context_expr = evaluate_ast(item.context_expr, state, static_tools, custom_tools)
|
||||
context_expr = evaluate_ast(
|
||||
item.context_expr, state, static_tools, custom_tools
|
||||
)
|
||||
if item.optional_vars:
|
||||
state[item.optional_vars.id] = context_expr.__enter__()
|
||||
contexts.append(state[item.optional_vars.id])
|
||||
|
@ -661,7 +722,9 @@ def evaluate_with(with_node, state, static_tools, custom_tools):
|
|||
def import_modules(expression, state, authorized_imports):
|
||||
def check_module_authorized(module_name):
|
||||
module_path = module_name.split(".")
|
||||
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
|
||||
module_subpaths = [
|
||||
".".join(module_path[:i]) for i in range(1, len(module_path) + 1)
|
||||
]
|
||||
return any(subpath in authorized_imports for subpath in module_subpaths)
|
||||
|
||||
if isinstance(expression, ast.Import):
|
||||
|
@ -676,7 +739,9 @@ def import_modules(expression, state, authorized_imports):
|
|||
return None
|
||||
elif isinstance(expression, ast.ImportFrom):
|
||||
if check_module_authorized(expression.module):
|
||||
module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
|
||||
module = __import__(
|
||||
expression.module, fromlist=[alias.name for alias in expression.names]
|
||||
)
|
||||
for alias in expression.names:
|
||||
state[alias.asname or alias.name] = getattr(module, alias.name)
|
||||
else:
|
||||
|
@ -691,9 +756,14 @@ def evaluate_dictcomp(dictcomp, state, static_tools, custom_tools):
|
|||
for value in iter_value:
|
||||
new_state = state.copy()
|
||||
set_value(gen.target, value, new_state, static_tools, custom_tools)
|
||||
if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in gen.ifs):
|
||||
if all(
|
||||
evaluate_ast(if_clause, new_state, static_tools, custom_tools)
|
||||
for if_clause in gen.ifs
|
||||
):
|
||||
key = evaluate_ast(dictcomp.key, new_state, static_tools, custom_tools)
|
||||
val = evaluate_ast(dictcomp.value, new_state, static_tools, custom_tools)
|
||||
val = evaluate_ast(
|
||||
dictcomp.value, new_state, static_tools, custom_tools
|
||||
)
|
||||
result[key] = val
|
||||
return result
|
||||
|
||||
|
@ -744,7 +814,10 @@ def evaluate_ast(
|
|||
# Constant -> just return the value
|
||||
return expression.value
|
||||
elif isinstance(expression, ast.Tuple):
|
||||
return tuple(evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts)
|
||||
return tuple(
|
||||
evaluate_ast(elt, state, static_tools, custom_tools)
|
||||
for elt in expression.elts
|
||||
)
|
||||
elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
|
||||
return evaluate_listcomp(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.UnaryOp):
|
||||
|
@ -770,8 +843,13 @@ def evaluate_ast(
|
|||
return evaluate_function_def(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Dict):
|
||||
# Dict -> evaluate all keys and values
|
||||
keys = [evaluate_ast(k, state, static_tools, custom_tools) for k in expression.keys]
|
||||
values = [evaluate_ast(v, state, static_tools, custom_tools) for v in expression.values]
|
||||
keys = [
|
||||
evaluate_ast(k, state, static_tools, custom_tools) for k in expression.keys
|
||||
]
|
||||
values = [
|
||||
evaluate_ast(v, state, static_tools, custom_tools)
|
||||
for v in expression.values
|
||||
]
|
||||
return dict(zip(keys, values))
|
||||
elif isinstance(expression, ast.Expr):
|
||||
# Expression -> evaluate the content
|
||||
|
@ -788,10 +866,18 @@ def evaluate_ast(
|
|||
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
|
||||
return evaluate_ast(expression.value, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.JoinedStr):
|
||||
return "".join([str(evaluate_ast(v, state, static_tools, custom_tools)) for v in expression.values])
|
||||
return "".join(
|
||||
[
|
||||
str(evaluate_ast(v, state, static_tools, custom_tools))
|
||||
for v in expression.values
|
||||
]
|
||||
)
|
||||
elif isinstance(expression, ast.List):
|
||||
# List -> evaluate all elements
|
||||
return [evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts]
|
||||
return [
|
||||
evaluate_ast(elt, state, static_tools, custom_tools)
|
||||
for elt in expression.elts
|
||||
]
|
||||
elif isinstance(expression, ast.Name):
|
||||
# Name -> pick up the value in the state
|
||||
return evaluate_name(expression, state, static_tools, custom_tools)
|
||||
|
@ -815,7 +901,9 @@ def evaluate_ast(
|
|||
evaluate_ast(expression.upper, state, static_tools, custom_tools)
|
||||
if expression.upper is not None
|
||||
else None,
|
||||
evaluate_ast(expression.step, state, static_tools, custom_tools) if expression.step is not None else None,
|
||||
evaluate_ast(expression.step, state, static_tools, custom_tools)
|
||||
if expression.step is not None
|
||||
else None,
|
||||
)
|
||||
elif isinstance(expression, ast.DictComp):
|
||||
return evaluate_dictcomp(expression, state, static_tools, custom_tools)
|
||||
|
@ -834,17 +922,24 @@ def evaluate_ast(
|
|||
elif isinstance(expression, ast.With):
|
||||
return evaluate_with(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Set):
|
||||
return {evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts}
|
||||
return {
|
||||
evaluate_ast(elt, state, static_tools, custom_tools)
|
||||
for elt in expression.elts
|
||||
}
|
||||
elif isinstance(expression, ast.Return):
|
||||
raise ReturnException(
|
||||
evaluate_ast(expression.value, state, static_tools, custom_tools) if expression.value else None
|
||||
evaluate_ast(expression.value, state, static_tools, custom_tools)
|
||||
if expression.value
|
||||
else None
|
||||
)
|
||||
else:
|
||||
# For now we refuse anything else. Let's add things as we need them.
|
||||
raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
|
||||
|
||||
|
||||
def truncate_print_outputs(print_outputs: str, max_len_outputs: int = MAX_LEN_OUTPUT) -> str:
|
||||
def truncate_print_outputs(
|
||||
print_outputs: str, max_len_outputs: int = MAX_LEN_OUTPUT
|
||||
) -> str:
|
||||
if len(print_outputs) < max_len_outputs:
|
||||
return print_outputs
|
||||
else:
|
||||
|
@ -895,8 +990,12 @@ def evaluate_python_code(
|
|||
OPERATIONS_COUNT = 0
|
||||
try:
|
||||
for node in expression.body:
|
||||
result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
|
||||
state["print_outputs"] = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT)
|
||||
result = evaluate_ast(
|
||||
node, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
state["print_outputs"] = truncate_content(
|
||||
PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT
|
||||
)
|
||||
return result
|
||||
except InterpreterError as e:
|
||||
msg = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT)
|
||||
|
|
|
@ -26,7 +26,9 @@ class DuckDuckGoSearchTool(Tool):
|
|||
name = "web_search"
|
||||
description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
|
||||
Each result has keys 'title', 'href' and 'body'."""
|
||||
inputs = {"query": {"type": "string", "description": "The search query to perform."}}
|
||||
inputs = {
|
||||
"query": {"type": "string", "description": "The search query to perform."}
|
||||
}
|
||||
output_type = "any"
|
||||
|
||||
def forward(self, query: str) -> str:
|
||||
|
|
141
agents/tools.py
141
agents/tools.py
|
@ -26,7 +26,13 @@ from functools import lru_cache, wraps
|
|||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
from huggingface_hub import create_repo, get_collection, hf_hub_download, metadata_update, upload_folder
|
||||
from huggingface_hub import (
|
||||
create_repo,
|
||||
get_collection,
|
||||
hf_hub_download,
|
||||
metadata_update,
|
||||
upload_folder,
|
||||
)
|
||||
from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session
|
||||
from packaging import version
|
||||
|
||||
|
@ -73,7 +79,9 @@ def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
|
|||
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs)
|
||||
return "model"
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(f"`{repo_id}` does not seem to be a valid repo identifier on the Hub.")
|
||||
raise EnvironmentError(
|
||||
f"`{repo_id}` does not seem to be a valid repo identifier on the Hub."
|
||||
)
|
||||
except Exception:
|
||||
return "model"
|
||||
except Exception:
|
||||
|
@ -158,7 +166,15 @@ class Tool:
|
|||
"inputs": dict,
|
||||
"output_type": str,
|
||||
}
|
||||
authorized_types = ["string", "integer", "number", "image", "audio", "any", "boolean"]
|
||||
authorized_types = [
|
||||
"string",
|
||||
"integer",
|
||||
"number",
|
||||
"image",
|
||||
"audio",
|
||||
"any",
|
||||
"boolean",
|
||||
]
|
||||
|
||||
for attr, expected_type in required_attributes.items():
|
||||
attr_value = getattr(self, attr, None)
|
||||
|
@ -169,7 +185,9 @@ class Tool:
|
|||
f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead."
|
||||
)
|
||||
for input_name, input_content in self.inputs.items():
|
||||
assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary."
|
||||
assert isinstance(
|
||||
input_content, dict
|
||||
), f"Input '{input_name}' should be a dictionary."
|
||||
assert (
|
||||
"type" in input_content and "description" in input_content
|
||||
), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
|
||||
|
@ -251,7 +269,11 @@ class Tool:
|
|||
# Save app file
|
||||
app_file = os.path.join(output_dir, "app.py")
|
||||
with open(app_file, "w", encoding="utf-8") as f:
|
||||
f.write(APP_FILE_TEMPLATE.format(module_name=last_module, class_name=self.__class__.__name__))
|
||||
f.write(
|
||||
APP_FILE_TEMPLATE.format(
|
||||
module_name=last_module, class_name=self.__class__.__name__
|
||||
)
|
||||
)
|
||||
|
||||
# Save requirements file
|
||||
requirements_file = os.path.join(output_dir, "requirements.txt")
|
||||
|
@ -343,7 +365,9 @@ class Tool:
|
|||
custom_tool = config
|
||||
|
||||
tool_class = custom_tool["tool_class"]
|
||||
tool_class = get_class_from_dynamic_module(tool_class, repo_id, token=token, **hub_kwargs)
|
||||
tool_class = get_class_from_dynamic_module(
|
||||
tool_class, repo_id, token=token, **hub_kwargs
|
||||
)
|
||||
|
||||
if len(tool_class.name) == 0:
|
||||
tool_class.name = custom_tool["name"]
|
||||
|
@ -420,7 +444,9 @@ class Tool:
|
|||
with tempfile.TemporaryDirectory() as work_dir:
|
||||
# Save all files.
|
||||
self.save(work_dir)
|
||||
logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
|
||||
logger.info(
|
||||
f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}"
|
||||
)
|
||||
return upload_folder(
|
||||
repo_id=repo_id,
|
||||
commit_message=commit_message,
|
||||
|
@ -432,7 +458,11 @@ class Tool:
|
|||
|
||||
@staticmethod
|
||||
def from_space(
|
||||
space_id: str, name: str, description: str, api_name: Optional[str] = None, token: Optional[str] = None
|
||||
space_id: str,
|
||||
name: str,
|
||||
description: str,
|
||||
api_name: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Creates a [`Tool`] from a Space given its id on the Hub.
|
||||
|
@ -485,7 +515,9 @@ class Tool:
|
|||
self.client = Client(space_id, hf_token=token)
|
||||
self.name = name
|
||||
self.description = description
|
||||
space_description = self.client.view_api(return_format="dict", print_info=False)["named_endpoints"]
|
||||
space_description = self.client.view_api(
|
||||
return_format="dict", print_info=False
|
||||
)["named_endpoints"]
|
||||
|
||||
# If api_name is not defined, take the first of the available APIs for this space
|
||||
if api_name is None:
|
||||
|
@ -498,7 +530,9 @@ class Tool:
|
|||
try:
|
||||
space_description_api = space_description[api_name]
|
||||
except KeyError:
|
||||
raise KeyError(f"Could not find specified {api_name=} among available api names.")
|
||||
raise KeyError(
|
||||
f"Could not find specified {api_name=} among available api names."
|
||||
)
|
||||
|
||||
self.inputs = {}
|
||||
for parameter in space_description_api["parameters"]:
|
||||
|
@ -523,9 +557,11 @@ class Tool:
|
|||
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
||||
arg.save(temp_file.name)
|
||||
arg = temp_file.name
|
||||
if (isinstance(arg, (str, Path)) and Path(arg).exists() and Path(arg).is_file()) or is_http_url_like(
|
||||
arg
|
||||
):
|
||||
if (
|
||||
isinstance(arg, (str, Path))
|
||||
and Path(arg).exists()
|
||||
and Path(arg).is_file()
|
||||
) or is_http_url_like(arg):
|
||||
arg = handle_file(arg)
|
||||
return arg
|
||||
|
||||
|
@ -544,7 +580,9 @@ class Tool:
|
|||
] # Sometime the space also returns the generation seed, in which case the result is at index 0
|
||||
return output
|
||||
|
||||
return SpaceToolWrapper(space_id, name, description, api_name=api_name, token=token)
|
||||
return SpaceToolWrapper(
|
||||
space_id, name, description, api_name=api_name, token=token
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_gradio(gradio_tool):
|
||||
|
@ -561,7 +599,8 @@ class Tool:
|
|||
self._gradio_tool = _gradio_tool
|
||||
func_args = list(inspect.signature(_gradio_tool.run).parameters.items())
|
||||
self.inputs = {
|
||||
key: {"type": CONVERSION_DICT[value.annotation], "description": ""} for key, value in func_args
|
||||
key: {"type": CONVERSION_DICT[value.annotation], "description": ""}
|
||||
for key, value in func_args
|
||||
}
|
||||
self.forward = self._gradio_tool.run
|
||||
|
||||
|
@ -603,7 +642,9 @@ DEFAULT_TOOL_DESCRIPTION_TEMPLATE = """
|
|||
"""
|
||||
|
||||
|
||||
def get_tool_description_with_args(tool: Tool, description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE) -> str:
|
||||
def get_tool_description_with_args(
|
||||
tool: Tool, description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
||||
) -> str:
|
||||
compiled_template = compile_jinja_template(description_template)
|
||||
rendered = compiled_template.render(
|
||||
tool=tool,
|
||||
|
@ -621,7 +662,10 @@ def compile_jinja_template(template):
|
|||
raise ImportError("template requires jinja2 to be installed.")
|
||||
|
||||
if version.parse(jinja2.__version__) < version.parse("3.1.0"):
|
||||
raise ImportError("template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}.")
|
||||
raise ImportError(
|
||||
"template requires jinja2>=3.1.0 to be installed. Your version is "
|
||||
f"{jinja2.__version__}."
|
||||
)
|
||||
|
||||
def raise_exception(message):
|
||||
raise TemplateError(message)
|
||||
|
@ -697,7 +741,9 @@ class PipelineTool(Tool):
|
|||
|
||||
if model is None:
|
||||
if self.default_checkpoint is None:
|
||||
raise ValueError("This tool does not implement a default checkpoint, you need to pass one.")
|
||||
raise ValueError(
|
||||
"This tool does not implement a default checkpoint, you need to pass one."
|
||||
)
|
||||
model = self.default_checkpoint
|
||||
if pre_processor is None:
|
||||
pre_processor = model
|
||||
|
@ -720,15 +766,21 @@ class PipelineTool(Tool):
|
|||
Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
|
||||
"""
|
||||
if isinstance(self.pre_processor, str):
|
||||
self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)
|
||||
self.pre_processor = self.pre_processor_class.from_pretrained(
|
||||
self.pre_processor, **self.hub_kwargs
|
||||
)
|
||||
|
||||
if isinstance(self.model, str):
|
||||
self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs)
|
||||
self.model = self.model_class.from_pretrained(
|
||||
self.model, **self.model_kwargs, **self.hub_kwargs
|
||||
)
|
||||
|
||||
if self.post_processor is None:
|
||||
self.post_processor = self.pre_processor
|
||||
elif isinstance(self.post_processor, str):
|
||||
self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)
|
||||
self.post_processor = self.post_processor_class.from_pretrained(
|
||||
self.post_processor, **self.hub_kwargs
|
||||
)
|
||||
|
||||
if self.device is None:
|
||||
if self.device_map is not None:
|
||||
|
@ -768,8 +820,12 @@ class PipelineTool(Tool):
|
|||
|
||||
encoded_inputs = self.encode(*args, **kwargs)
|
||||
|
||||
tensor_inputs = {k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)}
|
||||
non_tensor_inputs = {k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)}
|
||||
tensor_inputs = {
|
||||
k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)
|
||||
}
|
||||
non_tensor_inputs = {
|
||||
k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)
|
||||
}
|
||||
|
||||
encoded_inputs = send_to_device(tensor_inputs, self.device)
|
||||
outputs = self.forward({**encoded_inputs, **non_tensor_inputs})
|
||||
|
@ -790,7 +846,9 @@ def launch_gradio_demo(tool_class: Tool):
|
|||
try:
|
||||
import gradio as gr
|
||||
except ImportError:
|
||||
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
|
||||
raise ImportError(
|
||||
"Gradio should be installed in order to launch a gradio demo."
|
||||
)
|
||||
|
||||
tool = tool_class()
|
||||
|
||||
|
@ -807,11 +865,15 @@ def launch_gradio_demo(tool_class: Tool):
|
|||
|
||||
gradio_inputs = []
|
||||
for input_name, input_details in tool_class.inputs.items():
|
||||
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[input_details["type"]]
|
||||
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[
|
||||
input_details["type"]
|
||||
]
|
||||
new_component = input_gradio_component_class(label=input_name)
|
||||
gradio_inputs.append(new_component)
|
||||
|
||||
output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[tool_class.output_type]
|
||||
output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[
|
||||
tool_class.output_type
|
||||
]
|
||||
gradio_output = output_gradio_componentclass(label=input_name)
|
||||
|
||||
gr.Interface(
|
||||
|
@ -875,7 +937,9 @@ def load_tool(task_or_repo_id, model_repo_id=None, token=None, **kwargs):
|
|||
f"the tools that you load. We recommend specifying a `revision` to ensure you're loading the "
|
||||
f"code that you have checked."
|
||||
)
|
||||
return Tool.from_hub(task_or_repo_id, model_repo_id=model_repo_id, token=token, **kwargs)
|
||||
return Tool.from_hub(
|
||||
task_or_repo_id, model_repo_id=model_repo_id, token=token, **kwargs
|
||||
)
|
||||
|
||||
|
||||
def add_description(description):
|
||||
|
@ -935,7 +999,9 @@ class EndpointClient:
|
|||
payload["parameters"] = params
|
||||
|
||||
# Make API call
|
||||
response = get_session().post(self.endpoint_url, headers=self.headers, json=payload, data=data)
|
||||
response = get_session().post(
|
||||
self.endpoint_url, headers=self.headers, json=payload, data=data
|
||||
)
|
||||
|
||||
# By default, parse the response for the user.
|
||||
if output_image:
|
||||
|
@ -972,7 +1038,9 @@ class ToolCollection:
|
|||
|
||||
def __init__(self, collection_slug: str, token: Optional[str] = None):
|
||||
self._collection = get_collection(collection_slug, token=token)
|
||||
self._hub_repo_ids = {item.item_id for item in self._collection.items if item.item_type == "space"}
|
||||
self._hub_repo_ids = {
|
||||
item.item_id for item in self._collection.items if item.item_type == "space"
|
||||
}
|
||||
self.tools = {Tool.from_hub(repo_id) for repo_id in self._hub_repo_ids}
|
||||
|
||||
|
||||
|
@ -986,7 +1054,9 @@ def tool(tool_function: Callable) -> Tool:
|
|||
"""
|
||||
parameters = get_json_schema(tool_function)["function"]
|
||||
if "return" not in parameters:
|
||||
raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!")
|
||||
raise TypeHintParsingException(
|
||||
"Tool return type not found: make sure your function has a return type hint!"
|
||||
)
|
||||
class_name = f"{parameters['name'].capitalize()}Tool"
|
||||
|
||||
class SpecificTool(Tool):
|
||||
|
@ -1000,9 +1070,9 @@ def tool(tool_function: Callable) -> Tool:
|
|||
return tool_function(*args, **kwargs)
|
||||
|
||||
original_signature = inspect.signature(tool_function)
|
||||
new_parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)] + list(
|
||||
original_signature.parameters.values()
|
||||
)
|
||||
new_parameters = [
|
||||
inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)
|
||||
] + list(original_signature.parameters.values())
|
||||
new_signature = original_signature.replace(parameters=new_parameters)
|
||||
SpecificTool.forward.__signature__ = new_signature
|
||||
|
||||
|
@ -1049,7 +1119,10 @@ class Toolbox:
|
|||
The template to use to describe the tools. If not provided, the default template will be used.
|
||||
"""
|
||||
return "\n".join(
|
||||
[get_tool_description_with_args(tool, tool_description_template) for tool in self._tools.values()]
|
||||
[
|
||||
get_tool_description_with_args(tool, tool_description_template)
|
||||
for tool in self._tools.values()
|
||||
]
|
||||
)
|
||||
|
||||
def add_tool(self, tool: Tool):
|
||||
|
|
|
@ -16,32 +16,29 @@
|
|||
# limitations under the License.
|
||||
import json
|
||||
import re
|
||||
from typing import Tuple, Dict
|
||||
from typing import Tuple, Dict, Union
|
||||
|
||||
from transformers.utils.import_utils import _is_package_available
|
||||
|
||||
_pygments_available = _is_package_available("pygments")
|
||||
|
||||
|
||||
def is_pygments_available():
|
||||
return _pygments_available
|
||||
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
|
||||
LENGTH_TRUNCATE_REPORTS = 10000
|
||||
|
||||
def truncate_content(content: str, max_length: int = LENGTH_TRUNCATE_REPORTS):
|
||||
if len(content) < max_length:
|
||||
return content
|
||||
else:
|
||||
return content[:max_length//2] + "\n..._(Content was truncated because too long)_...\n---" + content[-max_length//2:]
|
||||
|
||||
|
||||
def parse_json_blob(json_blob: str) -> Dict[str, str]:
|
||||
try:
|
||||
first_accolade_index = json_blob.find("{")
|
||||
last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1]
|
||||
json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace('\\"', "'")
|
||||
json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace(
|
||||
'\\"', "'"
|
||||
)
|
||||
json_data = json.loads(json_blob, strict=False)
|
||||
return json_data
|
||||
except json.JSONDecodeError as e:
|
||||
|
@ -63,7 +60,12 @@ def parse_code_blob(code_blob: str) -> str:
|
|||
try:
|
||||
pattern = r"```(?:py|python)?\n(.*?)\n```"
|
||||
match = re.search(pattern, code_blob, re.DOTALL)
|
||||
if match is None:
|
||||
raise ValueError(
|
||||
f"No match ground for regex pattern {pattern} in {code_blob=}."
|
||||
)
|
||||
return match.group(1).strip()
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"""
|
||||
|
@ -77,7 +79,7 @@ Code:
|
|||
)
|
||||
|
||||
|
||||
def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]:
|
||||
def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]:
|
||||
json_blob = json_blob.replace("```json", "").replace("```", "")
|
||||
tool_call = parse_json_blob(json_blob)
|
||||
if "action" in tool_call and "action_input" in tool_call:
|
||||
|
@ -85,7 +87,25 @@ def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]:
|
|||
elif "action" in tool_call:
|
||||
return tool_call["action"], None
|
||||
else:
|
||||
missing_keys = [key for key in ['action', 'action_input'] if key not in tool_call]
|
||||
missing_keys = [
|
||||
key for key in ["action", "action_input"] if key not in tool_call
|
||||
]
|
||||
error_msg = f"Missing keys: {missing_keys} in blob {tool_call}"
|
||||
console.print(f"[bold red]{error_msg}[/bold red]")
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
MAX_LENGTH_TRUNCATE_CONTENT = 20000
|
||||
|
||||
|
||||
def truncate_content(
|
||||
content: str, max_length: int = MAX_LENGTH_TRUNCATE_CONTENT
|
||||
) -> str:
|
||||
if len(content) <= max_length:
|
||||
return content
|
||||
else:
|
||||
return (
|
||||
content[: MAX_LENGTH_TRUNCATE_CONTENT // 2]
|
||||
+ f"\n..._This content has been truncated to stay below {max_length} characters_...\n"
|
||||
+ content[-MAX_LENGTH_TRUNCATE_CONTENT // 2 :]
|
||||
)
|
||||
|
|
|
@ -120,4 +120,4 @@ def get_weather_api(location (str), date_time: str) -> str:
|
|||
raise ValueError("Conversion of `date_time` to datetime format failed, make sure to provide a string in format '%m/%d/%y %H:%M:%S'. Full trace:" + str(e))
|
||||
temperature_celsius, risk_of_rain, wave_height = get_weather_report_at_coordinates((lon, lat), date_time)
|
||||
return f"Weather report for {location}, {date_time}: Temperature will be {temperature_celsius}°C, risk of rain is {risk_of_rain*100:.0f}%, wave height is {wave_height}m."
|
||||
```
|
||||
```
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,271 @@
|
|||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile requirements.in -o requirements.txt
|
||||
aiofiles==23.2.1
|
||||
# via gradio
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anyio==4.7.0
|
||||
# via
|
||||
# gradio
|
||||
# httpx
|
||||
# starlette
|
||||
appnope==0.1.4
|
||||
# via ipykernel
|
||||
asttokens==3.0.0
|
||||
# via stack-data
|
||||
beautifulsoup4==4.12.3
|
||||
# via markdownify
|
||||
certifi==2024.8.30
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
# requests
|
||||
charset-normalizer==3.4.0
|
||||
# via requests
|
||||
click==8.1.7
|
||||
# via
|
||||
# duckduckgo-search
|
||||
# typer
|
||||
# uvicorn
|
||||
comm==0.2.2
|
||||
# via ipykernel
|
||||
debugpy==1.8.10
|
||||
# via ipykernel
|
||||
decorator==5.1.1
|
||||
# via ipython
|
||||
diskcache==5.6.3
|
||||
# via llama-cpp-python
|
||||
duckduckgo-search==6.4.1
|
||||
# via -r requirements.in
|
||||
executing==2.1.0
|
||||
# via stack-data
|
||||
fastapi==0.115.6
|
||||
# via gradio
|
||||
ffmpy==0.4.0
|
||||
# via gradio
|
||||
filelock==3.16.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
fsspec==2024.10.0
|
||||
# via
|
||||
# gradio-client
|
||||
# huggingface-hub
|
||||
gradio==5.8.0
|
||||
# via -r requirements.in
|
||||
gradio-client==1.5.1
|
||||
# via gradio
|
||||
h11==0.14.0
|
||||
# via
|
||||
# httpcore
|
||||
# uvicorn
|
||||
httpcore==1.0.7
|
||||
# via httpx
|
||||
httpx==0.28.1
|
||||
# via
|
||||
# gradio
|
||||
# gradio-client
|
||||
# safehttpx
|
||||
huggingface-hub==0.26.5
|
||||
# via
|
||||
# gradio
|
||||
# gradio-client
|
||||
# tokenizers
|
||||
# transformers
|
||||
idna==3.10
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
# requests
|
||||
iniconfig==2.0.0
|
||||
# via pytest
|
||||
ipykernel==6.29.5
|
||||
# via -r requirements.in
|
||||
ipython==8.30.0
|
||||
# via ipykernel
|
||||
jedi==0.19.2
|
||||
# via ipython
|
||||
jinja2==3.1.4
|
||||
# via
|
||||
# -r requirements.in
|
||||
# gradio
|
||||
# llama-cpp-python
|
||||
jupyter-client==8.6.3
|
||||
# via ipykernel
|
||||
jupyter-core==5.7.2
|
||||
# via
|
||||
# ipykernel
|
||||
# jupyter-client
|
||||
llama-cpp-python==0.3.5
|
||||
# via -r requirements.in
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
markdownify==0.14.1
|
||||
# via -r requirements.in
|
||||
markupsafe==2.1.5
|
||||
# via
|
||||
# gradio
|
||||
# jinja2
|
||||
matplotlib-inline==0.1.7
|
||||
# via
|
||||
# ipykernel
|
||||
# ipython
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
nest-asyncio==1.6.0
|
||||
# via ipykernel
|
||||
numpy==2.2.0
|
||||
# via
|
||||
# gradio
|
||||
# llama-cpp-python
|
||||
# pandas
|
||||
# transformers
|
||||
orjson==3.10.12
|
||||
# via gradio
|
||||
packaging==24.2
|
||||
# via
|
||||
# gradio
|
||||
# gradio-client
|
||||
# huggingface-hub
|
||||
# ipykernel
|
||||
# pytest
|
||||
# transformers
|
||||
pandas==2.2.3
|
||||
# via
|
||||
# -r requirements.in
|
||||
# gradio
|
||||
parso==0.8.4
|
||||
# via jedi
|
||||
pexpect==4.9.0
|
||||
# via ipython
|
||||
pillow==11.0.0
|
||||
# via
|
||||
# -r requirements.in
|
||||
# gradio
|
||||
platformdirs==4.3.6
|
||||
# via jupyter-core
|
||||
pluggy==1.5.0
|
||||
# via pytest
|
||||
primp==0.8.3
|
||||
# via duckduckgo-search
|
||||
prompt-toolkit==3.0.48
|
||||
# via ipython
|
||||
psutil==6.1.0
|
||||
# via ipykernel
|
||||
ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pure-eval==0.2.3
|
||||
# via stack-data
|
||||
pydantic==2.10.3
|
||||
# via
|
||||
# fastapi
|
||||
# gradio
|
||||
pydantic-core==2.27.1
|
||||
# via pydantic
|
||||
pydub==0.25.1
|
||||
# via gradio
|
||||
pygments==2.18.0
|
||||
# via
|
||||
# ipython
|
||||
# rich
|
||||
pytest==8.3.4
|
||||
# via -r requirements.in
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# jupyter-client
|
||||
# pandas
|
||||
python-dotenv==1.0.1
|
||||
# via -r requirements.in
|
||||
python-multipart==0.0.19
|
||||
# via gradio
|
||||
pytz==2024.2
|
||||
# via pandas
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# gradio
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
pyzmq==26.2.0
|
||||
# via
|
||||
# ipykernel
|
||||
# jupyter-client
|
||||
regex==2024.11.6
|
||||
# via transformers
|
||||
requests==2.32.3
|
||||
# via
|
||||
# -r requirements.in
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
rich==13.9.4
|
||||
# via
|
||||
# -r requirements.in
|
||||
# typer
|
||||
ruff==0.8.2
|
||||
# via gradio
|
||||
safehttpx==0.1.6
|
||||
# via gradio
|
||||
safetensors==0.4.5
|
||||
# via transformers
|
||||
semantic-version==2.10.0
|
||||
# via gradio
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
six==1.17.0
|
||||
# via
|
||||
# markdownify
|
||||
# python-dateutil
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
soupsieve==2.6
|
||||
# via beautifulsoup4
|
||||
stack-data==0.6.3
|
||||
# via ipython
|
||||
starlette==0.41.3
|
||||
# via
|
||||
# fastapi
|
||||
# gradio
|
||||
tokenizers==0.21.0
|
||||
# via transformers
|
||||
tomlkit==0.13.2
|
||||
# via gradio
|
||||
tornado==6.4.2
|
||||
# via
|
||||
# ipykernel
|
||||
# jupyter-client
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
traitlets==5.14.3
|
||||
# via
|
||||
# comm
|
||||
# ipykernel
|
||||
# ipython
|
||||
# jupyter-client
|
||||
# jupyter-core
|
||||
# matplotlib-inline
|
||||
transformers==4.47.0
|
||||
# via -r requirements.in
|
||||
typer==0.15.1
|
||||
# via gradio
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# anyio
|
||||
# fastapi
|
||||
# gradio
|
||||
# gradio-client
|
||||
# huggingface-hub
|
||||
# llama-cpp-python
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# typer
|
||||
tzdata==2024.2
|
||||
# via pandas
|
||||
urllib3==2.2.3
|
||||
# via requests
|
||||
uvicorn==0.32.1
|
||||
# via gradio
|
||||
wcwidth==0.2.13
|
||||
# via prompt-toolkit
|
||||
websockets==14.1
|
||||
# via gradio-client
|
122
setup.py
122
setup.py
|
@ -1,122 +0,0 @@
|
|||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
extras = {}
|
||||
extras["quality"] = [
|
||||
"black ~= 23.1", # hf-doc-builder has a hidden dependency on `black`
|
||||
"hf-doc-builder >= 0.3.0",
|
||||
"ruff ~= 0.6.4",
|
||||
]
|
||||
extras["docs"] = []
|
||||
extras["test_prod"] = ["pytest>=7.2.0,<=8.0.0", "pytest-xdist", "pytest-subtests", "parameterized"]
|
||||
extras["test_dev"] = [
|
||||
"datasets",
|
||||
"diffusers",
|
||||
"evaluate",
|
||||
"torchdata>=0.8.0",
|
||||
"torchpippy>=0.2.0",
|
||||
"transformers",
|
||||
"scipy",
|
||||
"scikit-learn",
|
||||
"tqdm",
|
||||
"bitsandbytes",
|
||||
"timm",
|
||||
]
|
||||
extras["testing"] = extras["test_prod"] + extras["test_dev"]
|
||||
extras["deepspeed"] = ["deepspeed"]
|
||||
extras["rich"] = ["rich"]
|
||||
|
||||
extras["test_trackers"] = ["wandb", "comet-ml", "tensorboard", "dvclive"]
|
||||
extras["dev"] = extras["quality"] + extras["testing"] + extras["rich"]
|
||||
|
||||
extras["sagemaker"] = [
|
||||
"sagemaker", # boto3 is a required package in sagemaker
|
||||
]
|
||||
|
||||
setup(
|
||||
name="accelerate",
|
||||
version="1.2.0.dev0",
|
||||
description="Accelerate",
|
||||
long_description=open("README.md", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
keywords="deep learning",
|
||||
license="Apache",
|
||||
author="The HuggingFace team",
|
||||
author_email="zach.mueller@huggingface.co",
|
||||
url="https://github.com/huggingface/accelerate",
|
||||
package_dir={"": "src"},
|
||||
packages=find_packages("src"),
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"accelerate=accelerate.commands.accelerate_cli:main",
|
||||
"accelerate-config=accelerate.commands.config:main",
|
||||
"accelerate-estimate-memory=accelerate.commands.estimate:main",
|
||||
"accelerate-launch=accelerate.commands.launch:main",
|
||||
"accelerate-merge-weights=accelerate.commands.merge:main",
|
||||
]
|
||||
},
|
||||
python_requires=">=3.9.0",
|
||||
install_requires=[
|
||||
"numpy>=1.17,<3.0.0",
|
||||
"packaging>=20.0",
|
||||
"psutil",
|
||||
"pyyaml",
|
||||
"torch>=1.10.0",
|
||||
"huggingface_hub>=0.21.0",
|
||||
"safetensors>=0.4.3",
|
||||
],
|
||||
extras_require=extras,
|
||||
classifiers=[
|
||||
"Development Status :: 5 - Production/Stable",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Education",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
],
|
||||
)
|
||||
|
||||
# Release checklist
|
||||
# 1. Checkout the release branch (for a patch the current release branch, for a new minor version, create one):
|
||||
# git checkout -b vXX.xx-release
|
||||
# The -b is only necessary for creation (so remove it when doing a patch)
|
||||
# 2. Change the version in __init__.py and setup.py to the proper value.
|
||||
# 3. Commit these changes with the message: "Release: v<VERSION>"
|
||||
# 4. Add a tag in git to mark the release:
|
||||
# git tag v<VERSION> -m 'Adds tag v<VERSION> for pypi'
|
||||
# Push the tag and release commit to git: git push --tags origin vXX.xx-release
|
||||
# 5. Run the following commands in the top-level directory:
|
||||
# rm -rf dist
|
||||
# rm -rf build
|
||||
# python setup.py bdist_wheel
|
||||
# python setup.py sdist
|
||||
# 6. Upload the package to the pypi test server first:
|
||||
# twine upload dist/* -r testpypi
|
||||
# 7. Check that you can install it in a virtualenv by running:
|
||||
# pip install accelerate
|
||||
# pip uninstall accelerate
|
||||
# pip install -i https://testpypi.python.org/pypi accelerate
|
||||
# accelerate env
|
||||
# accelerate test
|
||||
# 8. Upload the final version to actual pypi:
|
||||
# twine upload dist/* -r pypi
|
||||
# 9. Add release notes to the tag in github once everything is looking hunky-dory.
|
||||
# 10. Go back to the main branch and update the version in __init__.py, setup.py to the new version ".dev" and push to
|
||||
# main.
|
|
@ -19,19 +19,25 @@ import uuid
|
|||
from pathlib import Path
|
||||
|
||||
from transformers.agents.agent_types import AgentAudio, AgentImage, AgentText
|
||||
from transformers.testing_utils import get_tests_dir, require_soundfile, require_torch, require_vision
|
||||
from transformers.utils import is_soundfile_availble, is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import (
|
||||
get_tests_dir,
|
||||
require_soundfile,
|
||||
require_torch,
|
||||
require_vision,
|
||||
)
|
||||
from transformers.utils import (
|
||||
is_soundfile_availble,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_soundfile_availble():
|
||||
import soundfile as sf
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_new_path(suffix="") -> str:
|
||||
directory = tempfile.mkdtemp()
|
||||
|
|
|
@ -19,17 +19,16 @@ import uuid
|
|||
|
||||
import pytest
|
||||
|
||||
from transformers.agents.agent_types import AgentText
|
||||
from transformers.agents.agents import (
|
||||
from agents.agent_types import AgentText
|
||||
from agents.agents import (
|
||||
AgentMaxIterationsError,
|
||||
CodeAgent,
|
||||
ManagedAgent,
|
||||
ReactCodeAgent,
|
||||
ReactJsonAgent,
|
||||
CodeAgent,
|
||||
JsonAgent,
|
||||
Toolbox,
|
||||
)
|
||||
from transformers.agents.default_tools import PythonInterpreterTool
|
||||
from transformers.testing_utils import require_torch
|
||||
from agents.default_tools import PythonInterpreterTool
|
||||
|
||||
|
||||
def get_new_path(suffix="") -> str:
|
||||
|
@ -149,19 +148,26 @@ print(result)
|
|||
|
||||
class AgentTests(unittest.TestCase):
|
||||
def test_fake_code_agent(self):
|
||||
agent = CodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_oneshot)
|
||||
agent = CodeAgent(
|
||||
tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_oneshot
|
||||
)
|
||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert isinstance(output, str)
|
||||
assert output == "7.2904"
|
||||
|
||||
def test_fake_react_json_agent(self):
|
||||
agent = ReactJsonAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_json_llm)
|
||||
agent = JsonAgent(
|
||||
tools=[PythonInterpreterTool()], llm_engine=fake_react_json_llm
|
||||
)
|
||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert isinstance(output, str)
|
||||
assert output == "7.2904"
|
||||
assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?"
|
||||
assert agent.logs[1]["observation"] == "7.2904"
|
||||
assert agent.logs[1]["rationale"].strip() == "Thought: I should multiply 2 by 3.6452. special_marker"
|
||||
assert (
|
||||
agent.logs[1]["rationale"].strip()
|
||||
== "Thought: I should multiply 2 by 3.6452. special_marker"
|
||||
)
|
||||
assert (
|
||||
agent.logs[2]["llm_output"]
|
||||
== """
|
||||
|
@ -175,7 +181,9 @@ Action:
|
|||
)
|
||||
|
||||
def test_fake_react_code_agent(self):
|
||||
agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm)
|
||||
agent = CodeAgent(
|
||||
tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm
|
||||
)
|
||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert isinstance(output, float)
|
||||
assert output == 7.2904
|
||||
|
@ -186,17 +194,19 @@ Action:
|
|||
}
|
||||
|
||||
def test_react_code_agent_code_errors_show_offending_lines(self):
|
||||
agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm_error)
|
||||
agent = CodeAgent(
|
||||
tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm_error
|
||||
)
|
||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert isinstance(output, AgentText)
|
||||
assert output == "got an error"
|
||||
assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs)
|
||||
|
||||
def test_setup_agent_with_empty_toolbox(self):
|
||||
ReactJsonAgent(llm_engine=fake_react_json_llm, tools=[])
|
||||
JsonAgent(llm_engine=fake_react_json_llm, tools=[])
|
||||
|
||||
def test_react_fails_max_iterations(self):
|
||||
agent = ReactCodeAgent(
|
||||
agent = CodeAgent(
|
||||
tools=[PythonInterpreterTool()],
|
||||
llm_engine=fake_code_llm_no_return, # use this callable because it never ends
|
||||
max_iterations=5,
|
||||
|
@ -208,51 +218,62 @@ Action:
|
|||
@require_torch
|
||||
def test_init_agent_with_different_toolsets(self):
|
||||
toolset_1 = []
|
||||
agent = ReactCodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm)
|
||||
agent = CodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm)
|
||||
assert (
|
||||
len(agent.toolbox.tools) == 1
|
||||
) # when no tools are provided, only the final_answer tool is added by default
|
||||
|
||||
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
|
||||
agent = ReactCodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm)
|
||||
agent = CodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm)
|
||||
assert (
|
||||
len(agent.toolbox.tools) == 2
|
||||
) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer
|
||||
|
||||
toolset_3 = Toolbox(toolset_2)
|
||||
agent = ReactCodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm)
|
||||
agent = CodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm)
|
||||
assert (
|
||||
len(agent.toolbox.tools) == 2
|
||||
) # same as previous one, where toolset_3 is an instantiation of previous one
|
||||
|
||||
# check that add_base_tools will not interfere with existing tools
|
||||
with pytest.raises(KeyError) as e:
|
||||
agent = ReactJsonAgent(tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True)
|
||||
agent = JsonAgent(
|
||||
tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True
|
||||
)
|
||||
assert "already exists in the toolbox" in str(e)
|
||||
|
||||
# check that python_interpreter base tool does not get added to code agents
|
||||
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
|
||||
assert len(agent.toolbox.tools) == 7 # added final_answer tool + 6 base tools (excluding interpreter)
|
||||
agent = CodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
|
||||
assert (
|
||||
len(agent.toolbox.tools) == 7
|
||||
) # added final_answer tool + 6 base tools (excluding interpreter)
|
||||
|
||||
def test_function_persistence_across_steps(self):
|
||||
agent = ReactCodeAgent(
|
||||
tools=[], llm_engine=fake_react_code_functiondef, max_iterations=2, additional_authorized_imports=["numpy"]
|
||||
agent = CodeAgent(
|
||||
tools=[],
|
||||
llm_engine=fake_react_code_functiondef,
|
||||
max_iterations=2,
|
||||
additional_authorized_imports=["numpy"],
|
||||
)
|
||||
res = agent.run("ok")
|
||||
assert res[0] == 0.5
|
||||
|
||||
def test_init_managed_agent(self):
|
||||
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_functiondef)
|
||||
agent = CodeAgent(tools=[], llm_engine=fake_react_code_functiondef)
|
||||
managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
|
||||
assert managed_agent.name == "managed_agent"
|
||||
assert managed_agent.description == "Empty"
|
||||
|
||||
def test_agent_description_gets_correctly_inserted_in_system_prompt(self):
|
||||
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_functiondef)
|
||||
agent = CodeAgent(tools=[], llm_engine=fake_react_code_functiondef)
|
||||
managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
|
||||
manager_agent = ReactCodeAgent(
|
||||
tools=[], llm_engine=fake_react_code_functiondef, managed_agents=[managed_agent]
|
||||
manager_agent = CodeAgent(
|
||||
tools=[],
|
||||
llm_engine=fake_react_code_functiondef,
|
||||
managed_agents=[managed_agent],
|
||||
)
|
||||
assert "You can also give requests to team members." not in agent.system_prompt
|
||||
assert "<<managed_agents_descriptions>>" not in agent.system_prompt
|
||||
assert "You can also give requests to team members." in manager_agent.system_prompt
|
||||
assert (
|
||||
"You can also give requests to team members." in manager_agent.system_prompt
|
||||
)
|
||||
|
|
|
@ -1,41 +0,0 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("document_question_answering")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
document = dataset[0]["image"]
|
||||
|
||||
result = self.tool(document, "When is the coffee break?")
|
||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
document = dataset[0]["image"]
|
||||
|
||||
self.tool(document=document, question="When is the coffee break?")
|
|
@ -87,7 +87,11 @@ class ExampleDifferenceTests(unittest.TestCase):
|
|||
examples_path = Path("examples").resolve()
|
||||
|
||||
def one_complete_example(
|
||||
self, complete_file_name: str, parser_only: bool, secondary_filename: str = None, special_strings: list = None
|
||||
self,
|
||||
complete_file_name: str,
|
||||
parser_only: bool,
|
||||
secondary_filename: str = None,
|
||||
special_strings: list = None,
|
||||
):
|
||||
"""
|
||||
Tests a single `complete` example against all of the implemented `by_feature` scripts
|
||||
|
@ -112,10 +116,15 @@ class ExampleDifferenceTests(unittest.TestCase):
|
|||
with self.subTest(
|
||||
tested_script=complete_file_name,
|
||||
feature_script=item,
|
||||
tested_section="main()" if parser_only else "training_function()",
|
||||
tested_section="main()"
|
||||
if parser_only
|
||||
else "training_function()",
|
||||
):
|
||||
diff = compare_against_test(
|
||||
self.examples_path / complete_file_name, item_path, parser_only, secondary_filename
|
||||
self.examples_path / complete_file_name,
|
||||
item_path,
|
||||
parser_only,
|
||||
secondary_filename,
|
||||
)
|
||||
diff = "\n".join(diff)
|
||||
if special_strings is not None:
|
||||
|
@ -140,8 +149,12 @@ class ExampleDifferenceTests(unittest.TestCase):
|
|||
" " * 12,
|
||||
" " * 8 + "for step, batch in enumerate(active_dataloader):\n",
|
||||
]
|
||||
self.one_complete_example("complete_cv_example.py", True, cv_path, special_strings)
|
||||
self.one_complete_example("complete_cv_example.py", False, cv_path, special_strings)
|
||||
self.one_complete_example(
|
||||
"complete_cv_example.py", True, cv_path, special_strings
|
||||
)
|
||||
self.one_complete_example(
|
||||
"complete_cv_example.py", False, cv_path, special_strings
|
||||
)
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"TESTING_MOCKED_DATALOADERS": "1"})
|
||||
|
|
|
@ -47,9 +47,9 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
|
|||
def create_inputs(self):
|
||||
inputs_text = {"answer": "Text input"}
|
||||
inputs_image = {
|
||||
"answer": Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize(
|
||||
(512, 512)
|
||||
)
|
||||
"answer": Image.open(
|
||||
Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png"
|
||||
).resize((512, 512))
|
||||
}
|
||||
inputs_audio = {"answer": torch.Tensor(np.ones(3000))}
|
||||
return {"string": inputs_text, "image": inputs_image, "audio": inputs_audio}
|
||||
|
|
|
@ -1,42 +0,0 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import is_vision_available, load_tool
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("image_question_answering")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image, "How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image=image, question="How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
|
@ -129,7 +129,9 @@ final_answer('This is the final answer.')
|
|||
|
||||
def test_streaming_agent_image_output(self):
|
||||
def dummy_llm_engine(prompt, **kwargs):
|
||||
return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
|
||||
return (
|
||||
'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
|
||||
)
|
||||
|
||||
agent = ReactJsonAgent(
|
||||
tools=[],
|
||||
|
@ -138,7 +140,14 @@ final_answer('This is the final answer.')
|
|||
)
|
||||
|
||||
# Use stream_to_gradio to capture the output
|
||||
outputs = list(stream_to_gradio(agent, task="Test task", image=AgentImage(value="path.png"), test_mode=True))
|
||||
outputs = list(
|
||||
stream_to_gradio(
|
||||
agent,
|
||||
task="Test task",
|
||||
image=AgentImage(value="path.png"),
|
||||
test_mode=True,
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(len(outputs), 2)
|
||||
final_message = outputs[-1]
|
||||
|
|
|
@ -21,7 +21,10 @@ import pytest
|
|||
from transformers import load_tool
|
||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
|
||||
from transformers.agents.default_tools import BASE_PYTHON_TOOLS
|
||||
from transformers.agents.python_interpreter import InterpreterError, evaluate_python_code
|
||||
from transformers.agents.python_interpreter import (
|
||||
InterpreterError,
|
||||
evaluate_python_code,
|
||||
)
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
@ -57,7 +60,12 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
|||
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
|
||||
input_type = expected_input["type"]
|
||||
if isinstance(input_type, list):
|
||||
_inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type])
|
||||
_inputs.append(
|
||||
[
|
||||
AGENT_TYPE_MAPPING[_input_type](_input)
|
||||
for _input_type in input_type
|
||||
]
|
||||
)
|
||||
else:
|
||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
||||
|
||||
|
@ -91,7 +99,10 @@ class PythonInterpreterTester(unittest.TestCase):
|
|||
code = "print = '3'"
|
||||
with pytest.raises(InterpreterError) as e:
|
||||
evaluate_python_code(code, {"print": print}, state={})
|
||||
assert "Cannot assign to name 'print': doing this would erase the existing tool!" in str(e)
|
||||
assert (
|
||||
"Cannot assign to name 'print': doing this would erase the existing tool!"
|
||||
in str(e)
|
||||
)
|
||||
|
||||
def test_evaluate_call(self):
|
||||
code = "y = add_two(x)"
|
||||
|
@ -117,7 +128,9 @@ class PythonInterpreterTester(unittest.TestCase):
|
|||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||
self.assertDictEqual(result, {"x": 3, "y": 5})
|
||||
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
|
||||
self.assertDictEqual(
|
||||
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}
|
||||
)
|
||||
|
||||
def test_evaluate_expression(self):
|
||||
code = "x = 3\ny = 5"
|
||||
|
@ -133,7 +146,9 @@ class PythonInterpreterTester(unittest.TestCase):
|
|||
result = evaluate_python_code(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == "This is x: 3."
|
||||
self.assertDictEqual(state, {"x": 3, "text": "This is x: 3.", "print_outputs": ""})
|
||||
self.assertDictEqual(
|
||||
state, {"x": 3, "text": "This is x: 3.", "print_outputs": ""}
|
||||
)
|
||||
|
||||
def test_evaluate_if(self):
|
||||
code = "if x <= 3:\n y = 2\nelse:\n y = 5"
|
||||
|
@ -174,11 +189,15 @@ class PythonInterpreterTester(unittest.TestCase):
|
|||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
|
||||
self.assertDictEqual(
|
||||
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}
|
||||
)
|
||||
|
||||
code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)"
|
||||
state = {}
|
||||
evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state)
|
||||
evaluate_python_code(
|
||||
code, {"min": min, "print": print, "round": round}, state=state
|
||||
)
|
||||
assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62}
|
||||
|
||||
def test_subscript_string_with_string_index_raises_appropriate_error(self):
|
||||
|
@ -292,7 +311,16 @@ print(check_digits)
|
|||
"""
|
||||
state = {}
|
||||
evaluate_python_code(
|
||||
code, {"range": range, "print": print, "sum": sum, "enumerate": enumerate, "int": int, "str": str}, state
|
||||
code,
|
||||
{
|
||||
"range": range,
|
||||
"print": print,
|
||||
"sum": sum,
|
||||
"enumerate": enumerate,
|
||||
"int": int,
|
||||
"str": str,
|
||||
},
|
||||
state,
|
||||
)
|
||||
|
||||
def test_listcomp(self):
|
||||
|
@ -325,7 +353,9 @@ print(check_digits)
|
|||
assert result == {0: 0, 1: 1, 2: 4}
|
||||
|
||||
code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}"
|
||||
result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
|
||||
result = evaluate_python_code(
|
||||
code, {"print": print}, state={}, authorized_imports=["pandas"]
|
||||
)
|
||||
assert result == {102: "b"}
|
||||
|
||||
code = """
|
||||
|
@ -373,7 +403,9 @@ else:
|
|||
best_city = "Manhattan"
|
||||
best_city
|
||||
"""
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
|
||||
result = evaluate_python_code(
|
||||
code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
|
||||
)
|
||||
assert result == "Brooklyn"
|
||||
|
||||
code = """if d > e and a < b:
|
||||
|
@ -384,7 +416,9 @@ else:
|
|||
best_city = "Manhattan"
|
||||
best_city
|
||||
"""
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
|
||||
result = evaluate_python_code(
|
||||
code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
|
||||
)
|
||||
assert result == "Sacramento"
|
||||
|
||||
def test_if_conditions(self):
|
||||
|
@ -400,7 +434,9 @@ if char.isalpha():
|
|||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == 2.0
|
||||
|
||||
code = "from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
|
||||
code = (
|
||||
"from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
|
||||
)
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == "lose"
|
||||
|
||||
|
@ -434,10 +470,14 @@ if char.isalpha():
|
|||
|
||||
# Test submodules are handled properly, thus not raising error
|
||||
code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
|
||||
result = evaluate_python_code(
|
||||
code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]
|
||||
)
|
||||
|
||||
code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
|
||||
result = evaluate_python_code(
|
||||
code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]
|
||||
)
|
||||
|
||||
def test_additional_imports(self):
|
||||
code = "import numpy as np"
|
||||
|
@ -554,7 +594,11 @@ cat_sound = cat.sound()
|
|||
cat_str = str(cat)
|
||||
"""
|
||||
state = {}
|
||||
evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state)
|
||||
evaluate_python_code(
|
||||
code,
|
||||
{"print": print, "len": len, "super": super, "str": str, "sum": sum},
|
||||
state=state,
|
||||
)
|
||||
|
||||
# Assert results
|
||||
assert state["dog1_sound"] == "The dog barks."
|
||||
|
@ -588,7 +632,11 @@ except ValueError as e:
|
|||
exception_message = str(e)
|
||||
"""
|
||||
state = {}
|
||||
evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state)
|
||||
evaluate_python_code(
|
||||
code,
|
||||
{"print": print, "len": len, "super": super, "str": str, "sum": sum},
|
||||
state=state,
|
||||
)
|
||||
assert state["exception_message"] == "An error occurred"
|
||||
|
||||
def test_print(self):
|
||||
|
@ -600,7 +648,9 @@ except ValueError as e:
|
|||
def test_types_as_objects(self):
|
||||
code = "type_a = float(2); type_b = str; type_c = int"
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {"float": float, "str": str, "int": int}, state=state)
|
||||
result = evaluate_python_code(
|
||||
code, {"float": float, "str": str, "int": int}, state=state
|
||||
)
|
||||
assert result is int
|
||||
|
||||
def test_tuple_id(self):
|
||||
|
@ -731,7 +781,9 @@ def add_one(n, shift):
|
|||
add_one(1, 1)
|
||||
"""
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
|
||||
result = evaluate_python_code(
|
||||
code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state
|
||||
)
|
||||
assert result == 2
|
||||
|
||||
# test returning None
|
||||
|
@ -742,7 +794,9 @@ def returns_none(a):
|
|||
returns_none(1)
|
||||
"""
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
|
||||
result = evaluate_python_code(
|
||||
code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_nested_for_loop(self):
|
||||
|
@ -758,7 +812,9 @@ out = [i for sublist in all_res for i in sublist]
|
|||
out[:10]
|
||||
"""
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {"print": print, "range": range}, state=state)
|
||||
result = evaluate_python_code(
|
||||
code, {"print": print, "range": range}, state=state
|
||||
)
|
||||
assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]
|
||||
|
||||
def test_pandas(self):
|
||||
|
@ -773,7 +829,9 @@ parts_with_5_set_count = df[df['SetCount'] == 5.0]
|
|||
parts_with_5_set_count[['Quantity', 'SetCount']].values[1]
|
||||
"""
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {}, state=state, authorized_imports=["pandas"])
|
||||
result = evaluate_python_code(
|
||||
code, {}, state=state, authorized_imports=["pandas"]
|
||||
)
|
||||
assert np.array_equal(result, [-1, 5])
|
||||
|
||||
code = """
|
||||
|
@ -785,7 +843,9 @@ print("HH0")
|
|||
# Filter the DataFrame to get only the rows with outdated atomic numbers
|
||||
filtered_df = df.loc[df['AtomicNumber'].isin([104])]
|
||||
"""
|
||||
result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
|
||||
result = evaluate_python_code(
|
||||
code, {"print": print}, state={}, authorized_imports=["pandas"]
|
||||
)
|
||||
assert np.array_equal(result.values[0], [104, 1])
|
||||
|
||||
code = """import pandas as pd
|
||||
|
@ -818,7 +878,9 @@ coords_barcelona = (41.3869, 2.1660)
|
|||
|
||||
distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona)
|
||||
"""
|
||||
result = evaluate_python_code(code, {"print": print, "map": map}, state={}, authorized_imports=["math"])
|
||||
result = evaluate_python_code(
|
||||
code, {"print": print, "map": map}, state={}, authorized_imports=["math"]
|
||||
)
|
||||
assert round(result, 1) == 622395.4
|
||||
|
||||
def test_for(self):
|
||||
|
|
|
@ -1,36 +0,0 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
class SpeechToTextToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("speech_to_text")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool(np.ones(3000))
|
||||
self.assertEqual(result, " Thank you.")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(audio=np.ones(3000))
|
||||
self.assertEqual(result, " Thank you.")
|
|
@ -1,50 +0,0 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
from transformers.utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
@require_torch
|
||||
class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("text_to_speech")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
# SpeechT5 isn't deterministic
|
||||
torch.manual_seed(0)
|
||||
result = self.tool("hey")
|
||||
resulting_tensor = result.to_raw()
|
||||
self.assertTrue(len(resulting_tensor.detach().shape) == 1)
|
||||
self.assertTrue(resulting_tensor.detach().shape[0] > 1000)
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
# SpeechT5 isn't deterministic
|
||||
torch.manual_seed(0)
|
||||
result = self.tool("hey")
|
||||
resulting_tensor = result.to_raw()
|
||||
self.assertTrue(len(resulting_tensor.detach().shape) == 1)
|
||||
self.assertTrue(resulting_tensor.detach().shape[0] > 1000)
|
|
@ -20,7 +20,12 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
from transformers import is_torch_available, is_vision_available
|
||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText
|
||||
from transformers.agents.agent_types import (
|
||||
AGENT_TYPE_MAPPING,
|
||||
AgentAudio,
|
||||
AgentImage,
|
||||
AgentText,
|
||||
)
|
||||
from transformers.agents.tools import Tool, tool
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
|
|
|
@ -32,7 +32,9 @@ class TranslationToolTester(unittest.TestCase, ToolTesterMixin):
|
|||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text="Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
result = self.tool(
|
||||
text="Hey, what's up?", src_lang="English", tgt_lang="French"
|
||||
)
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_call(self):
|
||||
|
|
Loading…
Reference in New Issue