Allow passing kwargs to all models (#222)

* Allow passing kwargs to all models
This commit is contained in:
Aymeric Roucher 2025-01-16 23:03:38 +01:00 committed by GitHub
parent a1d8f3c398
commit b4091cb5ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 79 additions and 37 deletions

View File

@ -57,6 +57,10 @@ contains the API docs for the underlying classes.
[[autodoc]] VisitWebpageTool
### UserInputTool
[[autodoc]] UserInputTool
## ToolCollection
[[autodoc]] ToolCollection

View File

@ -144,7 +144,7 @@ class UserInputTool(Tool):
output_type = "string"
def forward(self, question):
user_input = input(f"{question} => ")
user_input = input(f"{question} => Type your answer here:")
return user_input

View File

@ -222,7 +222,6 @@ class Model:
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
) -> ChatMessage:
"""Process the input messages and return the model's response.
@ -233,8 +232,6 @@ class Model:
A list of strings that will stop the generation if encountered in the model's output.
grammar (`str`, *optional*):
The grammar or formatting structure to use in the model's response.
max_tokens (`int`, *optional*):
The maximum count of tokens to generate.
Returns:
`str`: The text content of the model's response.
"""
@ -244,7 +241,7 @@ class Model:
class HfApiModel(Model):
"""A class to interact with Hugging Face's Inference API for language model interaction.
This engine allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
This model allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
Parameters:
model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`):
@ -265,9 +262,10 @@ class HfApiModel(Model):
>>> engine = HfApiModel(
... model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
... token="your_hf_token_here",
... max_tokens=5000,
... )
>>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}]
>>> response = engine(messages, stop_sequences=["END"], max_tokens=1500)
>>> response = engine(messages, stop_sequences=["END"])
>>> print(response)
"Quantum mechanics is the branch of physics that studies..."
```
@ -279,6 +277,7 @@ class HfApiModel(Model):
token: Optional[str] = None,
timeout: Optional[int] = 120,
temperature: float = 0.5,
**kwargs,
):
super().__init__()
self.model_id = model_id
@ -286,13 +285,13 @@ class HfApiModel(Model):
token = os.getenv("HF_TOKEN")
self.client = InferenceClient(self.model_id, token=token, timeout=timeout)
self.temperature = temperature
self.kwargs = kwargs
def __call__(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
) -> ChatMessage:
"""
@ -308,16 +307,16 @@ class HfApiModel(Model):
tools=[get_json_schema(tool) for tool in tools_to_call_from],
tool_choice="auto",
stop=stop_sequences,
max_tokens=max_tokens,
temperature=self.temperature,
**self.kwargs,
)
else:
response = self.client.chat.completions.create(
model=self.model_id,
messages=messages,
stop=stop_sequences,
max_tokens=max_tokens,
temperature=self.temperature,
**self.kwargs,
)
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
@ -325,16 +324,44 @@ class HfApiModel(Model):
class TransformersModel(Model):
"""This engine initializes a model and tokenizer from the given `model_id`.
"""A class to interact with Hugging Face's Inference API for language model interaction.
This model allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
Parameters:
model_id (`str`, *optional*, defaults to `"HuggingFaceTB/SmolLM2-1.7B-Instruct"`):
model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`):
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
device (`str`, optional, defaults to `"cuda"` if available, else `"cpu"`.):
The device to load the model on (`"cpu"` or `"cuda"`).
device_map (`str`, *optional*):
The device_map to initialize your model with.
torch_dtype (`str`, *optional*):
The torch_dtype to initialize your model with.
kwargs (dict, *optional*):
Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`.
Raises:
ValueError:
If the model name is not provided.
Example:
```python
>>> engine = TransformersModel(
... model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
... device="cuda",
... max_new_tokens=5000,
... )
>>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}]
>>> response = engine(messages, stop_sequences=["END"])
>>> print(response)
"Quantum mechanics is the branch of physics that studies..."
```
"""
def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None):
def __init__(
self,
model_id: Optional[str] = None,
device_map: Optional[str] = None,
torch_dtype: Optional[str] = None,
**kwargs,
):
super().__init__()
if not is_torch_available():
raise ImportError("Please install torch in order to use TransformersModel.")
@ -347,14 +374,14 @@ class TransformersModel(Model):
f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'"
)
self.model_id = model_id
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
logger.info(f"Using device: {self.device}")
self.kwargs = kwargs
if device_map is None:
device_map = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device_map}")
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained(
model_id, device_map=self.device
model_id, device_map=device_map, torch_dtype=torch_dtype
)
except Exception as e:
logger.warning(
@ -363,7 +390,7 @@ class TransformersModel(Model):
self.model_id = default_model_id
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
self.model = AutoModelForCausalLM.from_pretrained(
model_id, device_map=self.device
model_id, device_map=device_map, torch_dtype=torch_dtype
)
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
@ -397,7 +424,6 @@ class TransformersModel(Model):
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
) -> ChatMessage:
messages = get_clean_message_list(
@ -422,10 +448,10 @@ class TransformersModel(Model):
out = self.model.generate(
**prompt_tensor,
max_new_tokens=max_tokens,
stopping_criteria=(
self.make_stopping_criteria(stop_sequences) if stop_sequences else None
),
**self.kwargs,
)
generated_tokens = out[0, count_prompt_tokens:]
output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
@ -458,6 +484,19 @@ class TransformersModel(Model):
class LiteLLMModel(Model):
"""This model connects to [LiteLLM](https://www.litellm.ai/) as a gateway to hundreds of LLMs.
Parameters:
model_id (`str`):
The model identifier to use on the server (e.g. "gpt-3.5-turbo").
api_base (`str`):
The base URL of the OpenAI-compatible API server.
api_key (`str`):
The API key to use for authentication.
**kwargs:
Additional keyword arguments to pass to the OpenAI API.
"""
def __init__(
self,
model_id="anthropic/claude-3-5-sonnet-20240620",
@ -482,7 +521,6 @@ class LiteLLMModel(Model):
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
) -> ChatMessage:
messages = get_clean_message_list(
@ -495,7 +533,6 @@ class LiteLLMModel(Model):
tools=[get_json_schema(tool) for tool in tools_to_call_from],
tool_choice="required",
stop=stop_sequences,
max_tokens=max_tokens,
api_base=self.api_base,
api_key=self.api_key,
**self.kwargs,
@ -505,7 +542,6 @@ class LiteLLMModel(Model):
model=self.model_id,
messages=messages,
stop=stop_sequences,
max_tokens=max_tokens,
api_base=self.api_base,
api_key=self.api_key,
**self.kwargs,
@ -516,7 +552,7 @@ class LiteLLMModel(Model):
class OpenAIServerModel(Model):
"""This engine connects to an OpenAI-compatible API server.
"""This model connects to an OpenAI-compatible API server.
Parameters:
model_id (`str`):
@ -525,8 +561,6 @@ class OpenAIServerModel(Model):
The base URL of the OpenAI-compatible API server.
api_key (`str`):
The API key to use for authentication.
temperature (`float`, *optional*, defaults to 0.7):
Controls randomness in the model's responses. Values between 0 and 2.
**kwargs:
Additional keyword arguments to pass to the OpenAI API.
"""
@ -536,7 +570,6 @@ class OpenAIServerModel(Model):
model_id: str,
api_base: str,
api_key: str,
temperature: float = 0.7,
**kwargs,
):
super().__init__()
@ -545,7 +578,6 @@ class OpenAIServerModel(Model):
base_url=api_base,
api_key=api_key,
)
self.temperature = temperature
self.kwargs = kwargs
def __call__(
@ -553,7 +585,6 @@ class OpenAIServerModel(Model):
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
) -> ChatMessage:
messages = get_clean_message_list(
@ -566,8 +597,6 @@ class OpenAIServerModel(Model):
tools=[get_json_schema(tool) for tool in tools_to_call_from],
tool_choice="auto",
stop=stop_sequences,
max_tokens=max_tokens,
temperature=self.temperature,
**self.kwargs,
)
else:
@ -575,8 +604,6 @@ class OpenAIServerModel(Model):
model=self.model_id,
messages=messages,
stop=stop_sequences,
max_tokens=max_tokens,
temperature=self.temperature,
**self.kwargs,
)
self.last_input_token_count = response.usage.prompt_tokens

View File

@ -16,7 +16,7 @@ import unittest
import json
from typing import Optional
from smolagents import models, tool, ChatMessage, HfApiModel
from smolagents import models, tool, ChatMessage, HfApiModel, TransformersModel
class ModelTests(unittest.TestCase):
@ -46,6 +46,17 @@ class ModelTests(unittest.TestCase):
assert data["content"] == "Hello!"
def test_get_hfapi_message_no_tool(self):
model = HfApiModel()
model = HfApiModel(max_tokens=10)
messages = [{"role": "user", "content": "Hello!"}]
model(messages, stop_sequences=["great"])
def test_transformers_message_no_tool(self):
model = TransformersModel(
model_id="HuggingFaceTB/SmolLM2-135M-Instruct",
max_new_tokens=5,
device_map="auto",
do_sample=False,
)
messages = [{"role": "user", "content": "Hello!"}]
output = model(messages, stop_sequences=["great"]).content
assert output == "assistant\nHello"