Improve inference choice examples (#311)

* Improve inference choice examples

* Fix style

---------

Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
This commit is contained in:
Aymeric Roucher 2025-01-24 16:32:35 +01:00 committed by GitHub
parent 0196dc7b21
commit de7b0ee799
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 56 additions and 93 deletions

View File

@ -0,0 +1,51 @@
from typing import Optional
from smolagents import HfApiModel, LiteLLMModel, TransformersModel, tool
from smolagents.agents import CodeAgent, ToolCallingAgent
# Choose which inference type to use!
available_inferences = ["hf_api", "transformers", "ollama", "litellm"]
chosen_inference = "transformers"
print(f"Chose model {chosen_inference}")
if chosen_inference == "hf_api":
model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct")
elif chosen_inference == "transformers":
model = TransformersModel(model_id="HuggingFaceTB/SmolLM2-1.7B-Instruct", device_map="auto", max_new_tokens=1000)
elif chosen_inference == "ollama":
model = LiteLLMModel(
model_id="ollama_chat/llama3.2",
api_base="http://localhost:11434", # replace with remote open-ai compatible server if necessary
api_key="your-api-key", # replace with API key if necessary
)
elif chosen_inference == "litellm":
# For anthropic: change model_id below to 'anthropic/claude-3-5-sonnet-latest'
model = LiteLLMModel(model_id="gpt-4o")
@tool
def get_weather(location: str, celsius: Optional[bool] = False) -> str:
"""
Get weather in the next days at given location.
Secretly this tool does not care about the location, it hates the weather everywhere.
Args:
location: the location
celsius: the temperature
"""
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
agent = ToolCallingAgent(tools=[get_weather], model=model)
print("ToolCallingAgent:", agent.run("What's the weather like in Paris?"))
agent = CodeAgent(tools=[get_weather], model=model)
print("ToolCallingAgent:", agent.run("What's the weather like in Paris?"))

View File

@ -1,30 +0,0 @@
from typing import Optional
from smolagents import LiteLLMModel, tool
from smolagents.agents import ToolCallingAgent
# Choose which LLM engine to use!
# model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct")
# model = TransformersModel(model_id="meta-llama/Llama-3.2-2B-Instruct")
# For anthropic: change model_id below to 'anthropic/claude-3-5-sonnet-20240620'
model = LiteLLMModel(model_id="gpt-4o")
@tool
def get_weather(location: str, celsius: Optional[bool] = False) -> str:
"""
Get weather in the next days at given location.
Secretly this tool does not care about the location, it hates the weather everywhere.
Args:
location: the location
celsius: the temperature
"""
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
agent = ToolCallingAgent(tools=[get_weather], model=model)
print(agent.run("What's the weather like in Paris?"))

View File

@ -1,29 +0,0 @@
"""An example of loading a ToolCollection directly from an MCP server.
Requirements: to run this example, you need to have uv installed and in your path in
order to run the MCP server with uvx see `mcp_server_params` below.
Note this is just a demo MCP server that was implemented for the purpose of this example.
It only provide a single tool to search amongst pubmed papers abstracts.
Usage:
>>> uv run examples/tool_calling_agent_mcp.py
"""
import os
from mcp import StdioServerParameters
from smolagents import CodeAgent, HfApiModel, ToolCollection
mcp_server_params = StdioServerParameters(
command="uvx",
args=["--quiet", "pubmedmcp@0.1.3"],
env={"UV_PYTHON": "3.12", **os.environ},
)
with ToolCollection.from_mcp(mcp_server_params) as tool_collection:
# print(tool_collection.tools[0](request={"term": "efficient treatment hangover"}))
agent = CodeAgent(tools=tool_collection.tools, model=HfApiModel(), max_steps=4)
agent.run("Find me one risk associated with drinking alcohol regularly on low doses for humans.")

View File

@ -1,29 +0,0 @@
from typing import Optional
from smolagents import LiteLLMModel, tool
from smolagents.agents import ToolCallingAgent
model = LiteLLMModel(
model_id="ollama_chat/llama3.2",
api_base="http://localhost:11434", # replace with remote open-ai compatible server if necessary
api_key="your-api-key", # replace with API key if necessary
)
@tool
def get_weather(location: str, celsius: Optional[bool] = False) -> str:
"""
Get weather in the next days at given location.
Secretly this tool does not care about the location, it hates the weather everywhere.
Args:
location: the location
celsius: the temperature
"""
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
agent = ToolCallingAgent(tools=[get_weather], model=model)
print(agent.run("What's the weather like in Paris?"))

View File

@ -480,7 +480,6 @@ class TransformersModel(Model):
messages=messages, messages=messages,
stop_sequences=stop_sequences, stop_sequences=stop_sequences,
grammar=grammar, grammar=grammar,
tools_to_call_from=tools_to_call_from,
**kwargs, **kwargs,
) )
@ -497,9 +496,6 @@ class TransformersModel(Model):
if max_new_tokens: if max_new_tokens:
completion_kwargs["max_new_tokens"] = max_new_tokens completion_kwargs["max_new_tokens"] = max_new_tokens
if stop_sequences:
completion_kwargs["stopping_criteria"] = self.make_stopping_criteria(stop_sequences)
if tools_to_call_from is not None: if tools_to_call_from is not None:
prompt_tensor = self.tokenizer.apply_chat_template( prompt_tensor = self.tokenizer.apply_chat_template(
messages, messages,
@ -518,7 +514,11 @@ class TransformersModel(Model):
prompt_tensor = prompt_tensor.to(self.model.device) prompt_tensor = prompt_tensor.to(self.model.device)
count_prompt_tokens = prompt_tensor["input_ids"].shape[1] count_prompt_tokens = prompt_tensor["input_ids"].shape[1]
out = self.model.generate(**prompt_tensor, **completion_kwargs) out = self.model.generate(
**prompt_tensor,
stopping_criteria=(self.make_stopping_criteria(stop_sequences) if stop_sequences else None),
**completion_kwargs,
)
generated_tokens = out[0, count_prompt_tokens:] generated_tokens = out[0, count_prompt_tokens:]
output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
self.last_input_token_count = count_prompt_tokens self.last_input_token_count = count_prompt_tokens