Change name 'llm_engine' to 'model'

This commit is contained in:
Aymeric 2024-12-25 21:51:43 +01:00
parent bfbe704793
commit 8005d6f21d
16 changed files with 200 additions and 215 deletions

View File

@ -34,7 +34,7 @@ This library offers:
**Simplicity**: the logic for agents fits in ~thousand lines of code. We kept abstractions to their minimal shape above raw code!
🌐 **Support for any LLM**: it supports models hosted on the Hub loaded in their `transformers` version or through our inference API, but also models from OpenAI, Anthropic... it's really easy to power an agent with any LLM.
🌐 **Support for any LLM**: it supports models hosted on the Hub loaded in their `transformers` version or through our inference API, but also models from OpenAI, Anthropic, and many more through our LiteLLM integration.
🧑‍💻 **First-class support for Code Agents**, i.e. agents that write their actions in code (as opposed to "agents being used to write code"), [read more here](tutorials/secure_code_execution).
@ -48,9 +48,9 @@ pip install agents
```
Then define your agent, give it the tools it needs and run it!
```py
from smolagents import CodeAgent, DuckDuckGoSearchTool, HfApiEngine
from smolagents import CodeAgent, DuckDuckGoSearchTool, HfApiModel
agent = CodeAgent(tools=[DuckDuckGoSearchTool()], llm_engine=HfApiEngine())
agent = CodeAgent(tools=[DuckDuckGoSearchTool()], model=HfApiModel())
agent.run("What time would the world's fastest car take to travel from New York to San Francisco?")
```

View File

@ -124,14 +124,14 @@ Now let us create an agent that leverages this tool.
We use the CodeAgent, which is transformers.agents main agent class: an agent that writes actions in code and can iterate on previous output according to the ReAct framework.
The llm_engine is the LLM that powers the agent system. HfEngine allows you to call LLMs using HFs Inference API, either via Serverless or Dedicated endpoint, but you could also use any proprietary API.
The model is the LLM that powers the agent system. HfModel allows you to call LLMs using HFs Inference API, either via Serverless or Dedicated endpoint, but you could also use any proprietary API.
```py
from smolagents import CodeAgent, HfApiEngine
from smolagents import CodeAgent, HfApiModel
agent = CodeAgent(
tools=[sql_engine],
llm_engine=HfApiEngine("meta-llama/Meta-Llama-3.1-8B-Instruct"),
model=HfApiModel("meta-llama/Meta-Llama-3.1-8B-Instruct"),
)
agent.run("Can you give me the name of the client who got the most expensive receipt?")
```
@ -187,7 +187,7 @@ sql_engine.description = updated_description
agent = CodeAgent(
tools=[sql_engine],
llm_engine=HfApiEngine("Qwen/Qwen2.5-72B-Instruct"),
model=HfApiModel("Qwen/Qwen2.5-72B-Instruct"),
)
agent.run("Which waiter got more total money from tips?")

View File

@ -26,7 +26,7 @@ To initialize a minimal agent, you need at least these two arguments:
- An LLM to power your agent - because the agent is different from a simple LLM, it is a system that uses a LLM as its engine.
- A list of tools from which the agent pick tools to execute
For defining your llm, you can make a `llm_engine` method which accepts a list of [messages](./chat_templating) and returns text. This callable also needs to accept a `stop_sequences` argument that indicates when to stop generating.
For defining your LLM, you can make a `custom_model` method which accepts a list of [messages](./chat_templating) and returns text. This callable also needs to accept a `stop_sequences` argument that indicates when to stop generating.
```python
from huggingface_hub import login, InferenceClient
@ -37,32 +37,32 @@ model_id = "Qwen/Qwen2.5-72B-Instruct"
client = InferenceClient(model=model_id)
def llm_engine(messages, stop_sequences=["Task"]) -> str:
def custom_model(messages, stop_sequences=["Task"]) -> str:
response = client.chat_completion(messages, stop=stop_sequences, max_tokens=1000)
answer = response.choices[0].message.content
return answer
```
You could use any `llm_engine` method as long as:
You could use any `custom_model` method as long as:
1. it follows the [messages format](./chat_templating) (`List[Dict[str, str]]`) for its input `messages`, and it returns a `str`.
2. it stops generating outputs at the sequences passed in the argument `stop_sequences`
Additionally, `llm_engine` can also take a `grammar` argument. In the case where you specify a `grammar` upon agent initialization, this argument will be passed to the calls to llm_engine, with the `grammar` that you defined upon initialization, to allow [constrained generation](https://huggingface.co/docs/text-generation-inference/conceptual/guidance) in order to force properly-formatted agent outputs.
Additionally, `custom_model` can also take a `grammar` argument. In the case where you specify a `grammar` upon agent initialization, this argument will be passed to the calls to model, with the `grammar` that you defined upon initialization, to allow [constrained generation](https://huggingface.co/docs/text-generation-inference/conceptual/guidance) in order to force properly-formatted agent outputs.
For convenience, we provide pre-built classes for your llm engine:
- [`TransformersEngine`] takes a pre-initialized `transformers` pipeline to run inference on your local machine using `transformers`.
- [`HfApiEngine`] leverages a `huggingface_hub.InferenceClient` under the hood.
- We also provide [`LiteLLMEngine`], which lets you call 100+ different models through [LiteLLM](https://docs.litellm.ai/)!
For convenience, we provide pre-built classes for your model engine:
- [`TransformersModel`] takes a pre-initialized `transformers` pipeline to run inference on your local machine using `transformers`.
- [`HfApiModel`] leverages a `huggingface_hub.InferenceClient` under the hood.
- We also provide [`LiteLLMModel`], which lets you call 100+ different models through [LiteLLM](https://docs.litellm.ai/)!
You will also need a `tools` argument which accepts a list of `Tools` - it can be an empty list. You can also add the default toolbox on top of your `tools` list by defining the optional argument `add_base_tools=True`.
Once you have these two arguments, `tools` and `llm_engine`, you can create an agent and run it.
Once you have these two arguments, `tools` and `model`, you can create an agent and run it.
```python
from smolagents import CodeAgent, HfApiEngine
from smolagents import CodeAgent, HfApiModel
llm_engine = HfApiEngine(model=model_id)
agent = CodeAgent(tools=[], llm_engine=llm_engine, add_base_tools=True)
model = HfApiModel(model=model_id)
agent = CodeAgent(tools=[], model=model, add_base_tools=True)
agent.run(
"Could you give me the 118th number in the Fibonacci sequence?",
@ -76,7 +76,7 @@ You can use this to indicate the path to local or remote files for the model to
```py
from smolagents import CodeAgent, Tool, SpeechToTextTool
agent = CodeAgent(tools=[SpeechToTextTool()], llm_engine=llm_engine, add_base_tools=True)
agent = CodeAgent(tools=[SpeechToTextTool()], model=model, add_base_tools=True)
agent.run("Why does Mike not know many people in New York?", audio="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/recording.mp3")
```
@ -96,7 +96,7 @@ You can authorize additional imports by passing the authorized modules as a list
```py
from smolagents import CodeAgent
agent = CodeAgent(tools=[], llm_engine=llm_engine, additional_authorized_imports=['requests', 'bs4'])
agent = CodeAgent(tools=[], model=model, additional_authorized_imports=['requests', 'bs4'])
agent.run("Could you get me the title of the page at url 'https://huggingface.co/blog'?")
```
This gives you at the end of the agent run:
@ -165,11 +165,11 @@ You could improve the system prompt, for example, by adding an explanation of th
For maximum flexibility, you can overwrite the whole system prompt template by passing your custom prompt as an argument to the `system_prompt` parameter.
```python
from smolagents import ToolCallingAgent, PythonInterpreterTool, JSON_SYSTEM_PROMPT
from smolagents import ToolCallingAgent, PythonInterpreterTool, TOOL_CALLING_SYSTEM_PROMPT
modified_prompt = JSON_SYSTEM_PROMPT
modified_prompt = TOOL_CALLING_SYSTEM_PROMPT
agent = ToolCallingAgent(tools=[PythonInterpreterTool()], llm_engine=llm_engine, system_prompt=modified_prompt)
agent = ToolCallingAgent(tools=[PythonInterpreterTool()], model=model, system_prompt=modified_prompt)
```
> [!WARNING]
@ -255,7 +255,7 @@ All these will be automatically baked into the agent's system prompt upon initia
Then you can directly initialize your agent:
```py
from smolagents import CodeAgent
agent = CodeAgent(tools=[model_download_tool], llm_engine=llm_engine)
agent = CodeAgent(tools=[model_download_tool], model=model)
agent.run(
"Can you give me the name of the model that has the most downloads in the 'text-to-video' task on the Hugging Face Hub?"
)
@ -287,11 +287,11 @@ To do so, encapsulate the agent in a [`ManagedAgent`] object. This object needs
Here's an example of making an agent that managed a specific web search agent using our [`DuckDuckGoSearchTool`]:
```py
from smolagents import CodeAgent, HfApiEngine, DuckDuckGoSearchTool, ManagedAgent
from smolagents import CodeAgent, HfApiModel, DuckDuckGoSearchTool, ManagedAgent
llm_engine = HfApiEngine()
model = HfApiModel()
web_agent = CodeAgent(tools=[DuckDuckGoSearchTool()], llm_engine=llm_engine)
web_agent = CodeAgent(tools=[DuckDuckGoSearchTool()], model=model)
managed_web_agent = ManagedAgent(
agent=web_agent,
@ -300,7 +300,7 @@ managed_web_agent = ManagedAgent(
)
manager_agent = CodeAgent(
tools=[], llm_engine=llm_engine, managed_agents=[managed_web_agent]
tools=[], model=model, managed_agents=[managed_web_agent]
)
manager_agent.run("Who is the CEO of Hugging Face?")
@ -318,17 +318,17 @@ You can use `GradioUI` to interactively submit tasks to your agent and observe i
from smolagents import (
load_tool,
CodeAgent,
HfApiEngine,
HfApiModel,
GradioUI
)
# Import tool from Hub
image_generation_tool = load_tool("m-ric/text-to-image")
llm_engine = HfApiEngine(model_id)
model = HfApiModel(model_id)
# Initialize the agent with the image generation tool
agent = CodeAgent(tools=[image_generation_tool], llm_engine=llm_engine)
agent = CodeAgent(tools=[image_generation_tool], model=model)
GradioUI(agent).launch()
```

View File

@ -52,19 +52,19 @@ We provide two types of agents, based on the main [`Agent`] class.
[[autodoc]] stream_to_gradio
## Engines
## Models
You're free to create and use your own engines to be usable by the Agents framework.
These engines have the following specification:
1. Follow the [messages format](../chat_templating.md) for its input (`List[Dict[str, str]]`) and return a string.
2. Stop generating outputs *before* the sequences passed in the argument `stop_sequences`
### TransformersEngine
### TransformersModel
For convenience, we have added a `TransformersEngine` that implements the points above, taking a pre-initialized `Pipeline` as input.
For convenience, we have added a `TransformersModel` that implements the points above, taking a pre-initialized `Pipeline` as input.
```python
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TransformersEngine
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TransformersModel
model_name = "HuggingFaceTB/SmolLM-135M-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
@ -72,18 +72,18 @@ model = AutoModelForCausalLM.from_pretrained(model_name)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
engine = TransformersEngine(pipe)
engine = TransformersModel(pipe)
engine([{"role": "user", "content": "Ok!"}], stop_sequences=["great"])
```
[[autodoc]] TransformersEngine
[[autodoc]] TransformersModel
### HfApiEngine
### HfApiModel
The `HfApiEngine` is an engine that wraps an [HF Inference API](https://huggingface.co/docs/api-inference/index) client for the execution of the LLM.
The `HfApiModel` is an engine that wraps an [HF Inference API](https://huggingface.co/docs/api-inference/index) client for the execution of the LLM.
```python
from transformers import HfApiEngine
from transformers import HfApiModel
messages = [
{"role": "user", "content": "Hello, how are you?"},
@ -91,7 +91,7 @@ messages = [
{"role": "user", "content": "No need to help, take it easy."},
]
HfApiEngine()(messages)
HfApiModel()(messages)
```
[[autodoc]] HfApiEngine
[[autodoc]] HfApiModel

View File

@ -166,7 +166,7 @@ Better ways to guide your LLM engine are:
We provide a model for a supplementary planning step, that an agent can run regularly in-between normal action steps. In this step, there is no tool call, the LLM is simply asked to update a list of facts it knows and to reflect on what steps it should take next based on those facts.
```py
from smolagents import load_tool, CodeAgent, HfApiEngine, DuckDuckGoSearchTool
from smolagents import load_tool, CodeAgent, HfApiModel, DuckDuckGoSearchTool
from dotenv import load_dotenv
load_dotenv()
@ -178,7 +178,7 @@ search_tool = DuckDuckGoSearchTool()
agent = CodeAgent(
tools=[search_tool],
llm_engine=HfApiEngine("Qwen/Qwen2.5-72B-Instruct"),
model=HfApiModel("Qwen/Qwen2.5-72B-Instruct"),
planning_interval=3 # This is where you activate planning!
)

View File

@ -68,10 +68,10 @@ To set the code executor to E2B, simply pass the flag `use_e2b_executor=True` wh
Note that you should add all the tool's dependencies in `additional_authorized_imports`, so that the executor installs them.
```py
from smolagents import CodeAgent, VisitWebpageTool, HfApiEngine
from smolagents import CodeAgent, VisitWebpageTool, HfApiModel
agent = CodeAgent(
tools = [VisitWebpageTool()],
llm_engine=HfApiEngine(),
model=HfApiModel(),
additional_authorized_imports=["requests", "markdownify"],
use_e2b_executor=True
)

View File

@ -114,10 +114,10 @@ And voilà, here's your image! 🏖️
Then you can use this tool just like any other tool. For example, let's improve the prompt `a rabbit wearing a space suit` and generate an image of it.
```python
from smolagents import CodeAgent, HfApiEngine
from smolagents import CodeAgent, HfApiModel
llm_engine = HfApiEngine("Qwen/Qwen2.5-Coder-32B-Instruct")
agent = CodeAgent(tools=[image_generation_tool], llm_engine=llm_engine)
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct")
agent = CodeAgent(tools=[image_generation_tool], model=model)
agent.run(
"Improve this prompt, then generate an image of it.", prompt='A rabbit wearing a space suit'
@ -169,7 +169,7 @@ from langchain.agents import load_tools
search_tool = Tool.from_langchain(load_tools(["serpapi"])[0])
agent = CodeAgent(tools=[search_tool], llm_engine=llm_engine)
agent = CodeAgent(tools=[search_tool], model=model)
agent.run("How many more blocks (also denoted as layers) are in BERT base encoder compared to the encoder from the architecture proposed in Attention is All You Need?")
```
@ -181,11 +181,11 @@ You can manage an agent's toolbox by adding or replacing a tool.
Let's add the `model_download_tool` to an existing agent initialized with only the default toolbox.
```python
from smolagents import HfApiEngine
from smolagents import HfApiModel
llm_engine = HfApiEngine("Qwen/Qwen2.5-Coder-32B-Instruct")
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct")
agent = CodeAgent(tools=[], llm_engine=llm_engine, add_base_tools=True)
agent = CodeAgent(tools=[], model=model, add_base_tools=True)
agent.toolbox.add_tool(model_download_tool)
```
Now we can leverage the new tool:
@ -218,7 +218,7 @@ image_tool_collection = ToolCollection(
collection_slug="huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f",
token="<YOUR_HUGGINGFACEHUB_API_TOKEN>"
)
agent = CodeAgent(tools=[*image_tool_collection.tools], llm_engine=llm_engine, add_base_tools=True)
agent = CodeAgent(tools=[*image_tool_collection.tools], model=model, add_base_tools=True)
agent.run("Please draw me a picture of rivers and lakes.")
```

View File

@ -1,4 +1,4 @@
from smolagents import Tool, CodeAgent, HfApiEngine
from smolagents import Tool, CodeAgent, HfApiModel
from smolagents.default_tools import VisitWebpageTool
from dotenv import load_dotenv
@ -30,7 +30,7 @@ get_cat_image = GetCatImageTool()
agent = CodeAgent(
tools = [get_cat_image, VisitWebpageTool()],
llm_engine=HfApiEngine(),
model=HfApiModel(),
additional_authorized_imports=["Pillow", "requests", "markdownify"], # "duckduckgo-search",
use_e2b_executor=False
)

View File

@ -66,10 +66,10 @@ def sql_engine(query: str) -> str:
output += "\n" + str(row)
return output
from smolagents import CodeAgent, HfApiEngine
from smolagents import CodeAgent, HfApiModel
agent = CodeAgent(
tools=[sql_engine],
llm_engine=HfApiEngine("meta-llama/Meta-Llama-3.1-8B-Instruct"),
model=HfApiModel("meta-llama/Meta-Llama-3.1-8B-Instruct"),
)
agent.run("Can you give me the name of the client who got the most expensive receipt?")

View File

@ -1,10 +1,10 @@
from smolagents.agents import ToolCallingAgent
from smolagents import tool, HfApiEngine, TransformersEngine, LiteLLMEngine
from smolagents import tool, HfApiModel, TransformersModel, LiteLLMModel
# Choose which LLM engine to use!
llm_engine = HfApiEngine("meta-llama/Llama-3.3-70B-Instruct")
llm_engine = TransformersEngine("meta-llama/Llama-3.2-2B-Instruct")
llm_engine = LiteLLMEngine("gpt-4o")
model = HfApiModel("meta-llama/Llama-3.3-70B-Instruct")
model = TransformersModel("meta-llama/Llama-3.2-2B-Instruct")
model = LiteLLMModel("gpt-4o")
@tool
def get_weather(location: str) -> str:
@ -17,6 +17,6 @@ def get_weather(location: str) -> str:
"""
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
agent = ToolCallingAgent(tools=[get_weather], llm_engine=llm_engine)
agent = ToolCallingAgent(tools=[get_weather], model=model)
print(agent.run("What's the weather like in Paris?"))

View File

@ -26,7 +26,7 @@ if TYPE_CHECKING:
from .agents import *
from .default_tools import *
from .gradio_ui import *
from .llm_engines import *
from .models import *
from .local_python_executor import *
from .e2b_executor import *
from .monitoring import *

View File

@ -36,7 +36,7 @@ from .utils import (
)
from .types import AgentAudio, AgentImage
from .default_tools import FinalAnswerTool
from .llm_engines import MessageRole
from .models import MessageRole
from .monitoring import Monitor
from .prompts import (
CODE_SYSTEM_PROMPT,
@ -160,7 +160,7 @@ class MultiStepAgent:
def __init__(
self,
tools: Union[List[Tool], Toolbox],
llm_engine: Callable[[List[Dict[str, str]]], str],
model: Callable[[List[Dict[str, str]]], str],
system_prompt: Optional[str] = None,
tool_description_template: Optional[str] = None,
max_iterations: int = 6,
@ -177,7 +177,7 @@ class MultiStepAgent:
if tool_parser is None:
tool_parser = parse_json_tool_call
self.agent_name = self.__class__.__name__
self.llm_engine = llm_engine
self.model = model
self.system_prompt_template = system_prompt
self.tool_description_template = (
tool_description_template
@ -208,7 +208,7 @@ class MultiStepAgent:
self.logs = []
self.task = None
self.verbose = verbose
self.monitor = Monitor(self.llm_engine)
self.monitor = Monitor(self.model)
self.step_callbacks = step_callbacks if step_callbacks is not None else []
self.step_callbacks.append(self.monitor.update_metrics)
@ -367,7 +367,7 @@ class MultiStepAgent:
}
]
try:
return self.llm_engine(self.input_messages)
return self.model(self.input_messages)
except Exception as e:
error_msg = f"Error in generating final LLM output:\n{e}"
console.print(f"[bold red]{error_msg}[/bold red]")
@ -473,13 +473,15 @@ class MultiStepAgent:
# Rule("[bold]New run", characters="═", style=YELLOW_HEX), Text(self.task)
# )
# )
console.print(Panel(
f"\n[bold]{task.strip()}\n",
title="[bold]New run",
subtitle=f"{type(self.llm_engine).__name__} - {(self.llm_engine.model_id if hasattr(self.llm_engine, "model_id") else "")}",
border_style=YELLOW_HEX,
subtitle_align="left",
))
console.print(
Panel(
f"\n[bold]{task.strip()}\n",
title="[bold]New run",
subtitle=f"{type(self.model).__name__} - {(self.model.model_id if hasattr(self.model, "model_id") else "")}",
border_style=YELLOW_HEX,
subtitle_align="left",
)
)
self.logs.append(TaskStep(task=self.task))
@ -616,7 +618,7 @@ class MultiStepAgent:
Now begin!""",
}
answer_facts = self.llm_engine([message_prompt_facts, message_prompt_task])
answer_facts = self.model([message_prompt_facts, message_prompt_task])
message_system_prompt_plan = {
"role": MessageRole.SYSTEM,
@ -635,7 +637,7 @@ Now begin!""",
answer_facts=answer_facts,
),
}
answer_plan = self.llm_engine(
answer_plan = self.model(
[message_system_prompt_plan, message_user_prompt_plan],
stop_sequences=["<end_plan>"],
)
@ -668,7 +670,7 @@ Now begin!""",
"role": MessageRole.USER,
"content": USER_PROMPT_FACTS_UPDATE,
}
facts_update = self.llm_engine(
facts_update = self.model(
[facts_update_system_prompt] + agent_memory + [facts_update_message]
)
@ -691,7 +693,7 @@ Now begin!""",
remaining_steps=(self.max_iterations - iteration),
),
}
plan_update = self.llm_engine(
plan_update = self.model(
[plan_update_message] + agent_memory + [plan_update_message_user],
stop_sequences=["<end_plan>"],
)
@ -714,13 +716,13 @@ Now begin!""",
class ToolCallingAgent(MultiStepAgent):
"""
This agent uses JSON-like tool calls, using method `llm_engine.get_tool_call` to leverage the LLM engine's tool calling capabilities.
This agent uses JSON-like tool calls, using method `model.get_tool_call` to leverage the LLM engine's tool calling capabilities.
"""
def __init__(
self,
tools: List[Tool],
llm_engine: Callable,
model: Callable,
system_prompt: Optional[str] = None,
planning_interval: Optional[int] = None,
**kwargs,
@ -729,7 +731,7 @@ class ToolCallingAgent(MultiStepAgent):
system_prompt = TOOL_CALLING_SYSTEM_PROMPT
super().__init__(
tools=tools,
llm_engine=llm_engine,
model=model,
system_prompt=system_prompt,
planning_interval=planning_interval,
**kwargs,
@ -748,14 +750,14 @@ class ToolCallingAgent(MultiStepAgent):
log_entry.agent_memory = agent_memory.copy()
try:
tool_name, tool_arguments, tool_call_id = self.llm_engine.get_tool_call(
tool_name, tool_arguments, tool_call_id = self.model.get_tool_call(
self.input_messages,
available_tools=list(self.toolbox._tools.values()),
stop_sequences=["Observation:"],
)
except Exception as e:
raise AgentGenerationError(
f"Error in generating tool call with llm_engine:\n{e}"
f"Error in generating tool call with model:\n{e}"
)
log_entry.tool_call = ToolCall(
@ -808,7 +810,7 @@ class CodeAgent(MultiStepAgent):
def __init__(
self,
tools: List[Tool],
llm_engine: Callable,
model: Callable,
system_prompt: Optional[str] = None,
grammar: Optional[Dict[str, str]] = None,
additional_authorized_imports: Optional[List[str]] = None,
@ -820,7 +822,7 @@ class CodeAgent(MultiStepAgent):
system_prompt = CODE_SYSTEM_PROMPT
super().__init__(
tools=tools,
llm_engine=llm_engine,
model=model,
system_prompt=system_prompt,
grammar=grammar,
planning_interval=planning_interval,
@ -871,7 +873,7 @@ class CodeAgent(MultiStepAgent):
additional_args = (
{"grammar": self.grammar} if self.grammar is not None else {}
)
llm_output = self.llm_engine(
llm_output = self.model(
self.input_messages,
stop_sequences=["<end_action>", "Observation:"],
**additional_args,
@ -879,7 +881,18 @@ class CodeAgent(MultiStepAgent):
log_entry.llm_output = llm_output
except Exception as e:
console.print_exception()
raise AgentGenerationError(f"Error in generating llm_engine output:\n{e}")
raise AgentGenerationError(f"Error in generating model output:\n{e}")
from rich.live import Live
from rich.markdown import Markdown
import time
with Live(console=console, vertical_overflow="visible") as live:
message = ""
for i in range(100):
time.sleep(0.02)
message += str(i)
live.update(Markdown(message))
if self.verbose:
console.print(
@ -908,8 +921,15 @@ class CodeAgent(MultiStepAgent):
)
# Execute
console.print(Panel(
Syntax(code_action, lexer="python", theme="monokai", word_wrap=True, line_numbers=True),
console.print(
Panel(
Syntax(
code_action,
lexer="python",
theme="monokai",
word_wrap=True,
line_numbers=True,
),
title="[bold]Executing this code:",
title_align="left",
)
@ -921,7 +941,10 @@ class CodeAgent(MultiStepAgent):
)
execution_outputs_console = []
if len(execution_logs) > 0:
execution_outputs_console += [Text("Execution logs:", style="bold"), Text(execution_logs)]
execution_outputs_console += [
Text("Execution logs:", style="bold"),
Text(execution_logs),
]
observation = "Execution logs:\n" + execution_logs
except Exception as e:
console.print_exception()
@ -929,7 +952,7 @@ class CodeAgent(MultiStepAgent):
if "'dict' object has no attribute 'read'" in str(e):
error_msg += "\nYou get this error because you passed a dict as input for one of the arguments instead of a string."
raise AgentExecutionError(error_msg)
truncated_output = truncate_content(str(output))
observation += "Last output from code snippet:\n" + truncated_output
log_entry.observations = observation
@ -940,17 +963,15 @@ class CodeAgent(MultiStepAgent):
is_final_answer = True
break
execution_outputs_console+= [
execution_outputs_console += [
Text(
f"{('Out - Final answer' if is_final_answer else 'Out')}: {truncated_output}",
style=(f"bold {YELLOW_HEX}" if is_final_answer else "")
style=(f"bold {YELLOW_HEX}" if is_final_answer else ""),
),
]
console.print(
Group(*execution_outputs_console)
)
console.print(Group(*execution_outputs_console))
log_entry.action_output = output
return (output if is_final_answer else None)
return output if is_final_answer else None
class ManagedAgent:

View File

@ -125,7 +125,7 @@ def get_clean_message_list(
return final_message_list
class HfEngine:
class HfModel:
def __init__(self):
self.last_input_token_count = None
self.last_output_token_count = None
@ -177,7 +177,7 @@ class HfEngine:
return remove_stop_sequences(response, stop_sequences)
class HfApiEngine(HfEngine):
class HfApiModel(HfModel):
"""A class to interact with Hugging Face's Inference API for language model interaction.
This engine allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
@ -200,7 +200,7 @@ class HfApiEngine(HfEngine):
Example:
```python
>>> engine = HfApiEngine(
>>> engine = HfApiModel(
... model="Qwen/Qwen2.5-Coder-32B-Instruct",
... token="your_hf_token_here",
... max_tokens=2000
@ -274,7 +274,7 @@ class HfApiEngine(HfEngine):
return tool_call.function.name, tool_call.function.arguments, tool_call.id
class TransformersEngine(HfEngine):
class TransformersModel(HfModel):
"""This engine initializes a model and tokenizer from the given `model_id`."""
def __init__(self, model_id: Optional[str] = None):
@ -391,7 +391,7 @@ class TransformersEngine(HfEngine):
return tool_name, tool_input, call_id
class LiteLLMEngine:
class LiteLLMModel:
def __init__(self, model_id="anthropic/claude-3-5-sonnet-20240620"):
self.model_id = model_id
# IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs
@ -448,8 +448,8 @@ __all__ = [
"MessageRole",
"tool_role_conversions",
"get_clean_message_list",
"HfEngine",
"TransformersEngine",
"HfApiEngine",
"LiteLLMEngine",
"HfModel",
"TransformersModel",
"HfApiModel",
"LiteLLMModel",
]

View File

@ -19,11 +19,11 @@ from rich.text import Text
class Monitor:
def __init__(self, tracked_llm_engine):
def __init__(self, tracked_model):
self.step_durations = []
self.tracked_llm_engine = tracked_llm_engine
self.tracked_model = tracked_model
if (
getattr(self.tracked_llm_engine, "last_input_token_count", "Not found")
getattr(self.tracked_model, "last_input_token_count", "Not found")
!= "Not found"
):
self.total_input_token_count = 0
@ -41,13 +41,9 @@ class Monitor:
f"[Step {len(self.step_durations)-1}: Duration {step_duration:.2f} seconds"
)
if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None:
self.total_input_token_count += (
self.tracked_llm_engine.last_input_token_count
)
self.total_output_token_count += (
self.tracked_llm_engine.last_output_token_count
)
if getattr(self.tracked_model, "last_input_token_count", None) is not None:
self.total_input_token_count += self.tracked_model.last_input_token_count
self.total_output_token_count += self.tracked_model.last_output_token_count
console_outputs += f"| Input tokens: {self.total_input_token_count:,} | Output tokens: {self.total_output_token_count:,}"
console_outputs += "]"
console.print(Text(console_outputs, style="dim"))

View File

@ -39,53 +39,32 @@ def get_new_path(suffix="") -> str:
return os.path.join(directory, str(uuid.uuid4()) + suffix)
def fake_json_llm(messages, stop_sequences=None, grammar=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
return """
Thought: I should multiply 2 by 3.6452. special_marker
Action:
{
"action": "python_interpreter",
"action_input": {"code": "2*3.6452"}
}
"""
else: # We're at step 2
return """
Thought: I can now answer the initial question
Action:
{
"action": "final_answer",
"action_input": {"answer": "7.2904"}
}
"""
class FakeToolCallModel:
def get_tool_call(
self, messages, available_tools, stop_sequences=None, grammar=None
):
if len(messages) < 3:
return "python_interpreter", {"code": "2*3.6452"}, "call_0"
else:
return "final_answer", {"answer": "7.2904"}, "call_1"
def fake_json_llm_image(messages, stop_sequences=None, grammar=None) -> str:
prompt = str(messages)
class FakeToolCallModelImage:
def get_tool_call(
self, messages, available_tools, stop_sequences=None, grammar=None
):
if len(messages) < 3:
return (
"fake_image_generation_tool",
{"prompt": "An image of a cat"},
"call_0",
)
if "special_marker" not in prompt:
return """
Thought: I should generate an image. special_marker
Action:
{
"action": "fake_image_generation_tool",
"action_input": {"prompt": "An image of a cat"}
}
"""
else: # We're at step 2
return """
Thought: I can now answer the initial question
Action:
{
"action": "final_answer",
"action_input": "image.png"
}
"""
else: # We're at step 2
return "final_answer", "image.png", "call_1"
def fake_code_llm(messages, stop_sequences=None, grammar=None) -> str:
def fake_code_model(messages, stop_sequences=None, grammar=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
return """
@ -105,7 +84,7 @@ final_answer(7.2904)
"""
def fake_code_llm_error(messages, stop_sequences=None) -> str:
def fake_code_model_error(messages, stop_sequences=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
return """
@ -150,7 +129,7 @@ final_answer(res)
"""
def fake_code_llm_single_step(messages, stop_sequences=None, grammar=None) -> str:
def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) -> str:
return """
Thought: I should multiply 2 by 3.6452. special_marker
Code:
@ -161,7 +140,7 @@ final_answer(result)
"""
def fake_code_llm_no_return(messages, stop_sequences=None, grammar=None) -> str:
def fake_code_model_no_return(messages, stop_sequences=None, grammar=None) -> str:
return """
Thought: I should multiply 2 by 3.6452. special_marker
Code:
@ -175,34 +154,24 @@ print(result)
class AgentTests(unittest.TestCase):
def test_fake_single_step_code_agent(self):
agent = CodeAgent(
tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_single_step
tools=[PythonInterpreterTool()], model=fake_code_model_single_step
)
output = agent.run("What is 2 multiplied by 3.6452?", single_step=True)
assert isinstance(output, str)
assert output == "7.2904"
def test_fake_json_agent(self):
def test_fake_toolcalling_agent(self):
agent = ToolCallingAgent(
tools=[PythonInterpreterTool()], llm_engine=fake_json_llm
tools=[PythonInterpreterTool()], model=FakeToolCallModel()
)
output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, str)
assert output == "7.2904"
assert agent.logs[1].task == "What is 2 multiplied by 3.6452?"
assert agent.logs[2].observations == "7.2904"
assert (
agent.logs[3].llm_output
== """
Thought: I can now answer the initial question
Action:
{
"action": "final_answer",
"action_input": {"answer": "7.2904"}
}
"""
)
assert agent.logs[3].llm_output is None
def test_json_agent_handles_image_tool_outputs(self):
def test_toolcalling_agent_handles_image_tool_outputs(self):
from PIL import Image
@tool
@ -215,33 +184,32 @@ Action:
return Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png")
agent = ToolCallingAgent(
tools=[fake_image_generation_tool], llm_engine=fake_json_llm_image
tools=[fake_image_generation_tool], model=FakeToolCallModelImage()
)
output = agent.run("Make me an image.")
assert isinstance(output, Image.Image)
assert isinstance(agent.state["image.png"], Image.Image)
def test_fake_code_agent(self):
agent = CodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_code_llm)
agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model)
output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, float)
assert output == 7.2904
assert agent.logs[1].task == "What is 2 multiplied by 3.6452?"
assert agent.logs[3].tool_call == ToolCall(
tool_name="python_interpreter",
tool_arguments="final_answer(7.2904)",
name="python_interpreter", arguments="final_answer(7.2904)", id="call_3"
)
def test_additional_args_added_to_task(self):
agent = CodeAgent(tools=[], llm_engine=fake_code_llm)
agent = CodeAgent(tools=[], model=fake_code_model)
agent.run(
"What is 2 multiplied by 3.6452?", additional_instruction="Remember this."
)
assert "Remember this" in agent.task
assert "Remember this" in str(agent.prompt_messages)
assert "Remember this" in str(agent.input_messages)
def test_reset_conversations(self):
agent = CodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_code_llm)
agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model)
output = agent.run("What is 2 multiplied by 3.6452?", reset=True)
assert output == 7.2904
assert len(agent.logs) == 4
@ -255,21 +223,19 @@ Action:
assert len(agent.logs) == 4
def test_code_agent_code_errors_show_offending_lines(self):
agent = CodeAgent(
tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_error
)
agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_error)
output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, AgentText)
assert output == "got an error"
assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs)
def test_setup_agent_with_empty_toolbox(self):
ToolCallingAgent(llm_engine=fake_json_llm, tools=[])
ToolCallingAgent(model=FakeToolCallModel(), tools=[])
def test_fails_max_iterations(self):
agent = CodeAgent(
tools=[PythonInterpreterTool()],
llm_engine=fake_code_llm_no_return, # use this callable because it never ends
model=fake_code_model_no_return, # use this callable because it never ends
max_iterations=5,
)
agent.run("What is 2 multiplied by 3.6452?")
@ -278,19 +244,19 @@ Action:
def test_init_agent_with_different_toolsets(self):
toolset_1 = []
agent = CodeAgent(tools=toolset_1, llm_engine=fake_code_llm)
agent = CodeAgent(tools=toolset_1, model=fake_code_model)
assert (
len(agent.toolbox.tools) == 1
) # when no tools are provided, only the final_answer tool is added by default
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
agent = CodeAgent(tools=toolset_2, llm_engine=fake_code_llm)
agent = CodeAgent(tools=toolset_2, model=fake_code_model)
assert (
len(agent.toolbox.tools) == 2
) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer
toolset_3 = Toolbox(toolset_2)
agent = CodeAgent(tools=toolset_3, llm_engine=fake_code_llm)
agent = CodeAgent(tools=toolset_3, model=fake_code_model)
assert (
len(agent.toolbox.tools) == 2
) # same as previous one, where toolset_3 is an instantiation of previous one
@ -298,18 +264,20 @@ Action:
# check that add_base_tools will not interfere with existing tools
with pytest.raises(KeyError) as e:
agent = ToolCallingAgent(
tools=toolset_3, llm_engine=fake_json_llm, add_base_tools=True
tools=toolset_3, model=FakeToolCallModel(), add_base_tools=True
)
assert "already exists in the toolbox" in str(e)
# check that python_interpreter base tool does not get added to code agents
agent = CodeAgent(tools=[], llm_engine=fake_code_llm, add_base_tools=True)
assert len(agent.toolbox.tools) == 2 # added final_answer tool + search
agent = CodeAgent(tools=[], model=fake_code_model, add_base_tools=True)
assert (
len(agent.toolbox.tools) == 3
) # added final_answer tool + search + transcribe
def test_function_persistence_across_steps(self):
agent = CodeAgent(
tools=[],
llm_engine=fake_code_functiondef,
model=fake_code_functiondef,
max_iterations=2,
additional_authorized_imports=["numpy"],
)
@ -317,17 +285,17 @@ Action:
assert res[0] == 0.5
def test_init_managed_agent(self):
agent = CodeAgent(tools=[], llm_engine=fake_code_functiondef)
agent = CodeAgent(tools=[], model=fake_code_functiondef)
managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
assert managed_agent.name == "managed_agent"
assert managed_agent.description == "Empty"
def test_agent_description_gets_correctly_inserted_in_system_prompt(self):
agent = CodeAgent(tools=[], llm_engine=fake_code_functiondef)
agent = CodeAgent(tools=[], model=fake_code_functiondef)
managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
manager_agent = CodeAgent(
tools=[],
llm_engine=fake_code_functiondef,
model=fake_code_functiondef,
managed_agents=[managed_agent],
)
assert "You can also give requests to team members." not in agent.system_prompt

View File

@ -26,7 +26,7 @@ from smolagents import (
class MonitoringTester(unittest.TestCase):
def test_code_agent_metrics(self):
class FakeLLMEngine:
class FakeLLMModel:
def __init__(self):
self.last_input_token_count = 10
self.last_output_token_count = 20
@ -40,7 +40,7 @@ final_answer('This is the final answer.')
agent = CodeAgent(
tools=[],
llm_engine=FakeLLMEngine(),
model=FakeLLMModel(),
max_iterations=1,
)
@ -50,7 +50,7 @@ final_answer('This is the final answer.')
self.assertEqual(agent.monitor.total_output_token_count, 20)
def test_json_agent_metrics(self):
class FakeLLMEngine:
class FakeLLMModel:
def __init__(self):
self.last_input_token_count = 10
self.last_output_token_count = 20
@ -60,7 +60,7 @@ final_answer('This is the final answer.')
agent = ToolCallingAgent(
tools=[],
llm_engine=FakeLLMEngine(),
model=FakeLLMModel(),
max_iterations=1,
)
@ -70,7 +70,7 @@ final_answer('This is the final answer.')
self.assertEqual(agent.monitor.total_output_token_count, 20)
def test_code_agent_metrics_max_iterations(self):
class FakeLLMEngine:
class FakeLLMModel:
def __init__(self):
self.last_input_token_count = 10
self.last_output_token_count = 20
@ -80,7 +80,7 @@ final_answer('This is the final answer.')
agent = CodeAgent(
tools=[],
llm_engine=FakeLLMEngine(),
model=FakeLLMModel(),
max_iterations=1,
)
@ -90,7 +90,7 @@ final_answer('This is the final answer.')
self.assertEqual(agent.monitor.total_output_token_count, 40)
def test_code_agent_metrics_generation_error(self):
class FakeLLMEngine:
class FakeLLMModel:
def __init__(self):
self.last_input_token_count = 10
self.last_output_token_count = 20
@ -100,7 +100,7 @@ final_answer('This is the final answer.')
agent = CodeAgent(
tools=[],
llm_engine=FakeLLMEngine(),
model=FakeLLMModel(),
max_iterations=1,
)
@ -110,7 +110,7 @@ final_answer('This is the final answer.')
self.assertEqual(agent.monitor.total_output_token_count, 40)
def test_streaming_agent_text_output(self):
def dummy_llm_engine(prompt, **kwargs):
def dummy_model(prompt, **kwargs):
return """
Code:
```py
@ -119,7 +119,7 @@ final_answer('This is the final answer.')
agent = CodeAgent(
tools=[],
llm_engine=dummy_llm_engine,
model=dummy_model,
max_iterations=1,
)
@ -132,14 +132,14 @@ final_answer('This is the final answer.')
self.assertIn("This is the final answer.", final_message.content)
def test_streaming_agent_image_output(self):
def dummy_llm_engine(prompt, **kwargs):
def dummy_model(prompt, **kwargs):
return (
'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
)
agent = ToolCallingAgent(
tools=[],
llm_engine=dummy_llm_engine,
model=dummy_model,
max_iterations=1,
)
@ -161,12 +161,12 @@ final_answer('This is the final answer.')
self.assertEqual(final_message.content["mime_type"], "image/png")
def test_streaming_with_agent_error(self):
def dummy_llm_engine(prompt, **kwargs):
def dummy_model(prompt, **kwargs):
raise AgentError("Simulated agent error")
agent = CodeAgent(
tools=[],
llm_engine=dummy_llm_engine,
model=dummy_model,
max_iterations=1,
)