Add LiteLLM engine
This commit is contained in:
parent
762ae9cfae
commit
1e357cee7f
|
@ -48,9 +48,9 @@ pip install agents
|
||||||
```
|
```
|
||||||
Then define your agent, give it the tools it needs and run it!
|
Then define your agent, give it the tools it needs and run it!
|
||||||
```py
|
```py
|
||||||
from smolagents import CodeAgent, WebSearchTool
|
from smolagents import CodeAgent, DuckDuckGoSearchTool, HfApiEngine
|
||||||
|
|
||||||
agent = CodeAgent(tools=[WebSearchTool()])
|
agent = CodeAgent(tools=[DuckDuckGoSearchTool()], llm_engine=HfApiEngine())
|
||||||
|
|
||||||
agent.run("What time would the world's fastest car take to travel from New York to San Francisco?")
|
agent.run("What time would the world's fastest car take to travel from New York to San Francisco?")
|
||||||
```
|
```
|
||||||
|
@ -68,6 +68,6 @@ Especially, since code execution can be a security concern (arbitrary code execu
|
||||||
|
|
||||||
## How lightweight is it?
|
## How lightweight is it?
|
||||||
|
|
||||||
We strived to keep abstractions to a strict minimum, with the main code in `agents.py` being roughly 1,000 lines of code, and still being quite complete, with several types of agents implemented: `CodeAgent` writing its actions in code snippets, `JsonAgent`, `ToolCallingAgent`...
|
We strived to keep abstractions to a strict minimum, with the main code in `agents.py` being roughly 1,000 lines of code, and still being quite complete, with several types of agents implemented: `CodeAgent` writing its actions in code snippets, and the more classic `ToolCallingAgent` that leverage built-in tool calling methods.
|
||||||
|
|
||||||
Many people ask: why use a framework at all? Well, because a big part of this stuff is non-trivial. For instance, the code agent has to keep a consistent format for code throughout its system prompt, its parser, the execution. Its variables have to be properly handled throughout. So our framework handles this complexity for you.
|
Many people ask: why use a framework at all? Well, because a big part of this stuff is non-trivial. For instance, the code agent has to keep a consistent format for code throughout its system prompt, its parser, the execution. So our framework handles this complexity for you. But of course we still encourage you to hack into the source code and use only the bits that you need, to the exclusion of everything else!
|
||||||
|
|
|
@ -39,9 +39,9 @@ Here is a video overview of how that works:
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
We implement two versions of JsonAgent:
|
We implement two versions of ToolCallingAgent:
|
||||||
- [`JsonAgent`] generates tool calls as a JSON in its output.
|
- [`ToolCallingAgent`] generates tool calls as a JSON in its output.
|
||||||
- [`CodeAgent`] is a new type of JsonAgent that generates its tool calls as blobs of code, which works really well for LLMs that have strong coding performance.
|
- [`CodeAgent`] is a new type of ToolCallingAgent that generates its tool calls as blobs of code, which works really well for LLMs that have strong coding performance.
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> We also provide an option to run agents in one-shot: just pass `single_step=True` when launching the agent, like `agent.run(your_task, single_step=True)`
|
> We also provide an option to run agents in one-shot: just pass `single_step=True` when launching the agent, like `agent.run(your_task, single_step=True)`
|
|
@ -69,19 +69,6 @@ agent.run(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
You can even leave the argument `llm_engine` undefined, and an [`HfApiEngine`] will be created by default.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from smolagents import CodeAgent
|
|
||||||
|
|
||||||
agent = CodeAgent(tools=[], add_base_tools=True)
|
|
||||||
|
|
||||||
agent.run(
|
|
||||||
"Could you give me the 118th number in the Fibonacci sequence?",
|
|
||||||
additional_detail="We adopt the convention where the first two numbers are 0 and 1."
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
Note that we used an additional `additional_detail` argument: you can additional kwargs to `agent.run()`, they will be baked into the prompt as text.
|
Note that we used an additional `additional_detail` argument: you can additional kwargs to `agent.run()`, they will be baked into the prompt as text.
|
||||||
|
|
||||||
You can use this to indicate the path to local or remote files for the model to use:
|
You can use this to indicate the path to local or remote files for the model to use:
|
||||||
|
@ -89,7 +76,7 @@ You can use this to indicate the path to local or remote files for the model to
|
||||||
```py
|
```py
|
||||||
from smolagents import CodeAgent, Tool, SpeechToTextTool
|
from smolagents import CodeAgent, Tool, SpeechToTextTool
|
||||||
|
|
||||||
agent = CodeAgent(tools=[SpeechToTextTool()], add_base_tools=True)
|
agent = CodeAgent(tools=[SpeechToTextTool()], llm_engine=llm_engine, add_base_tools=True)
|
||||||
|
|
||||||
agent.run("Why does Mike not know many people in New York?", audio="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/recording.mp3")
|
agent.run("Why does Mike not know many people in New York?", audio="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/recording.mp3")
|
||||||
```
|
```
|
||||||
|
@ -109,7 +96,7 @@ You can authorize additional imports by passing the authorized modules as a list
|
||||||
```py
|
```py
|
||||||
from smolagents import CodeAgent
|
from smolagents import CodeAgent
|
||||||
|
|
||||||
agent = CodeAgent(tools=[], additional_authorized_imports=['requests', 'bs4'])
|
agent = CodeAgent(tools=[], llm_engine=llm_engine, additional_authorized_imports=['requests', 'bs4'])
|
||||||
agent.run("Could you get me the title of the page at url 'https://huggingface.co/blog'?")
|
agent.run("Could you get me the title of the page at url 'https://huggingface.co/blog'?")
|
||||||
```
|
```
|
||||||
This gives you at the end of the agent run:
|
This gives you at the end of the agent run:
|
||||||
|
@ -178,11 +165,11 @@ You could improve the system prompt, for example, by adding an explanation of th
|
||||||
For maximum flexibility, you can overwrite the whole system prompt template by passing your custom prompt as an argument to the `system_prompt` parameter.
|
For maximum flexibility, you can overwrite the whole system prompt template by passing your custom prompt as an argument to the `system_prompt` parameter.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from smolagents import JsonAgent, PythonInterpreterTool, JSON_SYSTEM_PROMPT
|
from smolagents import ToolCallingAgent, PythonInterpreterTool, JSON_SYSTEM_PROMPT
|
||||||
|
|
||||||
modified_prompt = JSON_SYSTEM_PROMPT
|
modified_prompt = JSON_SYSTEM_PROMPT
|
||||||
|
|
||||||
agent = JsonAgent(tools=[PythonInterpreterTool()], system_prompt=modified_prompt)
|
agent = ToolCallingAgent(tools=[PythonInterpreterTool()], llm_engine=llm_engine, system_prompt=modified_prompt)
|
||||||
```
|
```
|
||||||
|
|
||||||
> [!WARNING]
|
> [!WARNING]
|
||||||
|
@ -209,7 +196,7 @@ When the agent is initialized, the tool attributes are used to generate a tool d
|
||||||
Transformers comes with a default toolbox for empowering agents, that you can add to your agent upon initialization with argument `add_base_tools = True`:
|
Transformers comes with a default toolbox for empowering agents, that you can add to your agent upon initialization with argument `add_base_tools = True`:
|
||||||
|
|
||||||
- **DuckDuckGo web search***: performs a web search using DuckDuckGo browser.
|
- **DuckDuckGo web search***: performs a web search using DuckDuckGo browser.
|
||||||
- **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`JsonAgent`] if you initialize it with `add_base_tools=True`, since code-based agent can already natively execute Python code
|
- **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`ToolCallingAgent`] if you initialize it with `add_base_tools=True`, since code-based agent can already natively execute Python code
|
||||||
- **Transcriber**: a speech-to-text pipeline built on Whisper-Turbo that transcribes an audio to text.
|
- **Transcriber**: a speech-to-text pipeline built on Whisper-Turbo that transcribes an audio to text.
|
||||||
|
|
||||||
You can manually use a tool by calling the [`load_tool`] function and a task to perform.
|
You can manually use a tool by calling the [`load_tool`] function and a task to perform.
|
||||||
|
|
|
@ -31,7 +31,7 @@ Our agents inherit from [`MultiStepAgent`], which means they can act in multiple
|
||||||
|
|
||||||
We provide two types of agents, based on the main [`Agent`] class.
|
We provide two types of agents, based on the main [`Agent`] class.
|
||||||
- [`CodeAgent`] is the default agent, it writes its tool calls in Python code.
|
- [`CodeAgent`] is the default agent, it writes its tool calls in Python code.
|
||||||
- [`JsonAgent`] writes its tool calls in JSON.
|
- [`ToolCallingAgent`] writes its tool calls in JSON.
|
||||||
|
|
||||||
|
|
||||||
### Classes of agents
|
### Classes of agents
|
||||||
|
@ -40,7 +40,7 @@ We provide two types of agents, based on the main [`Agent`] class.
|
||||||
|
|
||||||
[[autodoc]] CodeAgent
|
[[autodoc]] CodeAgent
|
||||||
|
|
||||||
[[autodoc]] JsonAgent
|
[[autodoc]] ToolCallingAgent
|
||||||
|
|
||||||
|
|
||||||
### ManagedAgent
|
### ManagedAgent
|
||||||
|
|
|
@ -68,9 +68,10 @@ To set the code executor to E2B, simply pass the flag `use_e2b_executor=True` wh
|
||||||
Note that you should add all the tool's dependencies in `additional_authorized_imports`, so that the executor installs them.
|
Note that you should add all the tool's dependencies in `additional_authorized_imports`, so that the executor installs them.
|
||||||
|
|
||||||
```py
|
```py
|
||||||
from smolagents import CodeAgent, VisitWebpageTool
|
from smolagents import CodeAgent, VisitWebpageTool, HfApiEngine
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools = [VisitWebpageTool()],
|
tools = [VisitWebpageTool()],
|
||||||
|
llm_engine=HfApiEngine(),
|
||||||
additional_authorized_imports=["requests", "markdownify"],
|
additional_authorized_imports=["requests", "markdownify"],
|
||||||
use_e2b_executor=True
|
use_e2b_executor=True
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from smolagents import Tool, CodeAgent
|
from smolagents import Tool, CodeAgent, HfApiEngine
|
||||||
from smolagents.default_tools.search import VisitWebpageTool
|
from smolagents.default_tools import VisitWebpageTool
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
@ -30,6 +30,7 @@ get_cat_image = GetCatImageTool()
|
||||||
|
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools = [get_cat_image, VisitWebpageTool()],
|
tools = [get_cat_image, VisitWebpageTool()],
|
||||||
|
llm_engine=HfApiEngine(),
|
||||||
additional_authorized_imports=["Pillow", "requests", "markdownify"], # "duckduckgo-search",
|
additional_authorized_imports=["Pillow", "requests", "markdownify"], # "duckduckgo-search",
|
||||||
use_e2b_executor=False
|
use_e2b_executor=False
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
from smolagents.agents import ToolCallingAgent
|
from smolagents.agents import ToolCallingAgent
|
||||||
from smolagents import tool, HfApiEngine, OpenAIEngine, AnthropicEngine
|
from smolagents import tool, HfApiEngine, OpenAIEngine, AnthropicEngine, TransformersEngine, LiteLLMEngine
|
||||||
|
|
||||||
# Choose which LLM engine to use!
|
# Choose which LLM engine to use!
|
||||||
llm_engine = OpenAIEngine("gpt-4o")
|
# llm_engine = OpenAIEngine("gpt-4o")
|
||||||
llm_engine = AnthropicEngine("claude-3-5-sonnet-20240620")
|
# llm_engine = AnthropicEngine("claude-3-5-sonnet-20240620")
|
||||||
llm_engine = HfApiEngine("meta-llama/Llama-3.3-70B-Instruct")
|
# llm_engine = HfApiEngine("meta-llama/Llama-3.3-70B-Instruct")
|
||||||
|
# llm_engine = TransformersEngine("meta-llama/Llama-3.2-2B-Instruct")
|
||||||
|
llm_engine = LiteLLMEngine()
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def get_weather(location: str) -> str:
|
def get_weather(location: str) -> str:
|
|
@ -23,10 +23,20 @@ from rich.panel import Panel
|
||||||
from rich.rule import Rule
|
from rich.rule import Rule
|
||||||
from rich.text import Text
|
from rich.text import Text
|
||||||
|
|
||||||
from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content
|
from .utils import (
|
||||||
|
console,
|
||||||
|
parse_code_blob,
|
||||||
|
parse_json_tool_call,
|
||||||
|
truncate_content,
|
||||||
|
AgentError,
|
||||||
|
AgentParsingError,
|
||||||
|
AgentExecutionError,
|
||||||
|
AgentGenerationError,
|
||||||
|
AgentMaxIterationsError,
|
||||||
|
)
|
||||||
from .types import AgentAudio, AgentImage
|
from .types import AgentAudio, AgentImage
|
||||||
from .default_tools import FinalAnswerTool
|
from .default_tools import FinalAnswerTool
|
||||||
from .llm_engines import HfApiEngine, MessageRole
|
from .llm_engines import MessageRole
|
||||||
from .monitoring import Monitor
|
from .monitoring import Monitor
|
||||||
from .prompts import (
|
from .prompts import (
|
||||||
CODE_SYSTEM_PROMPT,
|
CODE_SYSTEM_PROMPT,
|
||||||
|
@ -52,39 +62,6 @@ from .tools import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AgentError(Exception):
|
|
||||||
"""Base class for other agent-related exceptions"""
|
|
||||||
|
|
||||||
def __init__(self, message):
|
|
||||||
super().__init__(message)
|
|
||||||
self.message = message
|
|
||||||
console.print(f"[bold red]{message}[/bold red]")
|
|
||||||
|
|
||||||
|
|
||||||
class AgentParsingError(AgentError):
|
|
||||||
"""Exception raised for errors in parsing in the agent"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class AgentExecutionError(AgentError):
|
|
||||||
"""Exception raised for errors in execution in the agent"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class AgentMaxIterationsError(AgentError):
|
|
||||||
"""Exception raised for errors in execution in the agent"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class AgentGenerationError(AgentError):
|
|
||||||
"""Exception raised for errors in generation in the agent"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolCall:
|
class ToolCall:
|
||||||
name: str
|
name: str
|
||||||
|
@ -171,8 +148,10 @@ def format_prompt_with_managed_agents_descriptions(
|
||||||
else:
|
else:
|
||||||
return prompt_template.replace(agent_descriptions_placeholder, "")
|
return prompt_template.replace(agent_descriptions_placeholder, "")
|
||||||
|
|
||||||
|
|
||||||
YELLOW_HEX = "#ffdd00"
|
YELLOW_HEX = "#ffdd00"
|
||||||
|
|
||||||
|
|
||||||
class MultiStepAgent:
|
class MultiStepAgent:
|
||||||
"""
|
"""
|
||||||
Agent class that solves the given task step by step, using the ReAct framework:
|
Agent class that solves the given task step by step, using the ReAct framework:
|
||||||
|
@ -182,7 +161,7 @@ class MultiStepAgent:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tools: Union[List[Tool], Toolbox],
|
tools: Union[List[Tool], Toolbox],
|
||||||
llm_engine: Optional[Callable[[List[Dict[str, str]]], str]] = None,
|
llm_engine: Callable[[List[Dict[str, str]]], str],
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
tool_description_template: Optional[str] = None,
|
tool_description_template: Optional[str] = None,
|
||||||
max_iterations: int = 6,
|
max_iterations: int = 6,
|
||||||
|
@ -195,8 +174,6 @@ class MultiStepAgent:
|
||||||
monitor_metrics: bool = True,
|
monitor_metrics: bool = True,
|
||||||
planning_interval: Optional[int] = None,
|
planning_interval: Optional[int] = None,
|
||||||
):
|
):
|
||||||
if llm_engine is None:
|
|
||||||
llm_engine = HfApiEngine()
|
|
||||||
if system_prompt is None:
|
if system_prompt is None:
|
||||||
system_prompt = CODE_SYSTEM_PROMPT
|
system_prompt = CODE_SYSTEM_PROMPT
|
||||||
if tool_parser is None:
|
if tool_parser is None:
|
||||||
|
@ -222,14 +199,14 @@ class MultiStepAgent:
|
||||||
self._toolbox = tools
|
self._toolbox = tools
|
||||||
if add_base_tools:
|
if add_base_tools:
|
||||||
self._toolbox.add_base_tools(
|
self._toolbox.add_base_tools(
|
||||||
add_python_interpreter=(self.__class__ == JsonAgent)
|
add_python_interpreter=(self.__class__ == ToolCallingAgent)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._toolbox = Toolbox(tools, add_base_tools=add_base_tools)
|
self._toolbox = Toolbox(tools, add_base_tools=add_base_tools)
|
||||||
self._toolbox.add_tool(FinalAnswerTool())
|
self._toolbox.add_tool(FinalAnswerTool())
|
||||||
|
|
||||||
self.system_prompt = self.initialize_system_prompt()
|
self.system_prompt = self.initialize_system_prompt()
|
||||||
self.prompt_messages = None
|
self.input_messages = None
|
||||||
self.logs = []
|
self.logs = []
|
||||||
self.task = None
|
self.task = None
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
@ -384,23 +361,23 @@ class MultiStepAgent:
|
||||||
"""
|
"""
|
||||||
This method provides a final answer to the task, based on the logs of the agent's interactions.
|
This method provides a final answer to the task, based on the logs of the agent's interactions.
|
||||||
"""
|
"""
|
||||||
self.prompt_messages = [
|
self.input_messages = [
|
||||||
{
|
{
|
||||||
"role": MessageRole.SYSTEM,
|
"role": MessageRole.SYSTEM,
|
||||||
"content": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
|
"content": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
self.prompt_messages += self.write_inner_memory_from_logs()[1:]
|
self.input_messages += self.write_inner_memory_from_logs()[1:]
|
||||||
self.prompt_messages += [
|
self.input_messages += [
|
||||||
{
|
{
|
||||||
"role": MessageRole.USER,
|
"role": MessageRole.USER,
|
||||||
"content": f"Based on the above, please provide an answer to the following user request:\n{task}",
|
"content": f"Based on the above, please provide an answer to the following user request:\n{task}",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
try:
|
try:
|
||||||
return self.llm_engine(self.prompt_messages)
|
return self.llm_engine(self.input_messages)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error in generating final LLM output: {e}."
|
error_msg = f"Error in generating final LLM output:\n{e}"
|
||||||
console.print(f"[bold red]{error_msg}[/bold red]")
|
console.print(f"[bold red]{error_msg}[/bold red]")
|
||||||
return error_msg
|
return error_msg
|
||||||
|
|
||||||
|
@ -490,16 +467,20 @@ class MultiStepAgent:
|
||||||
system_prompt_step = SystemPromptStep(system_prompt=self.system_prompt)
|
system_prompt_step = SystemPromptStep(system_prompt=self.system_prompt)
|
||||||
|
|
||||||
if reset:
|
if reset:
|
||||||
self.token_count = 0
|
|
||||||
self.logs = []
|
self.logs = []
|
||||||
self.logs.append(system_prompt_step)
|
self.logs.append(system_prompt_step)
|
||||||
|
self.monitor.reset()
|
||||||
else:
|
else:
|
||||||
if len(self.logs) > 0:
|
if len(self.logs) > 0:
|
||||||
self.logs[0] = system_prompt_step
|
self.logs[0] = system_prompt_step
|
||||||
else:
|
else:
|
||||||
self.logs.append(system_prompt_step)
|
self.logs.append(system_prompt_step)
|
||||||
|
|
||||||
console.print(Group(Rule("[bold]New run", characters="═", style=YELLOW_HEX), Text(self.task)))
|
console.print(
|
||||||
|
Group(
|
||||||
|
Rule("[bold]New run", characters="═", style=YELLOW_HEX), Text(self.task)
|
||||||
|
)
|
||||||
|
)
|
||||||
self.logs.append(TaskStep(task=self.task))
|
self.logs.append(TaskStep(task=self.task))
|
||||||
|
|
||||||
if single_step:
|
if single_step:
|
||||||
|
@ -534,7 +515,9 @@ class MultiStepAgent:
|
||||||
self.planning_step(
|
self.planning_step(
|
||||||
task, is_first_step=(iteration == 0), iteration=iteration
|
task, is_first_step=(iteration == 0), iteration=iteration
|
||||||
)
|
)
|
||||||
console.print(Rule(f"[bold]Step {iteration}", characters="━", style=YELLOW_HEX))
|
console.print(
|
||||||
|
Rule(f"[bold]Step {iteration}", characters="━", style=YELLOW_HEX)
|
||||||
|
)
|
||||||
|
|
||||||
# Run one step!
|
# Run one step!
|
||||||
final_answer = self.step(step_log)
|
final_answer = self.step(step_log)
|
||||||
|
@ -580,7 +563,9 @@ class MultiStepAgent:
|
||||||
self.planning_step(
|
self.planning_step(
|
||||||
task, is_first_step=(iteration == 0), iteration=iteration
|
task, is_first_step=(iteration == 0), iteration=iteration
|
||||||
)
|
)
|
||||||
console.print(Rule(f"[bold]Step {iteration}", characters="━", style=YELLOW_HEX))
|
console.print(
|
||||||
|
Rule(f"[bold]Step {iteration}", characters="━", style=YELLOW_HEX)
|
||||||
|
)
|
||||||
|
|
||||||
# Run one step!
|
# Run one step!
|
||||||
final_answer = self.step(step_log)
|
final_answer = self.step(step_log)
|
||||||
|
@ -727,138 +712,20 @@ Now begin!""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class JsonAgent(MultiStepAgent):
|
|
||||||
"""
|
|
||||||
In this agent, the tool calls will be formulated by the LLM in JSON format, then parsed and executed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
tools: List[Tool],
|
|
||||||
llm_engine: Optional[Callable] = None,
|
|
||||||
system_prompt: Optional[str] = None,
|
|
||||||
grammar: Optional[Dict[str, str]] = None,
|
|
||||||
planning_interval: Optional[int] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
if llm_engine is None:
|
|
||||||
llm_engine = HfApiEngine()
|
|
||||||
if system_prompt is None:
|
|
||||||
system_prompt = JSON_SYSTEM_PROMPT
|
|
||||||
super().__init__(
|
|
||||||
tools=tools,
|
|
||||||
llm_engine=llm_engine,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
grammar=grammar,
|
|
||||||
planning_interval=planning_interval,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def step(self, log_entry: ActionStep) -> Union[None, Any]:
|
|
||||||
"""
|
|
||||||
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
|
|
||||||
Returns None if the step is not final.
|
|
||||||
"""
|
|
||||||
agent_memory = self.write_inner_memory_from_logs()
|
|
||||||
|
|
||||||
self.prompt_messages = agent_memory
|
|
||||||
|
|
||||||
# Add new step in logs
|
|
||||||
log_entry.agent_memory = agent_memory.copy()
|
|
||||||
|
|
||||||
try:
|
|
||||||
additional_args = (
|
|
||||||
{"grammar": self.grammar} if self.grammar is not None else {}
|
|
||||||
)
|
|
||||||
llm_output = self.llm_engine(
|
|
||||||
self.prompt_messages,
|
|
||||||
stop_sequences=["<end_action>", "Observation:"],
|
|
||||||
**additional_args,
|
|
||||||
)
|
|
||||||
log_entry.llm_output = llm_output
|
|
||||||
except Exception as e:
|
|
||||||
raise AgentGenerationError(f"Error in generating llm_engine output: {e}.")
|
|
||||||
|
|
||||||
if self.verbose:
|
|
||||||
console.print(
|
|
||||||
Group(
|
|
||||||
Rule(
|
|
||||||
"[italic]Output message of the LLM:",
|
|
||||||
align="left",
|
|
||||||
style="orange",
|
|
||||||
),
|
|
||||||
Text(llm_output),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Parse
|
|
||||||
rationale, action = self.extract_action(
|
|
||||||
llm_output=llm_output, split_token="Action:"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
tool_name, tool_arguments = self.tool_parser(action)
|
|
||||||
except Exception as e:
|
|
||||||
raise AgentParsingError(f"Could not parse the given action: {e}.")
|
|
||||||
|
|
||||||
log_entry.tool_call = ToolCall(
|
|
||||||
tool_name=tool_name, tool_arguments=tool_arguments
|
|
||||||
)
|
|
||||||
|
|
||||||
# Execute
|
|
||||||
console.print(Rule("Agent thoughts:", align="left"), Text(rationale))
|
|
||||||
console.print(
|
|
||||||
Panel(Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}"))
|
|
||||||
)
|
|
||||||
if tool_name == "final_answer":
|
|
||||||
if isinstance(tool_arguments, dict):
|
|
||||||
if "answer" in tool_arguments:
|
|
||||||
answer = tool_arguments["answer"]
|
|
||||||
else:
|
|
||||||
answer = tool_arguments
|
|
||||||
else:
|
|
||||||
answer = tool_arguments
|
|
||||||
if (
|
|
||||||
isinstance(answer, str) and answer in self.state.keys()
|
|
||||||
): # if the answer is a state variable, return the value
|
|
||||||
answer = self.state[answer]
|
|
||||||
log_entry.action_output = answer
|
|
||||||
return answer
|
|
||||||
else:
|
|
||||||
if tool_arguments is None:
|
|
||||||
tool_arguments = {}
|
|
||||||
observation = self.execute_tool_call(tool_name, tool_arguments)
|
|
||||||
observation_type = type(observation)
|
|
||||||
if observation_type in [AgentImage, AgentAudio]:
|
|
||||||
if observation_type == AgentImage:
|
|
||||||
observation_name = "image.png"
|
|
||||||
elif observation_type == AgentAudio:
|
|
||||||
observation_name = "audio.mp3"
|
|
||||||
# TODO: observation naming could allow for different names of same type
|
|
||||||
|
|
||||||
self.state[observation_name] = observation
|
|
||||||
updated_information = f"Stored '{observation_name}' in memory."
|
|
||||||
else:
|
|
||||||
updated_information = str(observation).strip()
|
|
||||||
log_entry.observations = updated_information
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class ToolCallingAgent(MultiStepAgent):
|
class ToolCallingAgent(MultiStepAgent):
|
||||||
"""
|
"""
|
||||||
This agent uses JSON-like tool calls, but to the difference of JsonAgents, it leverages the underlying librarie's tool calling facilities.
|
This agent uses JSON-like tool calls, using method `llm_engine.get_tool_call` to leverage the LLM engine's tool calling capabilities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tools: List[Tool],
|
tools: List[Tool],
|
||||||
llm_engine: Optional[Callable] = None,
|
llm_engine: Callable,
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
planning_interval: Optional[int] = None,
|
planning_interval: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if llm_engine is None:
|
|
||||||
llm_engine = HfApiEngine()
|
|
||||||
if system_prompt is None:
|
if system_prompt is None:
|
||||||
system_prompt = TOOL_CALLING_SYSTEM_PROMPT
|
system_prompt = TOOL_CALLING_SYSTEM_PROMPT
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
@ -876,17 +743,19 @@ class ToolCallingAgent(MultiStepAgent):
|
||||||
"""
|
"""
|
||||||
agent_memory = self.write_inner_memory_from_logs()
|
agent_memory = self.write_inner_memory_from_logs()
|
||||||
|
|
||||||
self.prompt_messages = agent_memory
|
self.input_messages = agent_memory
|
||||||
|
|
||||||
# Add new step in logs
|
# Add new step in logs
|
||||||
log_entry.agent_memory = agent_memory.copy()
|
log_entry.agent_memory = agent_memory.copy()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tool_name, tool_arguments, tool_call_id = self.llm_engine.get_tool_call(
|
tool_name, tool_arguments, tool_call_id = self.llm_engine.get_tool_call(
|
||||||
self.prompt_messages, available_tools=list(self.toolbox._tools.values())
|
self.input_messages,
|
||||||
|
available_tools=list(self.toolbox._tools.values()),
|
||||||
|
stop_sequences=["Observation:"],
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AgentGenerationError(f"Error in generating llm_engine output: {e}.")
|
raise AgentGenerationError(f"Error in generating tool call with llm_engine:\n{e}")
|
||||||
|
|
||||||
log_entry.tool_call = ToolCall(
|
log_entry.tool_call = ToolCall(
|
||||||
name=tool_name, arguments=tool_arguments, id=tool_call_id
|
name=tool_name, arguments=tool_arguments, id=tool_call_id
|
||||||
|
@ -938,7 +807,7 @@ class CodeAgent(MultiStepAgent):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tools: List[Tool],
|
tools: List[Tool],
|
||||||
llm_engine: Optional[Callable] = None,
|
llm_engine: Callable,
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
grammar: Optional[Dict[str, str]] = None,
|
grammar: Optional[Dict[str, str]] = None,
|
||||||
additional_authorized_imports: Optional[List[str]] = None,
|
additional_authorized_imports: Optional[List[str]] = None,
|
||||||
|
@ -946,8 +815,6 @@ class CodeAgent(MultiStepAgent):
|
||||||
use_e2b_executor: bool = False,
|
use_e2b_executor: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if llm_engine is None:
|
|
||||||
llm_engine = HfApiEngine()
|
|
||||||
if system_prompt is None:
|
if system_prompt is None:
|
||||||
system_prompt = CODE_SYSTEM_PROMPT
|
system_prompt = CODE_SYSTEM_PROMPT
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
@ -994,7 +861,7 @@ class CodeAgent(MultiStepAgent):
|
||||||
"""
|
"""
|
||||||
agent_memory = self.write_inner_memory_from_logs()
|
agent_memory = self.write_inner_memory_from_logs()
|
||||||
|
|
||||||
self.prompt_messages = agent_memory.copy()
|
self.input_messages = agent_memory.copy()
|
||||||
|
|
||||||
# Add new step in logs
|
# Add new step in logs
|
||||||
log_entry.agent_memory = agent_memory.copy()
|
log_entry.agent_memory = agent_memory.copy()
|
||||||
|
@ -1004,13 +871,13 @@ class CodeAgent(MultiStepAgent):
|
||||||
{"grammar": self.grammar} if self.grammar is not None else {}
|
{"grammar": self.grammar} if self.grammar is not None else {}
|
||||||
)
|
)
|
||||||
llm_output = self.llm_engine(
|
llm_output = self.llm_engine(
|
||||||
self.prompt_messages,
|
self.input_messages,
|
||||||
stop_sequences=["<end_action>", "Observation:"],
|
stop_sequences=["<end_action>", "Observation:"],
|
||||||
**additional_args,
|
**additional_args,
|
||||||
)
|
)
|
||||||
log_entry.llm_output = llm_output
|
log_entry.llm_output = llm_output
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AgentGenerationError(f"Error in generating llm_engine output: {e}.")
|
raise AgentGenerationError(f"Error in generating llm_engine output:\n{e}")
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
console.print(
|
console.print(
|
||||||
|
@ -1026,17 +893,7 @@ class CodeAgent(MultiStepAgent):
|
||||||
|
|
||||||
# Parse
|
# Parse
|
||||||
try:
|
try:
|
||||||
rationale, raw_code_action = self.extract_action(
|
code_action = parse_code_blob(llm_output)
|
||||||
llm_output=llm_output, split_token="Code:"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
console.print(
|
|
||||||
f"Error in extracting action, trying to parse the whole output. Error trace: {e}"
|
|
||||||
)
|
|
||||||
rationale, raw_code_action = llm_output, llm_output
|
|
||||||
|
|
||||||
try:
|
|
||||||
code_action = parse_code_blob(raw_code_action)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error in code parsing: {e}. Make sure to provide correct code"
|
error_msg = f"Error in code parsing: {e}. Make sure to provide correct code"
|
||||||
raise AgentParsingError(error_msg)
|
raise AgentParsingError(error_msg)
|
||||||
|
@ -1048,11 +905,6 @@ class CodeAgent(MultiStepAgent):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
if self.verbose:
|
|
||||||
console.print(
|
|
||||||
Group(Rule("[italic]Agent thoughts", align="left"), Text(rationale))
|
|
||||||
)
|
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
Panel(
|
Panel(
|
||||||
Syntax(code_action, lexer="python", theme="github-dark"),
|
Syntax(code_action, lexer="python", theme="github-dark"),
|
||||||
|
@ -1148,10 +1000,8 @@ class ManagedAgent:
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AgentError",
|
|
||||||
"ManagedAgent",
|
"ManagedAgent",
|
||||||
"MultiStepAgent",
|
"MultiStepAgent",
|
||||||
"CodeAgent",
|
"CodeAgent",
|
||||||
"JsonAgent",
|
|
||||||
"Toolbox",
|
"Toolbox",
|
||||||
]
|
]
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict
|
from typing import Dict, Optional
|
||||||
from huggingface_hub import hf_hub_download, list_spaces
|
from huggingface_hub import hf_hub_download, list_spaces
|
||||||
|
|
||||||
from transformers.utils import is_offline_mode
|
from transformers.utils import is_offline_mode
|
||||||
|
@ -139,7 +139,7 @@ class UserInputTool(Tool):
|
||||||
return user_input
|
return user_input
|
||||||
|
|
||||||
|
|
||||||
class WebSearchTool(Tool):
|
class DuckDuckGoSearchTool(Tool):
|
||||||
name = "web_search"
|
name = "web_search"
|
||||||
description = """Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
|
description = """Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
|
||||||
Each result has keys 'title', 'href' and 'body'."""
|
Each result has keys 'title', 'href' and 'body'."""
|
||||||
|
@ -148,7 +148,7 @@ class WebSearchTool(Tool):
|
||||||
}
|
}
|
||||||
output_type = "any"
|
output_type = "any"
|
||||||
|
|
||||||
def forward(self, query: str) -> str:
|
def forward(self, query: str) -> list[dict[str, str]]:
|
||||||
try:
|
try:
|
||||||
from duckduckgo_search import DDGS
|
from duckduckgo_search import DDGS
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -159,6 +159,85 @@ class WebSearchTool(Tool):
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleSearchTool(Tool):
|
||||||
|
name = "web_search"
|
||||||
|
description = """Performs a google web search for your query then returns a string of the top search results."""
|
||||||
|
inputs = {
|
||||||
|
"query": {"type": "string", "description": "The search query to perform."},
|
||||||
|
"filter_year": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Optionally restrict results to a certain year",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
output_type = "string"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(self)
|
||||||
|
import os
|
||||||
|
|
||||||
|
self.serpapi_key = os.getenv("SERPAPI_API_KEY")
|
||||||
|
|
||||||
|
def forward(self, query: str, filter_year: Optional[int] = None) -> str:
|
||||||
|
import requests
|
||||||
|
|
||||||
|
if self.serpapi_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Missing SerpAPI key. Make sure you have 'SERPAPI_API_KEY' in your env variables."
|
||||||
|
)
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"engine": "google",
|
||||||
|
"q": query,
|
||||||
|
"api_key": self.serpapi_key,
|
||||||
|
"google_domain": "google.com",
|
||||||
|
}
|
||||||
|
if filter_year is not None:
|
||||||
|
params["tbs"] = (
|
||||||
|
f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = requests.get("https://serpapi.com/search.json", params=params)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
results = response.json()
|
||||||
|
else:
|
||||||
|
raise ValueError(response.json())
|
||||||
|
|
||||||
|
if "organic_results" not in results.keys():
|
||||||
|
raise Exception(
|
||||||
|
f"'organic_results' key not found for query: '{query}'. Use a less restrictive query."
|
||||||
|
)
|
||||||
|
if len(results["organic_results"]) == 0:
|
||||||
|
year_filter_message = (
|
||||||
|
f" with filter year={filter_year}" if filter_year is not None else ""
|
||||||
|
)
|
||||||
|
return f"No results found for '{query}'{year_filter_message}. Try with a more general query, or remove the year filter."
|
||||||
|
|
||||||
|
web_snippets = []
|
||||||
|
if "organic_results" in results:
|
||||||
|
for idx, page in enumerate(results["organic_results"]):
|
||||||
|
date_published = ""
|
||||||
|
if "date" in page:
|
||||||
|
date_published = "\nDate published: " + page["date"]
|
||||||
|
|
||||||
|
source = ""
|
||||||
|
if "source" in page:
|
||||||
|
source = "\nSource: " + page["source"]
|
||||||
|
|
||||||
|
snippet = ""
|
||||||
|
if "snippet" in page:
|
||||||
|
snippet = "\n" + page["snippet"]
|
||||||
|
|
||||||
|
redacted_version = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{snippet}"
|
||||||
|
|
||||||
|
redacted_version = redacted_version.replace(
|
||||||
|
"Your browser can't play this video.", ""
|
||||||
|
)
|
||||||
|
web_snippets.append(redacted_version)
|
||||||
|
|
||||||
|
return "## Web Results\n" + "\n\n".join(web_snippets)
|
||||||
|
|
||||||
|
|
||||||
class VisitWebpageTool(Tool):
|
class VisitWebpageTool(Tool):
|
||||||
name = "visit_webpage"
|
name = "visit_webpage"
|
||||||
description = "Visits a webpage at the given url and returns its content as a markdown string."
|
description = "Visits a webpage at the given url and returns its content as a markdown string."
|
||||||
|
@ -223,7 +302,8 @@ __all__ = [
|
||||||
"PythonInterpreterTool",
|
"PythonInterpreterTool",
|
||||||
"FinalAnswerTool",
|
"FinalAnswerTool",
|
||||||
"UserInputTool",
|
"UserInputTool",
|
||||||
"WebSearchTool",
|
"DuckDuckGoSearchTool",
|
||||||
|
"GoogleSearchTool",
|
||||||
"VisitWebpageTool",
|
"VisitWebpageTool",
|
||||||
"SpeechToTextTool",
|
"SpeechToTextTool",
|
||||||
]
|
]
|
||||||
|
|
|
@ -328,7 +328,7 @@ if __name__ == '__main__':
|
||||||
if self.socket:
|
if self.socket:
|
||||||
try:
|
try:
|
||||||
self.socket.close()
|
self.socket.close()
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if self.container:
|
if self.container:
|
||||||
|
|
|
@ -17,13 +17,17 @@
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
from transformers import AutoTokenizer, Pipeline
|
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from huggingface_hub import InferenceClient
|
from huggingface_hub import InferenceClient
|
||||||
|
|
||||||
from smolagents import Tool
|
from .tools import Tool
|
||||||
|
from .utils import parse_json_tool_call
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -64,7 +68,7 @@ def get_json_schema(tool: Tool) -> Dict:
|
||||||
"description": tool.description,
|
"description": tool.description,
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": tool.inputs,
|
"properties": {k: {k2: v2.replace("any", "object") for k2, v2 in v.items()} for k, v in tool.inputs.items()},
|
||||||
"required": list(tool.inputs.keys()),
|
"required": list(tool.inputs.keys()),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -91,7 +95,8 @@ def remove_stop_sequences(content: str, stop_sequences: List[str]) -> str:
|
||||||
|
|
||||||
|
|
||||||
def get_clean_message_list(
|
def get_clean_message_list(
|
||||||
message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}
|
message_list: List[Dict[str, str]],
|
||||||
|
role_conversions: Dict[MessageRole, MessageRole] = {},
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
Subsequent messages with the same role will be concatenated to a single message.
|
Subsequent messages with the same role will be concatenated to a single message.
|
||||||
|
@ -274,25 +279,40 @@ class HfApiEngine(HfEngine):
|
||||||
|
|
||||||
|
|
||||||
class TransformersEngine(HfEngine):
|
class TransformersEngine(HfEngine):
|
||||||
"""This engine uses a pre-initialized local text-generation pipeline."""
|
"""This engine initializes a model and tokenizer from the given `model_id`."""
|
||||||
|
|
||||||
def __init__(self, pipeline: Pipeline, model_id: Optional[str] = None):
|
def __init__(self, model_id: Optional[str] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
|
||||||
if model_id is None:
|
if model_id is None:
|
||||||
model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
|
model_id = default_model_id
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'"
|
f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'"
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Failed to load tokenizer for model {model_id}: {e}. Loading default tokenizer instead."
|
f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {model_id=}."
|
||||||
)
|
)
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
|
||||||
"HuggingFaceTB/SmolLM2-1.7B-Instruct"
|
self.model = AutoModelForCausalLM.from_pretrained(default_model_id)
|
||||||
)
|
|
||||||
self.pipeline = pipeline
|
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
|
||||||
|
class StopOnTokens(StoppingCriteria):
|
||||||
|
def __init__(self, stop_token_ids):
|
||||||
|
self.stop_token_ids = stop_token_ids
|
||||||
|
|
||||||
|
def __call__(self, input_ids, scores):
|
||||||
|
for stop_ids in self.stop_token_ids:
|
||||||
|
if input_ids[0][-len(stop_ids) :].tolist() == stop_ids:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
stop_token_ids = [self.tokenizer.encode("Observation:")[1:]] # Remove BOS token
|
||||||
|
stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_token_ids)])
|
||||||
|
return stopping_criteria
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
|
@ -306,45 +326,93 @@ class TransformersEngine(HfEngine):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get LLM output
|
# Get LLM output
|
||||||
if stop_sequences is not None and len(stop_sequences) > 0:
|
prompt = self.tokenizer.apply_chat_template(
|
||||||
stop_strings = stop_sequences
|
|
||||||
else:
|
|
||||||
stop_strings = None
|
|
||||||
|
|
||||||
output = self.pipeline(
|
|
||||||
messages,
|
messages,
|
||||||
stop_strings=stop_strings,
|
return_tensors="pt",
|
||||||
max_length=max_tokens,
|
return_dict=True,
|
||||||
tokenizer=self.pipeline.tokenizer,
|
)
|
||||||
|
prompt = prompt.to(self.model.device)
|
||||||
|
count_prompt_tokens = prompt["input_ids"].shape[1]
|
||||||
|
|
||||||
|
out = self.model.generate(
|
||||||
|
**prompt,
|
||||||
|
max_new_tokens=max_tokens,
|
||||||
|
stopping_criteria=(
|
||||||
|
self.make_stopping_criteria(stop_sequences) if stop_sequences else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
generated_tokens = out[0, count_prompt_tokens:]
|
||||||
|
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||||
|
|
||||||
|
self.last_input_token_count = count_prompt_tokens
|
||||||
|
self.last_output_token_count = len(generated_tokens)
|
||||||
|
|
||||||
|
if stop_sequences is not None:
|
||||||
|
response = remove_stop_sequences(response, stop_sequences)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def get_tool_call(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
available_tools: List[Tool],
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
max_tokens: int = 500,
|
||||||
|
) -> str:
|
||||||
|
messages = get_clean_message_list(
|
||||||
|
messages, role_conversions=tool_role_conversions
|
||||||
)
|
)
|
||||||
|
|
||||||
response = output[0]["generated_text"][-1]["content"]
|
prompt = self.tokenizer.apply_chat_template(
|
||||||
self.last_input_token_count = len(
|
messages,
|
||||||
self.tokenizer.apply_chat_template(messages, tokenize=True)
|
tools=[get_json_schema(tool) for tool in available_tools],
|
||||||
|
return_tensors="pt",
|
||||||
|
return_dict=True,
|
||||||
|
add_generation_prompt=True,
|
||||||
)
|
)
|
||||||
self.last_output_token_count = len(self.tokenizer.encode(response))
|
prompt = prompt.to(self.model.device)
|
||||||
return response
|
count_prompt_tokens = prompt["input_ids"].shape[1]
|
||||||
|
|
||||||
|
out = self.model.generate(
|
||||||
|
**prompt,
|
||||||
|
max_new_tokens=max_tokens,
|
||||||
|
stopping_criteria=(
|
||||||
|
self.make_stopping_criteria(stop_sequences) if stop_sequences else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
generated_tokens = out[0, count_prompt_tokens:]
|
||||||
|
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||||
|
|
||||||
|
self.last_input_token_count = count_prompt_tokens
|
||||||
|
self.last_output_token_count = len(generated_tokens)
|
||||||
|
|
||||||
|
if stop_sequences is not None:
|
||||||
|
response = remove_stop_sequences(response, stop_sequences)
|
||||||
|
|
||||||
|
tool_name, tool_input = parse_json_tool_call(response)
|
||||||
|
call_id = "".join(random.choices("0123456789", k=5))
|
||||||
|
|
||||||
|
return tool_name, tool_input, call_id
|
||||||
|
|
||||||
|
|
||||||
class OpenAIEngine:
|
class OpenAIEngine:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name: Optional[str] = None,
|
model_id: Optional[str] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
base_url: Optional[str] = None,
|
base_url: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Creates a LLM Engine that follows OpenAI format.
|
"""Creates a LLM Engine that follows OpenAI format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name (`str`, *optional*): the model name to use.
|
model_id (`str`, *optional*): the model name to use.
|
||||||
api_key (`str`, *optional*): your API key.
|
api_key (`str`, *optional*): your API key.
|
||||||
base_url (`str`, *optional*): the URL to use if using a different inference service than OpenAI, for instance "https://api-inference.huggingface.co/v1/".
|
base_url (`str`, *optional*): the URL to use if using a different inference service than OpenAI, for instance "https://api-inference.huggingface.co/v1/".
|
||||||
"""
|
"""
|
||||||
if model_name is None:
|
if model_id is None:
|
||||||
model_name = "gpt-4o"
|
model_id = "gpt-4o"
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
self.model_name = model_name
|
self.model_id = model_id
|
||||||
self.client = OpenAI(
|
self.client = OpenAI(
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -364,7 +432,7 @@ class OpenAIEngine:
|
||||||
)
|
)
|
||||||
|
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=self.model_name,
|
model=self.model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stop=stop_sequences,
|
stop=stop_sequences,
|
||||||
temperature=0.5,
|
temperature=0.5,
|
||||||
|
@ -378,13 +446,14 @@ class OpenAIEngine:
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
available_tools: List[Tool],
|
available_tools: List[Tool],
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
"""Generates a tool call for the given message list. This method is used only by `ToolCallingAgent`."""
|
"""Generates a tool call for the given message list. This method is used only by `ToolCallingAgent`."""
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
messages, role_conversions=tool_role_conversions
|
messages, role_conversions=tool_role_conversions
|
||||||
)
|
)
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=self.model_name,
|
model=self.model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=[get_json_schema(tool) for tool in available_tools],
|
tools=[get_json_schema(tool) for tool in available_tools],
|
||||||
tool_choice="required",
|
tool_choice="required",
|
||||||
|
@ -396,12 +465,12 @@ class OpenAIEngine:
|
||||||
|
|
||||||
|
|
||||||
class AnthropicEngine:
|
class AnthropicEngine:
|
||||||
def __init__(self, model_name="claude-3-5-sonnet-20240620", use_bedrock=False):
|
def __init__(self, model_id="claude-3-5-sonnet-20240620", use_bedrock=False):
|
||||||
from anthropic import Anthropic, AnthropicBedrock
|
from anthropic import Anthropic, AnthropicBedrock
|
||||||
|
|
||||||
self.model_name = model_name
|
self.model_id = model_id
|
||||||
if use_bedrock:
|
if use_bedrock:
|
||||||
self.model_name = "anthropic.claude-3-5-sonnet-20240620-v1:0"
|
self.model_id = "anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||||
self.client = AnthropicBedrock(
|
self.client = AnthropicBedrock(
|
||||||
aws_access_key=os.getenv("AWS_BEDROCK_ID"),
|
aws_access_key=os.getenv("AWS_BEDROCK_ID"),
|
||||||
aws_secret_key=os.getenv("AWS_BEDROCK_KEY"),
|
aws_secret_key=os.getenv("AWS_BEDROCK_KEY"),
|
||||||
|
@ -454,7 +523,7 @@ class AnthropicEngine:
|
||||||
messages
|
messages
|
||||||
)
|
)
|
||||||
response = self.client.messages.create(
|
response = self.client.messages.create(
|
||||||
model=self.model_name,
|
model=self.model_id,
|
||||||
system=system_prompt,
|
system=system_prompt,
|
||||||
messages=filtered_messages,
|
messages=filtered_messages,
|
||||||
stop_sequences=stop_sequences,
|
stop_sequences=stop_sequences,
|
||||||
|
@ -471,6 +540,7 @@ class AnthropicEngine:
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
available_tools: List[Tool],
|
available_tools: List[Tool],
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
max_tokens: int = 1500,
|
max_tokens: int = 1500,
|
||||||
):
|
):
|
||||||
"""Generates a tool call for the given message list. This method is used only by `ToolCallingAgent`."""
|
"""Generates a tool call for the given message list. This method is used only by `ToolCallingAgent`."""
|
||||||
|
@ -481,7 +551,7 @@ class AnthropicEngine:
|
||||||
messages
|
messages
|
||||||
)
|
)
|
||||||
response = self.client.messages.create(
|
response = self.client.messages.create(
|
||||||
model=self.model_name,
|
model=self.model_id,
|
||||||
system=system_prompt,
|
system=system_prompt,
|
||||||
messages=filtered_messages,
|
messages=filtered_messages,
|
||||||
tools=[get_json_schema_anthropic(tool) for tool in available_tools],
|
tools=[get_json_schema_anthropic(tool) for tool in available_tools],
|
||||||
|
@ -493,6 +563,36 @@ class AnthropicEngine:
|
||||||
self.last_output_token_count = response.usage.output_tokens
|
self.last_output_token_count = response.usage.output_tokens
|
||||||
return tool_call.name, tool_call.input, tool_call.id
|
return tool_call.name, tool_call.input, tool_call.id
|
||||||
|
|
||||||
|
class LiteLLMEngine():
|
||||||
|
def __init__(self, model_id="anthropic/claude-3-5-sonnet-20240620"):
|
||||||
|
self.model_id = model_id
|
||||||
|
import os, litellm
|
||||||
|
|
||||||
|
# IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs
|
||||||
|
litellm.add_function_to_prompt = True
|
||||||
|
|
||||||
|
def get_tool_call(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
available_tools: List[Tool],
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
max_tokens: int = 1500,
|
||||||
|
):
|
||||||
|
from litellm import completion
|
||||||
|
messages = get_clean_message_list(
|
||||||
|
messages, role_conversions=tool_role_conversions
|
||||||
|
)
|
||||||
|
response = completion(
|
||||||
|
model=self.model_id,
|
||||||
|
messages=messages,
|
||||||
|
tools=[get_json_schema(tool) for tool in available_tools],
|
||||||
|
tool_choice="required",
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
stop=stop_sequences,
|
||||||
|
)
|
||||||
|
tool_calls = response.choices[0].message.tool_calls[0]
|
||||||
|
return tool_calls.function.name, tool_calls.function.arguments, tool_calls.id
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MessageRole",
|
"MessageRole",
|
||||||
|
@ -503,4 +603,5 @@ __all__ = [
|
||||||
"HfApiEngine",
|
"HfApiEngine",
|
||||||
"OpenAIEngine",
|
"OpenAIEngine",
|
||||||
"AnthropicEngine",
|
"AnthropicEngine",
|
||||||
|
"LiteLLMEngine",
|
||||||
]
|
]
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from .utils import console
|
from .utils import console
|
||||||
from rich.text import Text
|
from rich.text import Text
|
||||||
from rich.console import Group
|
|
||||||
|
|
||||||
|
|
||||||
class Monitor:
|
class Monitor:
|
||||||
|
@ -38,7 +37,9 @@ class Monitor:
|
||||||
def update_metrics(self, step_log):
|
def update_metrics(self, step_log):
|
||||||
step_duration = step_log.duration
|
step_duration = step_log.duration
|
||||||
self.step_durations.append(step_duration)
|
self.step_durations.append(step_duration)
|
||||||
console_outputs = f"[Step {len(self.step_durations)-1}: Duration {step_duration:.2f} seconds"
|
console_outputs = (
|
||||||
|
f"[Step {len(self.step_durations)-1}: Duration {step_duration:.2f} seconds"
|
||||||
|
)
|
||||||
|
|
||||||
if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None:
|
if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None:
|
||||||
self.total_input_token_count += (
|
self.total_input_token_count += (
|
||||||
|
|
|
@ -152,8 +152,8 @@ Specifically, this json should have an `action` key (name of the tool to use) an
|
||||||
|
|
||||||
The $ACTION_JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. It should be formatted in json. Do not try to escape special characters. Here is the template of a valid $ACTION_JSON_BLOB:
|
The $ACTION_JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. It should be formatted in json. Do not try to escape special characters. Here is the template of a valid $ACTION_JSON_BLOB:
|
||||||
{
|
{
|
||||||
"action": $TOOL_NAME,
|
"tool_name": $TOOL_NAME,
|
||||||
"action_input": $INPUT
|
"tool_arguments": $INPUT
|
||||||
}<end_action>
|
}<end_action>
|
||||||
|
|
||||||
Make sure to have the $INPUT as a dictionary in the right format for the tool you are using, and do not put variable names as input if you can find the right values.
|
Make sure to have the $INPUT as a dictionary in the right format for the tool you are using, and do not put variable names as input if you can find the right values.
|
||||||
|
@ -175,15 +175,15 @@ Observation: "image_1.jpg"
|
||||||
Thought: I need to transform the image that I received in the previous observation to make it green.
|
Thought: I need to transform the image that I received in the previous observation to make it green.
|
||||||
Action:
|
Action:
|
||||||
{
|
{
|
||||||
"action": "image_transformer",
|
"tool_name": "image_transformer",
|
||||||
"action_input": {"image": "image_1.jpg"}
|
"tool_arguments": {"image": "image_1.jpg"}
|
||||||
}<end_action>
|
}<end_action>
|
||||||
|
|
||||||
To provide the final answer to the task, use an action blob with "action": "final_answer" tool. It is the only way to complete the task, else you will be stuck on a loop. So your final output should look like this:
|
To provide the final answer to the task, use an action blob with "tool_name": "final_answer" tool. It is the only way to complete the task, else you will be stuck on a loop. So your final output should look like this:
|
||||||
Action:
|
Action:
|
||||||
{
|
{
|
||||||
"action": "final_answer",
|
"tool_name": "final_answer",
|
||||||
"action_input": {"answer": "insert your final answer here"}
|
"tool_arguments": {"answer": "insert your final answer here"}
|
||||||
}<end_action>
|
}<end_action>
|
||||||
|
|
||||||
|
|
||||||
|
@ -194,8 +194,8 @@ Task: "Generate an image of the oldest person in this document."
|
||||||
Thought: I will proceed step by step and use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
|
Thought: I will proceed step by step and use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
|
||||||
Action:
|
Action:
|
||||||
{
|
{
|
||||||
"action": "document_qa",
|
"tool_name": "document_qa",
|
||||||
"action_input": {"document": "document.pdf", "question": "Who is the oldest person mentioned?"}
|
"tool_arguments": {"document": "document.pdf", "question": "Who is the oldest person mentioned?"}
|
||||||
}<end_action>
|
}<end_action>
|
||||||
Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland."
|
Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland."
|
||||||
|
|
||||||
|
@ -203,16 +203,16 @@ Observation: "The oldest person in the document is John Doe, a 55 year old lumbe
|
||||||
Thought: I will now generate an image showcasing the oldest person.
|
Thought: I will now generate an image showcasing the oldest person.
|
||||||
Action:
|
Action:
|
||||||
{
|
{
|
||||||
"action": "image_generator",
|
"tool_name": "image_generator",
|
||||||
"action_input": {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."}
|
"tool_arguments": {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."}
|
||||||
}<end_action>
|
}<end_action>
|
||||||
Observation: "image.png"
|
Observation: "image.png"
|
||||||
|
|
||||||
Thought: I will now return the generated image.
|
Thought: I will now return the generated image.
|
||||||
Action:
|
Action:
|
||||||
{
|
{
|
||||||
"action": "final_answer",
|
"tool_name": "final_answer",
|
||||||
"action_input": "image.png"
|
"tool_arguments": "image.png"
|
||||||
}<end_action>
|
}<end_action>
|
||||||
|
|
||||||
---
|
---
|
||||||
|
@ -221,16 +221,16 @@ Task: "What is the result of the following operation: 5 + 3 + 1294.678?"
|
||||||
Thought: I will use python code evaluator to compute the result of the operation and then return the final answer using the `final_answer` tool
|
Thought: I will use python code evaluator to compute the result of the operation and then return the final answer using the `final_answer` tool
|
||||||
Action:
|
Action:
|
||||||
{
|
{
|
||||||
"action": "python_interpreter",
|
"tool_name": "python_interpreter",
|
||||||
"action_input": {"code": "5 + 3 + 1294.678"}
|
"tool_arguments": {"code": "5 + 3 + 1294.678"}
|
||||||
}<end_action>
|
}<end_action>
|
||||||
Observation: 1302.678
|
Observation: 1302.678
|
||||||
|
|
||||||
Thought: Now that I know the result, I will now return it.
|
Thought: Now that I know the result, I will now return it.
|
||||||
Action:
|
Action:
|
||||||
{
|
{
|
||||||
"action": "final_answer",
|
"tool_name": "final_answer",
|
||||||
"action_input": "1302.678"
|
"tool_arguments": "1302.678"
|
||||||
}<end_action>
|
}<end_action>
|
||||||
|
|
||||||
---
|
---
|
||||||
|
@ -239,8 +239,8 @@ Task: "Which city has the highest population , Guangzhou or Shanghai?"
|
||||||
Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities.
|
Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities.
|
||||||
Action:
|
Action:
|
||||||
{
|
{
|
||||||
"action": "search",
|
"tool_name": "search",
|
||||||
"action_input": "Population Guangzhou"
|
"tool_arguments": "Population Guangzhou"
|
||||||
}<end_action>
|
}<end_action>
|
||||||
Observation: ['Guangzhou has a population of 15 million inhabitants as of 2021.']
|
Observation: ['Guangzhou has a population of 15 million inhabitants as of 2021.']
|
||||||
|
|
||||||
|
@ -248,16 +248,16 @@ Observation: ['Guangzhou has a population of 15 million inhabitants as of 2021.'
|
||||||
Thought: Now let's get the population of Shanghai using the tool 'search'.
|
Thought: Now let's get the population of Shanghai using the tool 'search'.
|
||||||
Action:
|
Action:
|
||||||
{
|
{
|
||||||
"action": "search",
|
"tool_name": "search",
|
||||||
"action_input": "Population Shanghai"
|
"tool_arguments": "Population Shanghai"
|
||||||
}
|
}
|
||||||
Observation: '26 million (2019)'
|
Observation: '26 million (2019)'
|
||||||
|
|
||||||
Thought: Now I know that Shanghai has a larger population. Let's return the result.
|
Thought: Now I know that Shanghai has a larger population. Let's return the result.
|
||||||
Action:
|
Action:
|
||||||
{
|
{
|
||||||
"action": "final_answer",
|
"tool_name": "final_answer",
|
||||||
"action_input": "Shanghai"
|
"tool_arguments": "Shanghai"
|
||||||
}<end_action>
|
}<end_action>
|
||||||
|
|
||||||
|
|
||||||
|
@ -291,15 +291,15 @@ Observation: "image_1.jpg"
|
||||||
|
|
||||||
Action:
|
Action:
|
||||||
{
|
{
|
||||||
"action": "image_transformer",
|
"tool_name": "image_transformer",
|
||||||
"action_input": {"image": "image_1.jpg"}
|
"tool_arguments": {"image": "image_1.jpg"}
|
||||||
}
|
}
|
||||||
|
|
||||||
To provide the final answer to the task, use an action blob with "action": "final_answer" tool. It is the only way to complete the task, else you will be stuck on a loop. So your final output should look like this:
|
To provide the final answer to the task, use an action blob with "tool_name": "final_answer" tool. It is the only way to complete the task, else you will be stuck on a loop. So your final output should look like this:
|
||||||
Action:
|
Action:
|
||||||
{
|
{
|
||||||
"action": "final_answer",
|
"tool_name": "final_answer",
|
||||||
"action_input": {"answer": "insert your final answer here"}
|
"tool_arguments": {"answer": "insert your final answer here"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -309,22 +309,22 @@ Task: "Generate an image of the oldest person in this document."
|
||||||
|
|
||||||
Action:
|
Action:
|
||||||
{
|
{
|
||||||
"action": "document_qa",
|
"tool_name": "document_qa",
|
||||||
"action_input": {"document": "document.pdf", "question": "Who is the oldest person mentioned?"}
|
"tool_arguments": {"document": "document.pdf", "question": "Who is the oldest person mentioned?"}
|
||||||
}
|
}
|
||||||
Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland."
|
Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland."
|
||||||
|
|
||||||
Action:
|
Action:
|
||||||
{
|
{
|
||||||
"action": "image_generator",
|
"tool_name": "image_generator",
|
||||||
"action_input": {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."}
|
"tool_arguments": {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."}
|
||||||
}
|
}
|
||||||
Observation: "image.png"
|
Observation: "image.png"
|
||||||
|
|
||||||
Action:
|
Action:
|
||||||
{
|
{
|
||||||
"action": "final_answer",
|
"tool_name": "final_answer",
|
||||||
"action_input": "image.png"
|
"tool_arguments": "image.png"
|
||||||
}
|
}
|
||||||
|
|
||||||
---
|
---
|
||||||
|
@ -332,15 +332,15 @@ Task: "What is the result of the following operation: 5 + 3 + 1294.678?"
|
||||||
|
|
||||||
Action:
|
Action:
|
||||||
{
|
{
|
||||||
"action": "python_interpreter",
|
"tool_name": "python_interpreter",
|
||||||
"action_input": {"code": "5 + 3 + 1294.678"}
|
"tool_arguments": {"code": "5 + 3 + 1294.678"}
|
||||||
}
|
}
|
||||||
Observation: 1302.678
|
Observation: 1302.678
|
||||||
|
|
||||||
Action:
|
Action:
|
||||||
{
|
{
|
||||||
"action": "final_answer",
|
"tool_name": "final_answer",
|
||||||
"action_input": "1302.678"
|
"tool_arguments": "1302.678"
|
||||||
}
|
}
|
||||||
|
|
||||||
---
|
---
|
||||||
|
@ -348,23 +348,23 @@ Task: "Which city has the highest population , Guangzhou or Shanghai?"
|
||||||
|
|
||||||
Action:
|
Action:
|
||||||
{
|
{
|
||||||
"action": "search",
|
"tool_name": "search",
|
||||||
"action_input": "Population Guangzhou"
|
"tool_arguments": "Population Guangzhou"
|
||||||
}
|
}
|
||||||
Observation: ['Guangzhou has a population of 15 million inhabitants as of 2021.']
|
Observation: ['Guangzhou has a population of 15 million inhabitants as of 2021.']
|
||||||
|
|
||||||
|
|
||||||
Action:
|
Action:
|
||||||
{
|
{
|
||||||
"action": "search",
|
"tool_name": "search",
|
||||||
"action_input": "Population Shanghai"
|
"tool_arguments": "Population Shanghai"
|
||||||
}
|
}
|
||||||
Observation: '26 million (2019)'
|
Observation: '26 million (2019)'
|
||||||
|
|
||||||
Action:
|
Action:
|
||||||
{
|
{
|
||||||
"action": "final_answer",
|
"tool_name": "final_answer",
|
||||||
"action_input": "Shanghai"
|
"tool_arguments": "Shanghai"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -47,6 +47,39 @@ BASE_BUILTIN_MODULES = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class AgentError(Exception):
|
||||||
|
"""Base class for other agent-related exceptions"""
|
||||||
|
|
||||||
|
def __init__(self, message):
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
console.print(f"[bold red]{message}[/bold red]")
|
||||||
|
|
||||||
|
|
||||||
|
class AgentParsingError(AgentError):
|
||||||
|
"""Exception raised for errors in parsing in the agent"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AgentExecutionError(AgentError):
|
||||||
|
"""Exception raised for errors in execution in the agent"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AgentMaxIterationsError(AgentError):
|
||||||
|
"""Exception raised for errors in execution in the agent"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AgentGenerationError(AgentError):
|
||||||
|
"""Exception raised for errors in generation in the agent"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def parse_json_blob(json_blob: str) -> Dict[str, str]:
|
def parse_json_blob(json_blob: str) -> Dict[str, str]:
|
||||||
try:
|
try:
|
||||||
first_accolade_index = json_blob.find("{")
|
first_accolade_index = json_blob.find("{")
|
||||||
|
@ -97,17 +130,25 @@ Code:
|
||||||
def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]:
|
def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]:
|
||||||
json_blob = json_blob.replace("```json", "").replace("```", "")
|
json_blob = json_blob.replace("```json", "").replace("```", "")
|
||||||
tool_call = parse_json_blob(json_blob)
|
tool_call = parse_json_blob(json_blob)
|
||||||
if "action" in tool_call and "action_input" in tool_call:
|
tool_name_key, tool_arguments_key = None, None
|
||||||
return tool_call["action"], tool_call["action_input"]
|
for possible_tool_name_key in ["action", "tool_name", "tool", "name", "function"]:
|
||||||
elif "action" in tool_call:
|
if possible_tool_name_key in tool_call:
|
||||||
return tool_call["action"], None
|
tool_name_key = possible_tool_name_key
|
||||||
|
for possible_tool_arguments_key in [
|
||||||
|
"action_input",
|
||||||
|
"tool_arguments",
|
||||||
|
"tool_args",
|
||||||
|
"parameters",
|
||||||
|
]:
|
||||||
|
if possible_tool_arguments_key in tool_call:
|
||||||
|
tool_arguments_key = possible_tool_arguments_key
|
||||||
|
if tool_name_key is not None:
|
||||||
|
if tool_arguments_key is not None:
|
||||||
|
return tool_call[tool_name_key], tool_call[tool_arguments_key]
|
||||||
else:
|
else:
|
||||||
missing_keys = [
|
return tool_call[tool_name_key], None
|
||||||
key for key in ["action", "action_input"] if key not in tool_call
|
error_msg = "No tool name key found in tool call!" + f" Tool call: {json_blob}"
|
||||||
]
|
raise AgentParsingError(error_msg)
|
||||||
error_msg = f"Missing keys: {missing_keys} in blob {tool_call}"
|
|
||||||
console.print(f"[bold red]{error_msg}[/bold red]")
|
|
||||||
raise ValueError(error_msg)
|
|
||||||
|
|
||||||
|
|
||||||
MAX_LENGTH_TRUNCATE_CONTENT = 20000
|
MAX_LENGTH_TRUNCATE_CONTENT = 20000
|
||||||
|
@ -266,4 +307,4 @@ def instance_to_source(instance, base_cls=None):
|
||||||
return "\n".join(final_lines)
|
return "\n".join(final_lines)
|
||||||
|
|
||||||
|
|
||||||
__all__ = []
|
__all__ = ["AgentError"]
|
||||||
|
|
|
@ -25,7 +25,7 @@ from smolagents.agents import (
|
||||||
AgentMaxIterationsError,
|
AgentMaxIterationsError,
|
||||||
ManagedAgent,
|
ManagedAgent,
|
||||||
CodeAgent,
|
CodeAgent,
|
||||||
JsonAgent,
|
ToolCallingAgent,
|
||||||
Toolbox,
|
Toolbox,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
)
|
)
|
||||||
|
@ -182,7 +182,7 @@ class AgentTests(unittest.TestCase):
|
||||||
assert output == "7.2904"
|
assert output == "7.2904"
|
||||||
|
|
||||||
def test_fake_json_agent(self):
|
def test_fake_json_agent(self):
|
||||||
agent = JsonAgent(tools=[PythonInterpreterTool()], llm_engine=fake_json_llm)
|
agent = ToolCallingAgent(tools=[PythonInterpreterTool()], llm_engine=fake_json_llm)
|
||||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
assert output == "7.2904"
|
assert output == "7.2904"
|
||||||
|
@ -212,7 +212,7 @@ Action:
|
||||||
"""
|
"""
|
||||||
return Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png")
|
return Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png")
|
||||||
|
|
||||||
agent = JsonAgent(
|
agent = ToolCallingAgent(
|
||||||
tools=[fake_image_generation_tool], llm_engine=fake_json_llm_image
|
tools=[fake_image_generation_tool], llm_engine=fake_json_llm_image
|
||||||
)
|
)
|
||||||
output = agent.run("Make me an image.")
|
output = agent.run("Make me an image.")
|
||||||
|
@ -262,7 +262,7 @@ Action:
|
||||||
assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs)
|
assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs)
|
||||||
|
|
||||||
def test_setup_agent_with_empty_toolbox(self):
|
def test_setup_agent_with_empty_toolbox(self):
|
||||||
JsonAgent(llm_engine=fake_json_llm, tools=[])
|
ToolCallingAgent(llm_engine=fake_json_llm, tools=[])
|
||||||
|
|
||||||
def test_fails_max_iterations(self):
|
def test_fails_max_iterations(self):
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
|
@ -295,7 +295,7 @@ Action:
|
||||||
|
|
||||||
# check that add_base_tools will not interfere with existing tools
|
# check that add_base_tools will not interfere with existing tools
|
||||||
with pytest.raises(KeyError) as e:
|
with pytest.raises(KeyError) as e:
|
||||||
agent = JsonAgent(
|
agent = ToolCallingAgent(
|
||||||
tools=toolset_3, llm_engine=fake_json_llm, add_base_tools=True
|
tools=toolset_3, llm_engine=fake_json_llm, add_base_tools=True
|
||||||
)
|
)
|
||||||
assert "already exists in the toolbox" in str(e)
|
assert "already exists in the toolbox" in str(e)
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from smolagents import AgentImage, AgentError, CodeAgent, JsonAgent, stream_to_gradio
|
from smolagents import AgentImage, AgentError, CodeAgent, ToolCallingAgent, stream_to_gradio
|
||||||
|
|
||||||
|
|
||||||
class MonitoringTester(unittest.TestCase):
|
class MonitoringTester(unittest.TestCase):
|
||||||
|
@ -52,7 +52,7 @@ final_answer('This is the final answer.')
|
||||||
def __call__(self, prompt, **kwargs):
|
def __call__(self, prompt, **kwargs):
|
||||||
return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
|
return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
|
||||||
|
|
||||||
agent = JsonAgent(
|
agent = ToolCallingAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
llm_engine=FakeLLMEngine(),
|
llm_engine=FakeLLMEngine(),
|
||||||
max_iterations=1,
|
max_iterations=1,
|
||||||
|
@ -131,7 +131,7 @@ final_answer('This is the final answer.')
|
||||||
'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
|
'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = JsonAgent(
|
agent = ToolCallingAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
llm_engine=dummy_llm_engine,
|
llm_engine=dummy_llm_engine,
|
||||||
max_iterations=1,
|
max_iterations=1,
|
||||||
|
|
Loading…
Reference in New Issue