Improve inference choice examples (#311)
* Improve inference choice examples * Fix style --------- Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
This commit is contained in:
parent
0196dc7b21
commit
de7b0ee799
|
@ -0,0 +1,51 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from smolagents import HfApiModel, LiteLLMModel, TransformersModel, tool
|
||||||
|
from smolagents.agents import CodeAgent, ToolCallingAgent
|
||||||
|
|
||||||
|
|
||||||
|
# Choose which inference type to use!
|
||||||
|
|
||||||
|
available_inferences = ["hf_api", "transformers", "ollama", "litellm"]
|
||||||
|
chosen_inference = "transformers"
|
||||||
|
|
||||||
|
print(f"Chose model {chosen_inference}")
|
||||||
|
|
||||||
|
if chosen_inference == "hf_api":
|
||||||
|
model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct")
|
||||||
|
|
||||||
|
elif chosen_inference == "transformers":
|
||||||
|
model = TransformersModel(model_id="HuggingFaceTB/SmolLM2-1.7B-Instruct", device_map="auto", max_new_tokens=1000)
|
||||||
|
|
||||||
|
elif chosen_inference == "ollama":
|
||||||
|
model = LiteLLMModel(
|
||||||
|
model_id="ollama_chat/llama3.2",
|
||||||
|
api_base="http://localhost:11434", # replace with remote open-ai compatible server if necessary
|
||||||
|
api_key="your-api-key", # replace with API key if necessary
|
||||||
|
)
|
||||||
|
|
||||||
|
elif chosen_inference == "litellm":
|
||||||
|
# For anthropic: change model_id below to 'anthropic/claude-3-5-sonnet-latest'
|
||||||
|
model = LiteLLMModel(model_id="gpt-4o")
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def get_weather(location: str, celsius: Optional[bool] = False) -> str:
|
||||||
|
"""
|
||||||
|
Get weather in the next days at given location.
|
||||||
|
Secretly this tool does not care about the location, it hates the weather everywhere.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
location: the location
|
||||||
|
celsius: the temperature
|
||||||
|
"""
|
||||||
|
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||||
|
|
||||||
|
|
||||||
|
agent = ToolCallingAgent(tools=[get_weather], model=model)
|
||||||
|
|
||||||
|
print("ToolCallingAgent:", agent.run("What's the weather like in Paris?"))
|
||||||
|
|
||||||
|
agent = CodeAgent(tools=[get_weather], model=model)
|
||||||
|
|
||||||
|
print("ToolCallingAgent:", agent.run("What's the weather like in Paris?"))
|
|
@ -1,30 +0,0 @@
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from smolagents import LiteLLMModel, tool
|
|
||||||
from smolagents.agents import ToolCallingAgent
|
|
||||||
|
|
||||||
|
|
||||||
# Choose which LLM engine to use!
|
|
||||||
# model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct")
|
|
||||||
# model = TransformersModel(model_id="meta-llama/Llama-3.2-2B-Instruct")
|
|
||||||
|
|
||||||
# For anthropic: change model_id below to 'anthropic/claude-3-5-sonnet-20240620'
|
|
||||||
model = LiteLLMModel(model_id="gpt-4o")
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
def get_weather(location: str, celsius: Optional[bool] = False) -> str:
|
|
||||||
"""
|
|
||||||
Get weather in the next days at given location.
|
|
||||||
Secretly this tool does not care about the location, it hates the weather everywhere.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
location: the location
|
|
||||||
celsius: the temperature
|
|
||||||
"""
|
|
||||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
|
||||||
|
|
||||||
|
|
||||||
agent = ToolCallingAgent(tools=[get_weather], model=model)
|
|
||||||
|
|
||||||
print(agent.run("What's the weather like in Paris?"))
|
|
|
@ -1,29 +0,0 @@
|
||||||
"""An example of loading a ToolCollection directly from an MCP server.
|
|
||||||
|
|
||||||
Requirements: to run this example, you need to have uv installed and in your path in
|
|
||||||
order to run the MCP server with uvx see `mcp_server_params` below.
|
|
||||||
|
|
||||||
Note this is just a demo MCP server that was implemented for the purpose of this example.
|
|
||||||
It only provide a single tool to search amongst pubmed papers abstracts.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
>>> uv run examples/tool_calling_agent_mcp.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
from mcp import StdioServerParameters
|
|
||||||
|
|
||||||
from smolagents import CodeAgent, HfApiModel, ToolCollection
|
|
||||||
|
|
||||||
|
|
||||||
mcp_server_params = StdioServerParameters(
|
|
||||||
command="uvx",
|
|
||||||
args=["--quiet", "pubmedmcp@0.1.3"],
|
|
||||||
env={"UV_PYTHON": "3.12", **os.environ},
|
|
||||||
)
|
|
||||||
|
|
||||||
with ToolCollection.from_mcp(mcp_server_params) as tool_collection:
|
|
||||||
# print(tool_collection.tools[0](request={"term": "efficient treatment hangover"}))
|
|
||||||
agent = CodeAgent(tools=tool_collection.tools, model=HfApiModel(), max_steps=4)
|
|
||||||
agent.run("Find me one risk associated with drinking alcohol regularly on low doses for humans.")
|
|
|
@ -1,29 +0,0 @@
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from smolagents import LiteLLMModel, tool
|
|
||||||
from smolagents.agents import ToolCallingAgent
|
|
||||||
|
|
||||||
|
|
||||||
model = LiteLLMModel(
|
|
||||||
model_id="ollama_chat/llama3.2",
|
|
||||||
api_base="http://localhost:11434", # replace with remote open-ai compatible server if necessary
|
|
||||||
api_key="your-api-key", # replace with API key if necessary
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
def get_weather(location: str, celsius: Optional[bool] = False) -> str:
|
|
||||||
"""
|
|
||||||
Get weather in the next days at given location.
|
|
||||||
Secretly this tool does not care about the location, it hates the weather everywhere.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
location: the location
|
|
||||||
celsius: the temperature
|
|
||||||
"""
|
|
||||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
|
||||||
|
|
||||||
|
|
||||||
agent = ToolCallingAgent(tools=[get_weather], model=model)
|
|
||||||
|
|
||||||
print(agent.run("What's the weather like in Paris?"))
|
|
|
@ -480,7 +480,6 @@ class TransformersModel(Model):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stop_sequences=stop_sequences,
|
stop_sequences=stop_sequences,
|
||||||
grammar=grammar,
|
grammar=grammar,
|
||||||
tools_to_call_from=tools_to_call_from,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -497,9 +496,6 @@ class TransformersModel(Model):
|
||||||
if max_new_tokens:
|
if max_new_tokens:
|
||||||
completion_kwargs["max_new_tokens"] = max_new_tokens
|
completion_kwargs["max_new_tokens"] = max_new_tokens
|
||||||
|
|
||||||
if stop_sequences:
|
|
||||||
completion_kwargs["stopping_criteria"] = self.make_stopping_criteria(stop_sequences)
|
|
||||||
|
|
||||||
if tools_to_call_from is not None:
|
if tools_to_call_from is not None:
|
||||||
prompt_tensor = self.tokenizer.apply_chat_template(
|
prompt_tensor = self.tokenizer.apply_chat_template(
|
||||||
messages,
|
messages,
|
||||||
|
@ -518,7 +514,11 @@ class TransformersModel(Model):
|
||||||
prompt_tensor = prompt_tensor.to(self.model.device)
|
prompt_tensor = prompt_tensor.to(self.model.device)
|
||||||
count_prompt_tokens = prompt_tensor["input_ids"].shape[1]
|
count_prompt_tokens = prompt_tensor["input_ids"].shape[1]
|
||||||
|
|
||||||
out = self.model.generate(**prompt_tensor, **completion_kwargs)
|
out = self.model.generate(
|
||||||
|
**prompt_tensor,
|
||||||
|
stopping_criteria=(self.make_stopping_criteria(stop_sequences) if stop_sequences else None),
|
||||||
|
**completion_kwargs,
|
||||||
|
)
|
||||||
generated_tokens = out[0, count_prompt_tokens:]
|
generated_tokens = out[0, count_prompt_tokens:]
|
||||||
output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||||
self.last_input_token_count = count_prompt_tokens
|
self.last_input_token_count = count_prompt_tokens
|
||||||
|
|
Loading…
Reference in New Issue