Add LiteLLM engine

This commit is contained in:
Aymeric 2024-12-24 17:34:14 +01:00
parent 762ae9cfae
commit 1e357cee7f
16 changed files with 403 additions and 339 deletions

View File

@ -48,9 +48,9 @@ pip install agents
``` ```
Then define your agent, give it the tools it needs and run it! Then define your agent, give it the tools it needs and run it!
```py ```py
from smolagents import CodeAgent, WebSearchTool from smolagents import CodeAgent, DuckDuckGoSearchTool, HfApiEngine
agent = CodeAgent(tools=[WebSearchTool()]) agent = CodeAgent(tools=[DuckDuckGoSearchTool()], llm_engine=HfApiEngine())
agent.run("What time would the world's fastest car take to travel from New York to San Francisco?") agent.run("What time would the world's fastest car take to travel from New York to San Francisco?")
``` ```
@ -68,6 +68,6 @@ Especially, since code execution can be a security concern (arbitrary code execu
## How lightweight is it? ## How lightweight is it?
We strived to keep abstractions to a strict minimum, with the main code in `agents.py` being roughly 1,000 lines of code, and still being quite complete, with several types of agents implemented: `CodeAgent` writing its actions in code snippets, `JsonAgent`, `ToolCallingAgent`... We strived to keep abstractions to a strict minimum, with the main code in `agents.py` being roughly 1,000 lines of code, and still being quite complete, with several types of agents implemented: `CodeAgent` writing its actions in code snippets, and the more classic `ToolCallingAgent` that leverage built-in tool calling methods.
Many people ask: why use a framework at all? Well, because a big part of this stuff is non-trivial. For instance, the code agent has to keep a consistent format for code throughout its system prompt, its parser, the execution. Its variables have to be properly handled throughout. So our framework handles this complexity for you. Many people ask: why use a framework at all? Well, because a big part of this stuff is non-trivial. For instance, the code agent has to keep a consistent format for code throughout its system prompt, its parser, the execution. So our framework handles this complexity for you. But of course we still encourage you to hack into the source code and use only the bits that you need, to the exclusion of everything else!

View File

@ -39,9 +39,9 @@ Here is a video overview of how that works:
![Framework of a React Agent](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/open-source-llms-as-agents/ReAct.png) ![Framework of a React Agent](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/open-source-llms-as-agents/ReAct.png)
We implement two versions of JsonAgent: We implement two versions of ToolCallingAgent:
- [`JsonAgent`] generates tool calls as a JSON in its output. - [`ToolCallingAgent`] generates tool calls as a JSON in its output.
- [`CodeAgent`] is a new type of JsonAgent that generates its tool calls as blobs of code, which works really well for LLMs that have strong coding performance. - [`CodeAgent`] is a new type of ToolCallingAgent that generates its tool calls as blobs of code, which works really well for LLMs that have strong coding performance.
> [!TIP] > [!TIP]
> We also provide an option to run agents in one-shot: just pass `single_step=True` when launching the agent, like `agent.run(your_task, single_step=True)` > We also provide an option to run agents in one-shot: just pass `single_step=True` when launching the agent, like `agent.run(your_task, single_step=True)`

View File

@ -69,19 +69,6 @@ agent.run(
) )
``` ```
You can even leave the argument `llm_engine` undefined, and an [`HfApiEngine`] will be created by default.
```python
from smolagents import CodeAgent
agent = CodeAgent(tools=[], add_base_tools=True)
agent.run(
"Could you give me the 118th number in the Fibonacci sequence?",
additional_detail="We adopt the convention where the first two numbers are 0 and 1."
)
```
Note that we used an additional `additional_detail` argument: you can additional kwargs to `agent.run()`, they will be baked into the prompt as text. Note that we used an additional `additional_detail` argument: you can additional kwargs to `agent.run()`, they will be baked into the prompt as text.
You can use this to indicate the path to local or remote files for the model to use: You can use this to indicate the path to local or remote files for the model to use:
@ -89,7 +76,7 @@ You can use this to indicate the path to local or remote files for the model to
```py ```py
from smolagents import CodeAgent, Tool, SpeechToTextTool from smolagents import CodeAgent, Tool, SpeechToTextTool
agent = CodeAgent(tools=[SpeechToTextTool()], add_base_tools=True) agent = CodeAgent(tools=[SpeechToTextTool()], llm_engine=llm_engine, add_base_tools=True)
agent.run("Why does Mike not know many people in New York?", audio="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/recording.mp3") agent.run("Why does Mike not know many people in New York?", audio="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/recording.mp3")
``` ```
@ -109,7 +96,7 @@ You can authorize additional imports by passing the authorized modules as a list
```py ```py
from smolagents import CodeAgent from smolagents import CodeAgent
agent = CodeAgent(tools=[], additional_authorized_imports=['requests', 'bs4']) agent = CodeAgent(tools=[], llm_engine=llm_engine, additional_authorized_imports=['requests', 'bs4'])
agent.run("Could you get me the title of the page at url 'https://huggingface.co/blog'?") agent.run("Could you get me the title of the page at url 'https://huggingface.co/blog'?")
``` ```
This gives you at the end of the agent run: This gives you at the end of the agent run:
@ -178,11 +165,11 @@ You could improve the system prompt, for example, by adding an explanation of th
For maximum flexibility, you can overwrite the whole system prompt template by passing your custom prompt as an argument to the `system_prompt` parameter. For maximum flexibility, you can overwrite the whole system prompt template by passing your custom prompt as an argument to the `system_prompt` parameter.
```python ```python
from smolagents import JsonAgent, PythonInterpreterTool, JSON_SYSTEM_PROMPT from smolagents import ToolCallingAgent, PythonInterpreterTool, JSON_SYSTEM_PROMPT
modified_prompt = JSON_SYSTEM_PROMPT modified_prompt = JSON_SYSTEM_PROMPT
agent = JsonAgent(tools=[PythonInterpreterTool()], system_prompt=modified_prompt) agent = ToolCallingAgent(tools=[PythonInterpreterTool()], llm_engine=llm_engine, system_prompt=modified_prompt)
``` ```
> [!WARNING] > [!WARNING]
@ -209,7 +196,7 @@ When the agent is initialized, the tool attributes are used to generate a tool d
Transformers comes with a default toolbox for empowering agents, that you can add to your agent upon initialization with argument `add_base_tools = True`: Transformers comes with a default toolbox for empowering agents, that you can add to your agent upon initialization with argument `add_base_tools = True`:
- **DuckDuckGo web search***: performs a web search using DuckDuckGo browser. - **DuckDuckGo web search***: performs a web search using DuckDuckGo browser.
- **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`JsonAgent`] if you initialize it with `add_base_tools=True`, since code-based agent can already natively execute Python code - **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`ToolCallingAgent`] if you initialize it with `add_base_tools=True`, since code-based agent can already natively execute Python code
- **Transcriber**: a speech-to-text pipeline built on Whisper-Turbo that transcribes an audio to text. - **Transcriber**: a speech-to-text pipeline built on Whisper-Turbo that transcribes an audio to text.
You can manually use a tool by calling the [`load_tool`] function and a task to perform. You can manually use a tool by calling the [`load_tool`] function and a task to perform.

View File

@ -31,7 +31,7 @@ Our agents inherit from [`MultiStepAgent`], which means they can act in multiple
We provide two types of agents, based on the main [`Agent`] class. We provide two types of agents, based on the main [`Agent`] class.
- [`CodeAgent`] is the default agent, it writes its tool calls in Python code. - [`CodeAgent`] is the default agent, it writes its tool calls in Python code.
- [`JsonAgent`] writes its tool calls in JSON. - [`ToolCallingAgent`] writes its tool calls in JSON.
### Classes of agents ### Classes of agents
@ -40,7 +40,7 @@ We provide two types of agents, based on the main [`Agent`] class.
[[autodoc]] CodeAgent [[autodoc]] CodeAgent
[[autodoc]] JsonAgent [[autodoc]] ToolCallingAgent
### ManagedAgent ### ManagedAgent

View File

@ -68,9 +68,10 @@ To set the code executor to E2B, simply pass the flag `use_e2b_executor=True` wh
Note that you should add all the tool's dependencies in `additional_authorized_imports`, so that the executor installs them. Note that you should add all the tool's dependencies in `additional_authorized_imports`, so that the executor installs them.
```py ```py
from smolagents import CodeAgent, VisitWebpageTool from smolagents import CodeAgent, VisitWebpageTool, HfApiEngine
agent = CodeAgent( agent = CodeAgent(
tools = [VisitWebpageTool()], tools = [VisitWebpageTool()],
llm_engine=HfApiEngine(),
additional_authorized_imports=["requests", "markdownify"], additional_authorized_imports=["requests", "markdownify"],
use_e2b_executor=True use_e2b_executor=True
) )

View File

@ -1,5 +1,5 @@
from smolagents import Tool, CodeAgent from smolagents import Tool, CodeAgent, HfApiEngine
from smolagents.default_tools.search import VisitWebpageTool from smolagents.default_tools import VisitWebpageTool
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
@ -30,6 +30,7 @@ get_cat_image = GetCatImageTool()
agent = CodeAgent( agent = CodeAgent(
tools = [get_cat_image, VisitWebpageTool()], tools = [get_cat_image, VisitWebpageTool()],
llm_engine=HfApiEngine(),
additional_authorized_imports=["Pillow", "requests", "markdownify"], # "duckduckgo-search", additional_authorized_imports=["Pillow", "requests", "markdownify"], # "duckduckgo-search",
use_e2b_executor=False use_e2b_executor=False
) )

View File

@ -1,10 +1,12 @@
from smolagents.agents import ToolCallingAgent from smolagents.agents import ToolCallingAgent
from smolagents import tool, HfApiEngine, OpenAIEngine, AnthropicEngine from smolagents import tool, HfApiEngine, OpenAIEngine, AnthropicEngine, TransformersEngine, LiteLLMEngine
# Choose which LLM engine to use! # Choose which LLM engine to use!
llm_engine = OpenAIEngine("gpt-4o") # llm_engine = OpenAIEngine("gpt-4o")
llm_engine = AnthropicEngine("claude-3-5-sonnet-20240620") # llm_engine = AnthropicEngine("claude-3-5-sonnet-20240620")
llm_engine = HfApiEngine("meta-llama/Llama-3.3-70B-Instruct") # llm_engine = HfApiEngine("meta-llama/Llama-3.3-70B-Instruct")
# llm_engine = TransformersEngine("meta-llama/Llama-3.2-2B-Instruct")
llm_engine = LiteLLMEngine()
@tool @tool
def get_weather(location: str) -> str: def get_weather(location: str) -> str:

View File

@ -23,10 +23,20 @@ from rich.panel import Panel
from rich.rule import Rule from rich.rule import Rule
from rich.text import Text from rich.text import Text
from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content from .utils import (
console,
parse_code_blob,
parse_json_tool_call,
truncate_content,
AgentError,
AgentParsingError,
AgentExecutionError,
AgentGenerationError,
AgentMaxIterationsError,
)
from .types import AgentAudio, AgentImage from .types import AgentAudio, AgentImage
from .default_tools import FinalAnswerTool from .default_tools import FinalAnswerTool
from .llm_engines import HfApiEngine, MessageRole from .llm_engines import MessageRole
from .monitoring import Monitor from .monitoring import Monitor
from .prompts import ( from .prompts import (
CODE_SYSTEM_PROMPT, CODE_SYSTEM_PROMPT,
@ -52,39 +62,6 @@ from .tools import (
) )
class AgentError(Exception):
"""Base class for other agent-related exceptions"""
def __init__(self, message):
super().__init__(message)
self.message = message
console.print(f"[bold red]{message}[/bold red]")
class AgentParsingError(AgentError):
"""Exception raised for errors in parsing in the agent"""
pass
class AgentExecutionError(AgentError):
"""Exception raised for errors in execution in the agent"""
pass
class AgentMaxIterationsError(AgentError):
"""Exception raised for errors in execution in the agent"""
pass
class AgentGenerationError(AgentError):
"""Exception raised for errors in generation in the agent"""
pass
@dataclass @dataclass
class ToolCall: class ToolCall:
name: str name: str
@ -171,8 +148,10 @@ def format_prompt_with_managed_agents_descriptions(
else: else:
return prompt_template.replace(agent_descriptions_placeholder, "") return prompt_template.replace(agent_descriptions_placeholder, "")
YELLOW_HEX = "#ffdd00" YELLOW_HEX = "#ffdd00"
class MultiStepAgent: class MultiStepAgent:
""" """
Agent class that solves the given task step by step, using the ReAct framework: Agent class that solves the given task step by step, using the ReAct framework:
@ -182,7 +161,7 @@ class MultiStepAgent:
def __init__( def __init__(
self, self,
tools: Union[List[Tool], Toolbox], tools: Union[List[Tool], Toolbox],
llm_engine: Optional[Callable[[List[Dict[str, str]]], str]] = None, llm_engine: Callable[[List[Dict[str, str]]], str],
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
tool_description_template: Optional[str] = None, tool_description_template: Optional[str] = None,
max_iterations: int = 6, max_iterations: int = 6,
@ -195,8 +174,6 @@ class MultiStepAgent:
monitor_metrics: bool = True, monitor_metrics: bool = True,
planning_interval: Optional[int] = None, planning_interval: Optional[int] = None,
): ):
if llm_engine is None:
llm_engine = HfApiEngine()
if system_prompt is None: if system_prompt is None:
system_prompt = CODE_SYSTEM_PROMPT system_prompt = CODE_SYSTEM_PROMPT
if tool_parser is None: if tool_parser is None:
@ -222,14 +199,14 @@ class MultiStepAgent:
self._toolbox = tools self._toolbox = tools
if add_base_tools: if add_base_tools:
self._toolbox.add_base_tools( self._toolbox.add_base_tools(
add_python_interpreter=(self.__class__ == JsonAgent) add_python_interpreter=(self.__class__ == ToolCallingAgent)
) )
else: else:
self._toolbox = Toolbox(tools, add_base_tools=add_base_tools) self._toolbox = Toolbox(tools, add_base_tools=add_base_tools)
self._toolbox.add_tool(FinalAnswerTool()) self._toolbox.add_tool(FinalAnswerTool())
self.system_prompt = self.initialize_system_prompt() self.system_prompt = self.initialize_system_prompt()
self.prompt_messages = None self.input_messages = None
self.logs = [] self.logs = []
self.task = None self.task = None
self.verbose = verbose self.verbose = verbose
@ -384,23 +361,23 @@ class MultiStepAgent:
""" """
This method provides a final answer to the task, based on the logs of the agent's interactions. This method provides a final answer to the task, based on the logs of the agent's interactions.
""" """
self.prompt_messages = [ self.input_messages = [
{ {
"role": MessageRole.SYSTEM, "role": MessageRole.SYSTEM,
"content": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:", "content": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
} }
] ]
self.prompt_messages += self.write_inner_memory_from_logs()[1:] self.input_messages += self.write_inner_memory_from_logs()[1:]
self.prompt_messages += [ self.input_messages += [
{ {
"role": MessageRole.USER, "role": MessageRole.USER,
"content": f"Based on the above, please provide an answer to the following user request:\n{task}", "content": f"Based on the above, please provide an answer to the following user request:\n{task}",
} }
] ]
try: try:
return self.llm_engine(self.prompt_messages) return self.llm_engine(self.input_messages)
except Exception as e: except Exception as e:
error_msg = f"Error in generating final LLM output: {e}." error_msg = f"Error in generating final LLM output:\n{e}"
console.print(f"[bold red]{error_msg}[/bold red]") console.print(f"[bold red]{error_msg}[/bold red]")
return error_msg return error_msg
@ -490,16 +467,20 @@ class MultiStepAgent:
system_prompt_step = SystemPromptStep(system_prompt=self.system_prompt) system_prompt_step = SystemPromptStep(system_prompt=self.system_prompt)
if reset: if reset:
self.token_count = 0
self.logs = [] self.logs = []
self.logs.append(system_prompt_step) self.logs.append(system_prompt_step)
self.monitor.reset()
else: else:
if len(self.logs) > 0: if len(self.logs) > 0:
self.logs[0] = system_prompt_step self.logs[0] = system_prompt_step
else: else:
self.logs.append(system_prompt_step) self.logs.append(system_prompt_step)
console.print(Group(Rule("[bold]New run", characters="", style=YELLOW_HEX), Text(self.task))) console.print(
Group(
Rule("[bold]New run", characters="", style=YELLOW_HEX), Text(self.task)
)
)
self.logs.append(TaskStep(task=self.task)) self.logs.append(TaskStep(task=self.task))
if single_step: if single_step:
@ -534,7 +515,9 @@ class MultiStepAgent:
self.planning_step( self.planning_step(
task, is_first_step=(iteration == 0), iteration=iteration task, is_first_step=(iteration == 0), iteration=iteration
) )
console.print(Rule(f"[bold]Step {iteration}", characters="", style=YELLOW_HEX)) console.print(
Rule(f"[bold]Step {iteration}", characters="", style=YELLOW_HEX)
)
# Run one step! # Run one step!
final_answer = self.step(step_log) final_answer = self.step(step_log)
@ -580,7 +563,9 @@ class MultiStepAgent:
self.planning_step( self.planning_step(
task, is_first_step=(iteration == 0), iteration=iteration task, is_first_step=(iteration == 0), iteration=iteration
) )
console.print(Rule(f"[bold]Step {iteration}", characters="", style=YELLOW_HEX)) console.print(
Rule(f"[bold]Step {iteration}", characters="", style=YELLOW_HEX)
)
# Run one step! # Run one step!
final_answer = self.step(step_log) final_answer = self.step(step_log)
@ -727,138 +712,20 @@ Now begin!""",
) )
class JsonAgent(MultiStepAgent):
"""
In this agent, the tool calls will be formulated by the LLM in JSON format, then parsed and executed.
"""
def __init__(
self,
tools: List[Tool],
llm_engine: Optional[Callable] = None,
system_prompt: Optional[str] = None,
grammar: Optional[Dict[str, str]] = None,
planning_interval: Optional[int] = None,
**kwargs,
):
if llm_engine is None:
llm_engine = HfApiEngine()
if system_prompt is None:
system_prompt = JSON_SYSTEM_PROMPT
super().__init__(
tools=tools,
llm_engine=llm_engine,
system_prompt=system_prompt,
grammar=grammar,
planning_interval=planning_interval,
**kwargs,
)
def step(self, log_entry: ActionStep) -> Union[None, Any]:
"""
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
Returns None if the step is not final.
"""
agent_memory = self.write_inner_memory_from_logs()
self.prompt_messages = agent_memory
# Add new step in logs
log_entry.agent_memory = agent_memory.copy()
try:
additional_args = (
{"grammar": self.grammar} if self.grammar is not None else {}
)
llm_output = self.llm_engine(
self.prompt_messages,
stop_sequences=["<end_action>", "Observation:"],
**additional_args,
)
log_entry.llm_output = llm_output
except Exception as e:
raise AgentGenerationError(f"Error in generating llm_engine output: {e}.")
if self.verbose:
console.print(
Group(
Rule(
"[italic]Output message of the LLM:",
align="left",
style="orange",
),
Text(llm_output),
)
)
# Parse
rationale, action = self.extract_action(
llm_output=llm_output, split_token="Action:"
)
try:
tool_name, tool_arguments = self.tool_parser(action)
except Exception as e:
raise AgentParsingError(f"Could not parse the given action: {e}.")
log_entry.tool_call = ToolCall(
tool_name=tool_name, tool_arguments=tool_arguments
)
# Execute
console.print(Rule("Agent thoughts:", align="left"), Text(rationale))
console.print(
Panel(Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}"))
)
if tool_name == "final_answer":
if isinstance(tool_arguments, dict):
if "answer" in tool_arguments:
answer = tool_arguments["answer"]
else:
answer = tool_arguments
else:
answer = tool_arguments
if (
isinstance(answer, str) and answer in self.state.keys()
): # if the answer is a state variable, return the value
answer = self.state[answer]
log_entry.action_output = answer
return answer
else:
if tool_arguments is None:
tool_arguments = {}
observation = self.execute_tool_call(tool_name, tool_arguments)
observation_type = type(observation)
if observation_type in [AgentImage, AgentAudio]:
if observation_type == AgentImage:
observation_name = "image.png"
elif observation_type == AgentAudio:
observation_name = "audio.mp3"
# TODO: observation naming could allow for different names of same type
self.state[observation_name] = observation
updated_information = f"Stored '{observation_name}' in memory."
else:
updated_information = str(observation).strip()
log_entry.observations = updated_information
return None
class ToolCallingAgent(MultiStepAgent): class ToolCallingAgent(MultiStepAgent):
""" """
This agent uses JSON-like tool calls, but to the difference of JsonAgents, it leverages the underlying librarie's tool calling facilities. This agent uses JSON-like tool calls, using method `llm_engine.get_tool_call` to leverage the LLM engine's tool calling capabilities.
""" """
def __init__( def __init__(
self, self,
tools: List[Tool], tools: List[Tool],
llm_engine: Optional[Callable] = None, llm_engine: Callable,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
planning_interval: Optional[int] = None, planning_interval: Optional[int] = None,
**kwargs, **kwargs,
): ):
if llm_engine is None:
llm_engine = HfApiEngine()
if system_prompt is None: if system_prompt is None:
system_prompt = TOOL_CALLING_SYSTEM_PROMPT system_prompt = TOOL_CALLING_SYSTEM_PROMPT
super().__init__( super().__init__(
@ -876,17 +743,19 @@ class ToolCallingAgent(MultiStepAgent):
""" """
agent_memory = self.write_inner_memory_from_logs() agent_memory = self.write_inner_memory_from_logs()
self.prompt_messages = agent_memory self.input_messages = agent_memory
# Add new step in logs # Add new step in logs
log_entry.agent_memory = agent_memory.copy() log_entry.agent_memory = agent_memory.copy()
try: try:
tool_name, tool_arguments, tool_call_id = self.llm_engine.get_tool_call( tool_name, tool_arguments, tool_call_id = self.llm_engine.get_tool_call(
self.prompt_messages, available_tools=list(self.toolbox._tools.values()) self.input_messages,
available_tools=list(self.toolbox._tools.values()),
stop_sequences=["Observation:"],
) )
except Exception as e: except Exception as e:
raise AgentGenerationError(f"Error in generating llm_engine output: {e}.") raise AgentGenerationError(f"Error in generating tool call with llm_engine:\n{e}")
log_entry.tool_call = ToolCall( log_entry.tool_call = ToolCall(
name=tool_name, arguments=tool_arguments, id=tool_call_id name=tool_name, arguments=tool_arguments, id=tool_call_id
@ -938,7 +807,7 @@ class CodeAgent(MultiStepAgent):
def __init__( def __init__(
self, self,
tools: List[Tool], tools: List[Tool],
llm_engine: Optional[Callable] = None, llm_engine: Callable,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
grammar: Optional[Dict[str, str]] = None, grammar: Optional[Dict[str, str]] = None,
additional_authorized_imports: Optional[List[str]] = None, additional_authorized_imports: Optional[List[str]] = None,
@ -946,8 +815,6 @@ class CodeAgent(MultiStepAgent):
use_e2b_executor: bool = False, use_e2b_executor: bool = False,
**kwargs, **kwargs,
): ):
if llm_engine is None:
llm_engine = HfApiEngine()
if system_prompt is None: if system_prompt is None:
system_prompt = CODE_SYSTEM_PROMPT system_prompt = CODE_SYSTEM_PROMPT
super().__init__( super().__init__(
@ -994,7 +861,7 @@ class CodeAgent(MultiStepAgent):
""" """
agent_memory = self.write_inner_memory_from_logs() agent_memory = self.write_inner_memory_from_logs()
self.prompt_messages = agent_memory.copy() self.input_messages = agent_memory.copy()
# Add new step in logs # Add new step in logs
log_entry.agent_memory = agent_memory.copy() log_entry.agent_memory = agent_memory.copy()
@ -1004,13 +871,13 @@ class CodeAgent(MultiStepAgent):
{"grammar": self.grammar} if self.grammar is not None else {} {"grammar": self.grammar} if self.grammar is not None else {}
) )
llm_output = self.llm_engine( llm_output = self.llm_engine(
self.prompt_messages, self.input_messages,
stop_sequences=["<end_action>", "Observation:"], stop_sequences=["<end_action>", "Observation:"],
**additional_args, **additional_args,
) )
log_entry.llm_output = llm_output log_entry.llm_output = llm_output
except Exception as e: except Exception as e:
raise AgentGenerationError(f"Error in generating llm_engine output: {e}.") raise AgentGenerationError(f"Error in generating llm_engine output:\n{e}")
if self.verbose: if self.verbose:
console.print( console.print(
@ -1026,17 +893,7 @@ class CodeAgent(MultiStepAgent):
# Parse # Parse
try: try:
rationale, raw_code_action = self.extract_action( code_action = parse_code_blob(llm_output)
llm_output=llm_output, split_token="Code:"
)
except Exception as e:
console.print(
f"Error in extracting action, trying to parse the whole output. Error trace: {e}"
)
rationale, raw_code_action = llm_output, llm_output
try:
code_action = parse_code_blob(raw_code_action)
except Exception as e: except Exception as e:
error_msg = f"Error in code parsing: {e}. Make sure to provide correct code" error_msg = f"Error in code parsing: {e}. Make sure to provide correct code"
raise AgentParsingError(error_msg) raise AgentParsingError(error_msg)
@ -1048,11 +905,6 @@ class CodeAgent(MultiStepAgent):
) )
# Execute # Execute
if self.verbose:
console.print(
Group(Rule("[italic]Agent thoughts", align="left"), Text(rationale))
)
console.print( console.print(
Panel( Panel(
Syntax(code_action, lexer="python", theme="github-dark"), Syntax(code_action, lexer="python", theme="github-dark"),
@ -1148,10 +1000,8 @@ class ManagedAgent:
__all__ = [ __all__ = [
"AgentError",
"ManagedAgent", "ManagedAgent",
"MultiStepAgent", "MultiStepAgent",
"CodeAgent", "CodeAgent",
"JsonAgent",
"Toolbox", "Toolbox",
] ]

View File

@ -17,7 +17,7 @@
import json import json
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict from typing import Dict, Optional
from huggingface_hub import hf_hub_download, list_spaces from huggingface_hub import hf_hub_download, list_spaces
from transformers.utils import is_offline_mode from transformers.utils import is_offline_mode
@ -139,7 +139,7 @@ class UserInputTool(Tool):
return user_input return user_input
class WebSearchTool(Tool): class DuckDuckGoSearchTool(Tool):
name = "web_search" name = "web_search"
description = """Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results as a list of dict elements. description = """Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
Each result has keys 'title', 'href' and 'body'.""" Each result has keys 'title', 'href' and 'body'."""
@ -148,7 +148,7 @@ class WebSearchTool(Tool):
} }
output_type = "any" output_type = "any"
def forward(self, query: str) -> str: def forward(self, query: str) -> list[dict[str, str]]:
try: try:
from duckduckgo_search import DDGS from duckduckgo_search import DDGS
except ImportError: except ImportError:
@ -159,6 +159,85 @@ class WebSearchTool(Tool):
return results return results
class GoogleSearchTool(Tool):
name = "web_search"
description = """Performs a google web search for your query then returns a string of the top search results."""
inputs = {
"query": {"type": "string", "description": "The search query to perform."},
"filter_year": {
"type": "integer",
"description": "Optionally restrict results to a certain year",
},
}
output_type = "string"
def __init__(self):
super().__init__(self)
import os
self.serpapi_key = os.getenv("SERPAPI_API_KEY")
def forward(self, query: str, filter_year: Optional[int] = None) -> str:
import requests
if self.serpapi_key is None:
raise ValueError(
"Missing SerpAPI key. Make sure you have 'SERPAPI_API_KEY' in your env variables."
)
params = {
"engine": "google",
"q": query,
"api_key": self.serpapi_key,
"google_domain": "google.com",
}
if filter_year is not None:
params["tbs"] = (
f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}"
)
response = requests.get("https://serpapi.com/search.json", params=params)
if response.status_code == 200:
results = response.json()
else:
raise ValueError(response.json())
if "organic_results" not in results.keys():
raise Exception(
f"'organic_results' key not found for query: '{query}'. Use a less restrictive query."
)
if len(results["organic_results"]) == 0:
year_filter_message = (
f" with filter year={filter_year}" if filter_year is not None else ""
)
return f"No results found for '{query}'{year_filter_message}. Try with a more general query, or remove the year filter."
web_snippets = []
if "organic_results" in results:
for idx, page in enumerate(results["organic_results"]):
date_published = ""
if "date" in page:
date_published = "\nDate published: " + page["date"]
source = ""
if "source" in page:
source = "\nSource: " + page["source"]
snippet = ""
if "snippet" in page:
snippet = "\n" + page["snippet"]
redacted_version = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{snippet}"
redacted_version = redacted_version.replace(
"Your browser can't play this video.", ""
)
web_snippets.append(redacted_version)
return "## Web Results\n" + "\n\n".join(web_snippets)
class VisitWebpageTool(Tool): class VisitWebpageTool(Tool):
name = "visit_webpage" name = "visit_webpage"
description = "Visits a webpage at the given url and returns its content as a markdown string." description = "Visits a webpage at the given url and returns its content as a markdown string."
@ -223,7 +302,8 @@ __all__ = [
"PythonInterpreterTool", "PythonInterpreterTool",
"FinalAnswerTool", "FinalAnswerTool",
"UserInputTool", "UserInputTool",
"WebSearchTool", "DuckDuckGoSearchTool",
"GoogleSearchTool",
"VisitWebpageTool", "VisitWebpageTool",
"SpeechToTextTool", "SpeechToTextTool",
] ]

View File

@ -328,7 +328,7 @@ if __name__ == '__main__':
if self.socket: if self.socket:
try: try:
self.socket.close() self.socket.close()
except: except Exception:
pass pass
if self.container: if self.container:

View File

@ -17,13 +17,17 @@
from copy import deepcopy from copy import deepcopy
from enum import Enum from enum import Enum
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from transformers import AutoTokenizer, Pipeline from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
import logging import logging
import os import os
import random
from openai import OpenAI from openai import OpenAI
from huggingface_hub import InferenceClient from huggingface_hub import InferenceClient
from smolagents import Tool from .tools import Tool
from .utils import parse_json_tool_call
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -64,7 +68,7 @@ def get_json_schema(tool: Tool) -> Dict:
"description": tool.description, "description": tool.description,
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": tool.inputs, "properties": {k: {k2: v2.replace("any", "object") for k2, v2 in v.items()} for k, v in tool.inputs.items()},
"required": list(tool.inputs.keys()), "required": list(tool.inputs.keys()),
}, },
}, },
@ -91,7 +95,8 @@ def remove_stop_sequences(content: str, stop_sequences: List[str]) -> str:
def get_clean_message_list( def get_clean_message_list(
message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {} message_list: List[Dict[str, str]],
role_conversions: Dict[MessageRole, MessageRole] = {},
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
""" """
Subsequent messages with the same role will be concatenated to a single message. Subsequent messages with the same role will be concatenated to a single message.
@ -274,25 +279,40 @@ class HfApiEngine(HfEngine):
class TransformersEngine(HfEngine): class TransformersEngine(HfEngine):
"""This engine uses a pre-initialized local text-generation pipeline.""" """This engine initializes a model and tokenizer from the given `model_id`."""
def __init__(self, pipeline: Pipeline, model_id: Optional[str] = None): def __init__(self, model_id: Optional[str] = None):
super().__init__() super().__init__()
default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
if model_id is None: if model_id is None:
model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct" model_id = default_model_id
logger.warning( logger.warning(
f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'" f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'"
) )
try: try:
self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained(model_id)
except Exception as e: except Exception as e:
logger.warning( logger.warning(
f"Failed to load tokenizer for model {model_id}: {e}. Loading default tokenizer instead." f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {model_id=}."
) )
self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
"HuggingFaceTB/SmolLM2-1.7B-Instruct" self.model = AutoModelForCausalLM.from_pretrained(default_model_id)
)
self.pipeline = pipeline def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
class StopOnTokens(StoppingCriteria):
def __init__(self, stop_token_ids):
self.stop_token_ids = stop_token_ids
def __call__(self, input_ids, scores):
for stop_ids in self.stop_token_ids:
if input_ids[0][-len(stop_ids) :].tolist() == stop_ids:
return True
return False
stop_token_ids = [self.tokenizer.encode("Observation:")[1:]] # Remove BOS token
stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_token_ids)])
return stopping_criteria
def generate( def generate(
self, self,
@ -306,45 +326,93 @@ class TransformersEngine(HfEngine):
) )
# Get LLM output # Get LLM output
if stop_sequences is not None and len(stop_sequences) > 0: prompt = self.tokenizer.apply_chat_template(
stop_strings = stop_sequences
else:
stop_strings = None
output = self.pipeline(
messages, messages,
stop_strings=stop_strings, return_tensors="pt",
max_length=max_tokens, return_dict=True,
tokenizer=self.pipeline.tokenizer, )
prompt = prompt.to(self.model.device)
count_prompt_tokens = prompt["input_ids"].shape[1]
out = self.model.generate(
**prompt,
max_new_tokens=max_tokens,
stopping_criteria=(
self.make_stopping_criteria(stop_sequences) if stop_sequences else None
),
)
generated_tokens = out[0, count_prompt_tokens:]
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
self.last_input_token_count = count_prompt_tokens
self.last_output_token_count = len(generated_tokens)
if stop_sequences is not None:
response = remove_stop_sequences(response, stop_sequences)
return response
def get_tool_call(
self,
messages: List[Dict[str, str]],
available_tools: List[Tool],
stop_sequences: Optional[List[str]] = None,
max_tokens: int = 500,
) -> str:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
) )
response = output[0]["generated_text"][-1]["content"] prompt = self.tokenizer.apply_chat_template(
self.last_input_token_count = len( messages,
self.tokenizer.apply_chat_template(messages, tokenize=True) tools=[get_json_schema(tool) for tool in available_tools],
return_tensors="pt",
return_dict=True,
add_generation_prompt=True,
) )
self.last_output_token_count = len(self.tokenizer.encode(response)) prompt = prompt.to(self.model.device)
return response count_prompt_tokens = prompt["input_ids"].shape[1]
out = self.model.generate(
**prompt,
max_new_tokens=max_tokens,
stopping_criteria=(
self.make_stopping_criteria(stop_sequences) if stop_sequences else None
),
)
generated_tokens = out[0, count_prompt_tokens:]
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
self.last_input_token_count = count_prompt_tokens
self.last_output_token_count = len(generated_tokens)
if stop_sequences is not None:
response = remove_stop_sequences(response, stop_sequences)
tool_name, tool_input = parse_json_tool_call(response)
call_id = "".join(random.choices("0123456789", k=5))
return tool_name, tool_input, call_id
class OpenAIEngine: class OpenAIEngine:
def __init__( def __init__(
self, self,
model_name: Optional[str] = None, model_id: Optional[str] = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
base_url: Optional[str] = None, base_url: Optional[str] = None,
): ):
"""Creates a LLM Engine that follows OpenAI format. """Creates a LLM Engine that follows OpenAI format.
Args: Args:
model_name (`str`, *optional*): the model name to use. model_id (`str`, *optional*): the model name to use.
api_key (`str`, *optional*): your API key. api_key (`str`, *optional*): your API key.
base_url (`str`, *optional*): the URL to use if using a different inference service than OpenAI, for instance "https://api-inference.huggingface.co/v1/". base_url (`str`, *optional*): the URL to use if using a different inference service than OpenAI, for instance "https://api-inference.huggingface.co/v1/".
""" """
if model_name is None: if model_id is None:
model_name = "gpt-4o" model_id = "gpt-4o"
if api_key is None: if api_key is None:
api_key = os.getenv("OPENAI_API_KEY") api_key = os.getenv("OPENAI_API_KEY")
self.model_name = model_name self.model_id = model_id
self.client = OpenAI( self.client = OpenAI(
base_url=base_url, base_url=base_url,
api_key=api_key, api_key=api_key,
@ -364,7 +432,7 @@ class OpenAIEngine:
) )
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=self.model_name, model=self.model_id,
messages=messages, messages=messages,
stop=stop_sequences, stop=stop_sequences,
temperature=0.5, temperature=0.5,
@ -378,13 +446,14 @@ class OpenAIEngine:
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
available_tools: List[Tool], available_tools: List[Tool],
stop_sequences: Optional[List[str]] = None,
): ):
"""Generates a tool call for the given message list. This method is used only by `ToolCallingAgent`.""" """Generates a tool call for the given message list. This method is used only by `ToolCallingAgent`."""
messages = get_clean_message_list( messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions messages, role_conversions=tool_role_conversions
) )
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=self.model_name, model=self.model_id,
messages=messages, messages=messages,
tools=[get_json_schema(tool) for tool in available_tools], tools=[get_json_schema(tool) for tool in available_tools],
tool_choice="required", tool_choice="required",
@ -396,12 +465,12 @@ class OpenAIEngine:
class AnthropicEngine: class AnthropicEngine:
def __init__(self, model_name="claude-3-5-sonnet-20240620", use_bedrock=False): def __init__(self, model_id="claude-3-5-sonnet-20240620", use_bedrock=False):
from anthropic import Anthropic, AnthropicBedrock from anthropic import Anthropic, AnthropicBedrock
self.model_name = model_name self.model_id = model_id
if use_bedrock: if use_bedrock:
self.model_name = "anthropic.claude-3-5-sonnet-20240620-v1:0" self.model_id = "anthropic.claude-3-5-sonnet-20240620-v1:0"
self.client = AnthropicBedrock( self.client = AnthropicBedrock(
aws_access_key=os.getenv("AWS_BEDROCK_ID"), aws_access_key=os.getenv("AWS_BEDROCK_ID"),
aws_secret_key=os.getenv("AWS_BEDROCK_KEY"), aws_secret_key=os.getenv("AWS_BEDROCK_KEY"),
@ -454,7 +523,7 @@ class AnthropicEngine:
messages messages
) )
response = self.client.messages.create( response = self.client.messages.create(
model=self.model_name, model=self.model_id,
system=system_prompt, system=system_prompt,
messages=filtered_messages, messages=filtered_messages,
stop_sequences=stop_sequences, stop_sequences=stop_sequences,
@ -471,6 +540,7 @@ class AnthropicEngine:
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
available_tools: List[Tool], available_tools: List[Tool],
stop_sequences: Optional[List[str]] = None,
max_tokens: int = 1500, max_tokens: int = 1500,
): ):
"""Generates a tool call for the given message list. This method is used only by `ToolCallingAgent`.""" """Generates a tool call for the given message list. This method is used only by `ToolCallingAgent`."""
@ -481,7 +551,7 @@ class AnthropicEngine:
messages messages
) )
response = self.client.messages.create( response = self.client.messages.create(
model=self.model_name, model=self.model_id,
system=system_prompt, system=system_prompt,
messages=filtered_messages, messages=filtered_messages,
tools=[get_json_schema_anthropic(tool) for tool in available_tools], tools=[get_json_schema_anthropic(tool) for tool in available_tools],
@ -493,6 +563,36 @@ class AnthropicEngine:
self.last_output_token_count = response.usage.output_tokens self.last_output_token_count = response.usage.output_tokens
return tool_call.name, tool_call.input, tool_call.id return tool_call.name, tool_call.input, tool_call.id
class LiteLLMEngine():
def __init__(self, model_id="anthropic/claude-3-5-sonnet-20240620"):
self.model_id = model_id
import os, litellm
# IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs
litellm.add_function_to_prompt = True
def get_tool_call(
self,
messages: List[Dict[str, str]],
available_tools: List[Tool],
stop_sequences: Optional[List[str]] = None,
max_tokens: int = 1500,
):
from litellm import completion
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
response = completion(
model=self.model_id,
messages=messages,
tools=[get_json_schema(tool) for tool in available_tools],
tool_choice="required",
max_tokens=max_tokens,
stop=stop_sequences,
)
tool_calls = response.choices[0].message.tool_calls[0]
return tool_calls.function.name, tool_calls.function.arguments, tool_calls.id
__all__ = [ __all__ = [
"MessageRole", "MessageRole",
@ -503,4 +603,5 @@ __all__ = [
"HfApiEngine", "HfApiEngine",
"OpenAIEngine", "OpenAIEngine",
"AnthropicEngine", "AnthropicEngine",
"LiteLLMEngine",
] ]

View File

@ -16,7 +16,6 @@
# limitations under the License. # limitations under the License.
from .utils import console from .utils import console
from rich.text import Text from rich.text import Text
from rich.console import Group
class Monitor: class Monitor:
@ -38,7 +37,9 @@ class Monitor:
def update_metrics(self, step_log): def update_metrics(self, step_log):
step_duration = step_log.duration step_duration = step_log.duration
self.step_durations.append(step_duration) self.step_durations.append(step_duration)
console_outputs = f"[Step {len(self.step_durations)-1}: Duration {step_duration:.2f} seconds" console_outputs = (
f"[Step {len(self.step_durations)-1}: Duration {step_duration:.2f} seconds"
)
if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None: if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None:
self.total_input_token_count += ( self.total_input_token_count += (

View File

@ -152,8 +152,8 @@ Specifically, this json should have an `action` key (name of the tool to use) an
The $ACTION_JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. It should be formatted in json. Do not try to escape special characters. Here is the template of a valid $ACTION_JSON_BLOB: The $ACTION_JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. It should be formatted in json. Do not try to escape special characters. Here is the template of a valid $ACTION_JSON_BLOB:
{ {
"action": $TOOL_NAME, "tool_name": $TOOL_NAME,
"action_input": $INPUT "tool_arguments": $INPUT
}<end_action> }<end_action>
Make sure to have the $INPUT as a dictionary in the right format for the tool you are using, and do not put variable names as input if you can find the right values. Make sure to have the $INPUT as a dictionary in the right format for the tool you are using, and do not put variable names as input if you can find the right values.
@ -175,15 +175,15 @@ Observation: "image_1.jpg"
Thought: I need to transform the image that I received in the previous observation to make it green. Thought: I need to transform the image that I received in the previous observation to make it green.
Action: Action:
{ {
"action": "image_transformer", "tool_name": "image_transformer",
"action_input": {"image": "image_1.jpg"} "tool_arguments": {"image": "image_1.jpg"}
}<end_action> }<end_action>
To provide the final answer to the task, use an action blob with "action": "final_answer" tool. It is the only way to complete the task, else you will be stuck on a loop. So your final output should look like this: To provide the final answer to the task, use an action blob with "tool_name": "final_answer" tool. It is the only way to complete the task, else you will be stuck on a loop. So your final output should look like this:
Action: Action:
{ {
"action": "final_answer", "tool_name": "final_answer",
"action_input": {"answer": "insert your final answer here"} "tool_arguments": {"answer": "insert your final answer here"}
}<end_action> }<end_action>
@ -194,8 +194,8 @@ Task: "Generate an image of the oldest person in this document."
Thought: I will proceed step by step and use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer. Thought: I will proceed step by step and use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
Action: Action:
{ {
"action": "document_qa", "tool_name": "document_qa",
"action_input": {"document": "document.pdf", "question": "Who is the oldest person mentioned?"} "tool_arguments": {"document": "document.pdf", "question": "Who is the oldest person mentioned?"}
}<end_action> }<end_action>
Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland." Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland."
@ -203,16 +203,16 @@ Observation: "The oldest person in the document is John Doe, a 55 year old lumbe
Thought: I will now generate an image showcasing the oldest person. Thought: I will now generate an image showcasing the oldest person.
Action: Action:
{ {
"action": "image_generator", "tool_name": "image_generator",
"action_input": {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."} "tool_arguments": {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."}
}<end_action> }<end_action>
Observation: "image.png" Observation: "image.png"
Thought: I will now return the generated image. Thought: I will now return the generated image.
Action: Action:
{ {
"action": "final_answer", "tool_name": "final_answer",
"action_input": "image.png" "tool_arguments": "image.png"
}<end_action> }<end_action>
--- ---
@ -221,16 +221,16 @@ Task: "What is the result of the following operation: 5 + 3 + 1294.678?"
Thought: I will use python code evaluator to compute the result of the operation and then return the final answer using the `final_answer` tool Thought: I will use python code evaluator to compute the result of the operation and then return the final answer using the `final_answer` tool
Action: Action:
{ {
"action": "python_interpreter", "tool_name": "python_interpreter",
"action_input": {"code": "5 + 3 + 1294.678"} "tool_arguments": {"code": "5 + 3 + 1294.678"}
}<end_action> }<end_action>
Observation: 1302.678 Observation: 1302.678
Thought: Now that I know the result, I will now return it. Thought: Now that I know the result, I will now return it.
Action: Action:
{ {
"action": "final_answer", "tool_name": "final_answer",
"action_input": "1302.678" "tool_arguments": "1302.678"
}<end_action> }<end_action>
--- ---
@ -239,8 +239,8 @@ Task: "Which city has the highest population , Guangzhou or Shanghai?"
Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities. Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities.
Action: Action:
{ {
"action": "search", "tool_name": "search",
"action_input": "Population Guangzhou" "tool_arguments": "Population Guangzhou"
}<end_action> }<end_action>
Observation: ['Guangzhou has a population of 15 million inhabitants as of 2021.'] Observation: ['Guangzhou has a population of 15 million inhabitants as of 2021.']
@ -248,16 +248,16 @@ Observation: ['Guangzhou has a population of 15 million inhabitants as of 2021.'
Thought: Now let's get the population of Shanghai using the tool 'search'. Thought: Now let's get the population of Shanghai using the tool 'search'.
Action: Action:
{ {
"action": "search", "tool_name": "search",
"action_input": "Population Shanghai" "tool_arguments": "Population Shanghai"
} }
Observation: '26 million (2019)' Observation: '26 million (2019)'
Thought: Now I know that Shanghai has a larger population. Let's return the result. Thought: Now I know that Shanghai has a larger population. Let's return the result.
Action: Action:
{ {
"action": "final_answer", "tool_name": "final_answer",
"action_input": "Shanghai" "tool_arguments": "Shanghai"
}<end_action> }<end_action>
@ -291,15 +291,15 @@ Observation: "image_1.jpg"
Action: Action:
{ {
"action": "image_transformer", "tool_name": "image_transformer",
"action_input": {"image": "image_1.jpg"} "tool_arguments": {"image": "image_1.jpg"}
} }
To provide the final answer to the task, use an action blob with "action": "final_answer" tool. It is the only way to complete the task, else you will be stuck on a loop. So your final output should look like this: To provide the final answer to the task, use an action blob with "tool_name": "final_answer" tool. It is the only way to complete the task, else you will be stuck on a loop. So your final output should look like this:
Action: Action:
{ {
"action": "final_answer", "tool_name": "final_answer",
"action_input": {"answer": "insert your final answer here"} "tool_arguments": {"answer": "insert your final answer here"}
} }
@ -309,22 +309,22 @@ Task: "Generate an image of the oldest person in this document."
Action: Action:
{ {
"action": "document_qa", "tool_name": "document_qa",
"action_input": {"document": "document.pdf", "question": "Who is the oldest person mentioned?"} "tool_arguments": {"document": "document.pdf", "question": "Who is the oldest person mentioned?"}
} }
Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland." Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland."
Action: Action:
{ {
"action": "image_generator", "tool_name": "image_generator",
"action_input": {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."} "tool_arguments": {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."}
} }
Observation: "image.png" Observation: "image.png"
Action: Action:
{ {
"action": "final_answer", "tool_name": "final_answer",
"action_input": "image.png" "tool_arguments": "image.png"
} }
--- ---
@ -332,15 +332,15 @@ Task: "What is the result of the following operation: 5 + 3 + 1294.678?"
Action: Action:
{ {
"action": "python_interpreter", "tool_name": "python_interpreter",
"action_input": {"code": "5 + 3 + 1294.678"} "tool_arguments": {"code": "5 + 3 + 1294.678"}
} }
Observation: 1302.678 Observation: 1302.678
Action: Action:
{ {
"action": "final_answer", "tool_name": "final_answer",
"action_input": "1302.678" "tool_arguments": "1302.678"
} }
--- ---
@ -348,23 +348,23 @@ Task: "Which city has the highest population , Guangzhou or Shanghai?"
Action: Action:
{ {
"action": "search", "tool_name": "search",
"action_input": "Population Guangzhou" "tool_arguments": "Population Guangzhou"
} }
Observation: ['Guangzhou has a population of 15 million inhabitants as of 2021.'] Observation: ['Guangzhou has a population of 15 million inhabitants as of 2021.']
Action: Action:
{ {
"action": "search", "tool_name": "search",
"action_input": "Population Shanghai" "tool_arguments": "Population Shanghai"
} }
Observation: '26 million (2019)' Observation: '26 million (2019)'
Action: Action:
{ {
"action": "final_answer", "tool_name": "final_answer",
"action_input": "Shanghai" "tool_arguments": "Shanghai"
} }

View File

@ -47,6 +47,39 @@ BASE_BUILTIN_MODULES = [
] ]
class AgentError(Exception):
"""Base class for other agent-related exceptions"""
def __init__(self, message):
super().__init__(message)
self.message = message
console.print(f"[bold red]{message}[/bold red]")
class AgentParsingError(AgentError):
"""Exception raised for errors in parsing in the agent"""
pass
class AgentExecutionError(AgentError):
"""Exception raised for errors in execution in the agent"""
pass
class AgentMaxIterationsError(AgentError):
"""Exception raised for errors in execution in the agent"""
pass
class AgentGenerationError(AgentError):
"""Exception raised for errors in generation in the agent"""
pass
def parse_json_blob(json_blob: str) -> Dict[str, str]: def parse_json_blob(json_blob: str) -> Dict[str, str]:
try: try:
first_accolade_index = json_blob.find("{") first_accolade_index = json_blob.find("{")
@ -97,17 +130,25 @@ Code:
def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]: def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]:
json_blob = json_blob.replace("```json", "").replace("```", "") json_blob = json_blob.replace("```json", "").replace("```", "")
tool_call = parse_json_blob(json_blob) tool_call = parse_json_blob(json_blob)
if "action" in tool_call and "action_input" in tool_call: tool_name_key, tool_arguments_key = None, None
return tool_call["action"], tool_call["action_input"] for possible_tool_name_key in ["action", "tool_name", "tool", "name", "function"]:
elif "action" in tool_call: if possible_tool_name_key in tool_call:
return tool_call["action"], None tool_name_key = possible_tool_name_key
for possible_tool_arguments_key in [
"action_input",
"tool_arguments",
"tool_args",
"parameters",
]:
if possible_tool_arguments_key in tool_call:
tool_arguments_key = possible_tool_arguments_key
if tool_name_key is not None:
if tool_arguments_key is not None:
return tool_call[tool_name_key], tool_call[tool_arguments_key]
else: else:
missing_keys = [ return tool_call[tool_name_key], None
key for key in ["action", "action_input"] if key not in tool_call error_msg = "No tool name key found in tool call!" + f" Tool call: {json_blob}"
] raise AgentParsingError(error_msg)
error_msg = f"Missing keys: {missing_keys} in blob {tool_call}"
console.print(f"[bold red]{error_msg}[/bold red]")
raise ValueError(error_msg)
MAX_LENGTH_TRUNCATE_CONTENT = 20000 MAX_LENGTH_TRUNCATE_CONTENT = 20000
@ -266,4 +307,4 @@ def instance_to_source(instance, base_cls=None):
return "\n".join(final_lines) return "\n".join(final_lines)
__all__ = [] __all__ = ["AgentError"]

View File

@ -25,7 +25,7 @@ from smolagents.agents import (
AgentMaxIterationsError, AgentMaxIterationsError,
ManagedAgent, ManagedAgent,
CodeAgent, CodeAgent,
JsonAgent, ToolCallingAgent,
Toolbox, Toolbox,
ToolCall, ToolCall,
) )
@ -182,7 +182,7 @@ class AgentTests(unittest.TestCase):
assert output == "7.2904" assert output == "7.2904"
def test_fake_json_agent(self): def test_fake_json_agent(self):
agent = JsonAgent(tools=[PythonInterpreterTool()], llm_engine=fake_json_llm) agent = ToolCallingAgent(tools=[PythonInterpreterTool()], llm_engine=fake_json_llm)
output = agent.run("What is 2 multiplied by 3.6452?") output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, str) assert isinstance(output, str)
assert output == "7.2904" assert output == "7.2904"
@ -212,7 +212,7 @@ Action:
""" """
return Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png") return Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png")
agent = JsonAgent( agent = ToolCallingAgent(
tools=[fake_image_generation_tool], llm_engine=fake_json_llm_image tools=[fake_image_generation_tool], llm_engine=fake_json_llm_image
) )
output = agent.run("Make me an image.") output = agent.run("Make me an image.")
@ -262,7 +262,7 @@ Action:
assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs) assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs)
def test_setup_agent_with_empty_toolbox(self): def test_setup_agent_with_empty_toolbox(self):
JsonAgent(llm_engine=fake_json_llm, tools=[]) ToolCallingAgent(llm_engine=fake_json_llm, tools=[])
def test_fails_max_iterations(self): def test_fails_max_iterations(self):
agent = CodeAgent( agent = CodeAgent(
@ -295,7 +295,7 @@ Action:
# check that add_base_tools will not interfere with existing tools # check that add_base_tools will not interfere with existing tools
with pytest.raises(KeyError) as e: with pytest.raises(KeyError) as e:
agent = JsonAgent( agent = ToolCallingAgent(
tools=toolset_3, llm_engine=fake_json_llm, add_base_tools=True tools=toolset_3, llm_engine=fake_json_llm, add_base_tools=True
) )
assert "already exists in the toolbox" in str(e) assert "already exists in the toolbox" in str(e)

View File

@ -15,7 +15,7 @@
import unittest import unittest
from smolagents import AgentImage, AgentError, CodeAgent, JsonAgent, stream_to_gradio from smolagents import AgentImage, AgentError, CodeAgent, ToolCallingAgent, stream_to_gradio
class MonitoringTester(unittest.TestCase): class MonitoringTester(unittest.TestCase):
@ -52,7 +52,7 @@ final_answer('This is the final answer.')
def __call__(self, prompt, **kwargs): def __call__(self, prompt, **kwargs):
return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}' return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
agent = JsonAgent( agent = ToolCallingAgent(
tools=[], tools=[],
llm_engine=FakeLLMEngine(), llm_engine=FakeLLMEngine(),
max_iterations=1, max_iterations=1,
@ -131,7 +131,7 @@ final_answer('This is the final answer.')
'Action:{"action": "final_answer", "action_input": {"answer": "image"}}' 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
) )
agent = JsonAgent( agent = ToolCallingAgent(
tools=[], tools=[],
llm_engine=dummy_llm_engine, llm_engine=dummy_llm_engine,
max_iterations=1, max_iterations=1,