Improve inference choice examples (#311)
* Improve inference choice examples * Fix style --------- Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
This commit is contained in:
parent
0196dc7b21
commit
de7b0ee799
|
@ -0,0 +1,51 @@
|
|||
from typing import Optional
|
||||
|
||||
from smolagents import HfApiModel, LiteLLMModel, TransformersModel, tool
|
||||
from smolagents.agents import CodeAgent, ToolCallingAgent
|
||||
|
||||
|
||||
# Choose which inference type to use!
|
||||
|
||||
available_inferences = ["hf_api", "transformers", "ollama", "litellm"]
|
||||
chosen_inference = "transformers"
|
||||
|
||||
print(f"Chose model {chosen_inference}")
|
||||
|
||||
if chosen_inference == "hf_api":
|
||||
model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct")
|
||||
|
||||
elif chosen_inference == "transformers":
|
||||
model = TransformersModel(model_id="HuggingFaceTB/SmolLM2-1.7B-Instruct", device_map="auto", max_new_tokens=1000)
|
||||
|
||||
elif chosen_inference == "ollama":
|
||||
model = LiteLLMModel(
|
||||
model_id="ollama_chat/llama3.2",
|
||||
api_base="http://localhost:11434", # replace with remote open-ai compatible server if necessary
|
||||
api_key="your-api-key", # replace with API key if necessary
|
||||
)
|
||||
|
||||
elif chosen_inference == "litellm":
|
||||
# For anthropic: change model_id below to 'anthropic/claude-3-5-sonnet-latest'
|
||||
model = LiteLLMModel(model_id="gpt-4o")
|
||||
|
||||
|
||||
@tool
|
||||
def get_weather(location: str, celsius: Optional[bool] = False) -> str:
|
||||
"""
|
||||
Get weather in the next days at given location.
|
||||
Secretly this tool does not care about the location, it hates the weather everywhere.
|
||||
|
||||
Args:
|
||||
location: the location
|
||||
celsius: the temperature
|
||||
"""
|
||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||
|
||||
|
||||
agent = ToolCallingAgent(tools=[get_weather], model=model)
|
||||
|
||||
print("ToolCallingAgent:", agent.run("What's the weather like in Paris?"))
|
||||
|
||||
agent = CodeAgent(tools=[get_weather], model=model)
|
||||
|
||||
print("ToolCallingAgent:", agent.run("What's the weather like in Paris?"))
|
|
@ -1,30 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
from smolagents import LiteLLMModel, tool
|
||||
from smolagents.agents import ToolCallingAgent
|
||||
|
||||
|
||||
# Choose which LLM engine to use!
|
||||
# model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct")
|
||||
# model = TransformersModel(model_id="meta-llama/Llama-3.2-2B-Instruct")
|
||||
|
||||
# For anthropic: change model_id below to 'anthropic/claude-3-5-sonnet-20240620'
|
||||
model = LiteLLMModel(model_id="gpt-4o")
|
||||
|
||||
|
||||
@tool
|
||||
def get_weather(location: str, celsius: Optional[bool] = False) -> str:
|
||||
"""
|
||||
Get weather in the next days at given location.
|
||||
Secretly this tool does not care about the location, it hates the weather everywhere.
|
||||
|
||||
Args:
|
||||
location: the location
|
||||
celsius: the temperature
|
||||
"""
|
||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||
|
||||
|
||||
agent = ToolCallingAgent(tools=[get_weather], model=model)
|
||||
|
||||
print(agent.run("What's the weather like in Paris?"))
|
|
@ -1,29 +0,0 @@
|
|||
"""An example of loading a ToolCollection directly from an MCP server.
|
||||
|
||||
Requirements: to run this example, you need to have uv installed and in your path in
|
||||
order to run the MCP server with uvx see `mcp_server_params` below.
|
||||
|
||||
Note this is just a demo MCP server that was implemented for the purpose of this example.
|
||||
It only provide a single tool to search amongst pubmed papers abstracts.
|
||||
|
||||
Usage:
|
||||
>>> uv run examples/tool_calling_agent_mcp.py
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from mcp import StdioServerParameters
|
||||
|
||||
from smolagents import CodeAgent, HfApiModel, ToolCollection
|
||||
|
||||
|
||||
mcp_server_params = StdioServerParameters(
|
||||
command="uvx",
|
||||
args=["--quiet", "pubmedmcp@0.1.3"],
|
||||
env={"UV_PYTHON": "3.12", **os.environ},
|
||||
)
|
||||
|
||||
with ToolCollection.from_mcp(mcp_server_params) as tool_collection:
|
||||
# print(tool_collection.tools[0](request={"term": "efficient treatment hangover"}))
|
||||
agent = CodeAgent(tools=tool_collection.tools, model=HfApiModel(), max_steps=4)
|
||||
agent.run("Find me one risk associated with drinking alcohol regularly on low doses for humans.")
|
|
@ -1,29 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
from smolagents import LiteLLMModel, tool
|
||||
from smolagents.agents import ToolCallingAgent
|
||||
|
||||
|
||||
model = LiteLLMModel(
|
||||
model_id="ollama_chat/llama3.2",
|
||||
api_base="http://localhost:11434", # replace with remote open-ai compatible server if necessary
|
||||
api_key="your-api-key", # replace with API key if necessary
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
def get_weather(location: str, celsius: Optional[bool] = False) -> str:
|
||||
"""
|
||||
Get weather in the next days at given location.
|
||||
Secretly this tool does not care about the location, it hates the weather everywhere.
|
||||
|
||||
Args:
|
||||
location: the location
|
||||
celsius: the temperature
|
||||
"""
|
||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||
|
||||
|
||||
agent = ToolCallingAgent(tools=[get_weather], model=model)
|
||||
|
||||
print(agent.run("What's the weather like in Paris?"))
|
|
@ -480,7 +480,6 @@ class TransformersModel(Model):
|
|||
messages=messages,
|
||||
stop_sequences=stop_sequences,
|
||||
grammar=grammar,
|
||||
tools_to_call_from=tools_to_call_from,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -497,9 +496,6 @@ class TransformersModel(Model):
|
|||
if max_new_tokens:
|
||||
completion_kwargs["max_new_tokens"] = max_new_tokens
|
||||
|
||||
if stop_sequences:
|
||||
completion_kwargs["stopping_criteria"] = self.make_stopping_criteria(stop_sequences)
|
||||
|
||||
if tools_to_call_from is not None:
|
||||
prompt_tensor = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
|
@ -518,7 +514,11 @@ class TransformersModel(Model):
|
|||
prompt_tensor = prompt_tensor.to(self.model.device)
|
||||
count_prompt_tokens = prompt_tensor["input_ids"].shape[1]
|
||||
|
||||
out = self.model.generate(**prompt_tensor, **completion_kwargs)
|
||||
out = self.model.generate(
|
||||
**prompt_tensor,
|
||||
stopping_criteria=(self.make_stopping_criteria(stop_sequences) if stop_sequences else None),
|
||||
**completion_kwargs,
|
||||
)
|
||||
generated_tokens = out[0, count_prompt_tokens:]
|
||||
output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||
self.last_input_token_count = count_prompt_tokens
|
||||
|
|
Loading…
Reference in New Issue