Allow passing kwargs to all models (#222)

* Allow passing kwargs to all models
This commit is contained in:
Aymeric Roucher 2025-01-16 23:03:38 +01:00 committed by GitHub
parent a1d8f3c398
commit b4091cb5ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 79 additions and 37 deletions

View File

@ -57,6 +57,10 @@ contains the API docs for the underlying classes.
[[autodoc]] VisitWebpageTool [[autodoc]] VisitWebpageTool
### UserInputTool
[[autodoc]] UserInputTool
## ToolCollection ## ToolCollection
[[autodoc]] ToolCollection [[autodoc]] ToolCollection

View File

@ -144,7 +144,7 @@ class UserInputTool(Tool):
output_type = "string" output_type = "string"
def forward(self, question): def forward(self, question):
user_input = input(f"{question} => ") user_input = input(f"{question} => Type your answer here:")
return user_input return user_input

View File

@ -222,7 +222,6 @@ class Model:
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None, grammar: Optional[str] = None,
max_tokens: int = 1500,
) -> ChatMessage: ) -> ChatMessage:
"""Process the input messages and return the model's response. """Process the input messages and return the model's response.
@ -233,8 +232,6 @@ class Model:
A list of strings that will stop the generation if encountered in the model's output. A list of strings that will stop the generation if encountered in the model's output.
grammar (`str`, *optional*): grammar (`str`, *optional*):
The grammar or formatting structure to use in the model's response. The grammar or formatting structure to use in the model's response.
max_tokens (`int`, *optional*):
The maximum count of tokens to generate.
Returns: Returns:
`str`: The text content of the model's response. `str`: The text content of the model's response.
""" """
@ -244,7 +241,7 @@ class Model:
class HfApiModel(Model): class HfApiModel(Model):
"""A class to interact with Hugging Face's Inference API for language model interaction. """A class to interact with Hugging Face's Inference API for language model interaction.
This engine allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization. This model allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
Parameters: Parameters:
model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`): model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`):
@ -265,9 +262,10 @@ class HfApiModel(Model):
>>> engine = HfApiModel( >>> engine = HfApiModel(
... model_id="Qwen/Qwen2.5-Coder-32B-Instruct", ... model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
... token="your_hf_token_here", ... token="your_hf_token_here",
... max_tokens=5000,
... ) ... )
>>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}] >>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}]
>>> response = engine(messages, stop_sequences=["END"], max_tokens=1500) >>> response = engine(messages, stop_sequences=["END"])
>>> print(response) >>> print(response)
"Quantum mechanics is the branch of physics that studies..." "Quantum mechanics is the branch of physics that studies..."
``` ```
@ -279,6 +277,7 @@ class HfApiModel(Model):
token: Optional[str] = None, token: Optional[str] = None,
timeout: Optional[int] = 120, timeout: Optional[int] = 120,
temperature: float = 0.5, temperature: float = 0.5,
**kwargs,
): ):
super().__init__() super().__init__()
self.model_id = model_id self.model_id = model_id
@ -286,13 +285,13 @@ class HfApiModel(Model):
token = os.getenv("HF_TOKEN") token = os.getenv("HF_TOKEN")
self.client = InferenceClient(self.model_id, token=token, timeout=timeout) self.client = InferenceClient(self.model_id, token=token, timeout=timeout)
self.temperature = temperature self.temperature = temperature
self.kwargs = kwargs
def __call__( def __call__(
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None, grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None, tools_to_call_from: Optional[List[Tool]] = None,
) -> ChatMessage: ) -> ChatMessage:
""" """
@ -308,16 +307,16 @@ class HfApiModel(Model):
tools=[get_json_schema(tool) for tool in tools_to_call_from], tools=[get_json_schema(tool) for tool in tools_to_call_from],
tool_choice="auto", tool_choice="auto",
stop=stop_sequences, stop=stop_sequences,
max_tokens=max_tokens,
temperature=self.temperature, temperature=self.temperature,
**self.kwargs,
) )
else: else:
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=self.model_id, model=self.model_id,
messages=messages, messages=messages,
stop=stop_sequences, stop=stop_sequences,
max_tokens=max_tokens,
temperature=self.temperature, temperature=self.temperature,
**self.kwargs,
) )
self.last_input_token_count = response.usage.prompt_tokens self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens self.last_output_token_count = response.usage.completion_tokens
@ -325,16 +324,44 @@ class HfApiModel(Model):
class TransformersModel(Model): class TransformersModel(Model):
"""This engine initializes a model and tokenizer from the given `model_id`. """A class to interact with Hugging Face's Inference API for language model interaction.
This model allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
Parameters: Parameters:
model_id (`str`, *optional*, defaults to `"HuggingFaceTB/SmolLM2-1.7B-Instruct"`): model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`):
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub. The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
device (`str`, optional, defaults to `"cuda"` if available, else `"cpu"`.): device_map (`str`, *optional*):
The device to load the model on (`"cpu"` or `"cuda"`). The device_map to initialize your model with.
torch_dtype (`str`, *optional*):
The torch_dtype to initialize your model with.
kwargs (dict, *optional*):
Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`.
Raises:
ValueError:
If the model name is not provided.
Example:
```python
>>> engine = TransformersModel(
... model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
... device="cuda",
... max_new_tokens=5000,
... )
>>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}]
>>> response = engine(messages, stop_sequences=["END"])
>>> print(response)
"Quantum mechanics is the branch of physics that studies..."
```
""" """
def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None): def __init__(
self,
model_id: Optional[str] = None,
device_map: Optional[str] = None,
torch_dtype: Optional[str] = None,
**kwargs,
):
super().__init__() super().__init__()
if not is_torch_available(): if not is_torch_available():
raise ImportError("Please install torch in order to use TransformersModel.") raise ImportError("Please install torch in order to use TransformersModel.")
@ -347,14 +374,14 @@ class TransformersModel(Model):
f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'" f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'"
) )
self.model_id = model_id self.model_id = model_id
if device is None: self.kwargs = kwargs
device = "cuda" if torch.cuda.is_available() else "cpu" if device_map is None:
self.device = device device_map = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}") logger.info(f"Using device: {device_map}")
try: try:
self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_id, device_map=self.device model_id, device_map=device_map, torch_dtype=torch_dtype
) )
except Exception as e: except Exception as e:
logger.warning( logger.warning(
@ -363,7 +390,7 @@ class TransformersModel(Model):
self.model_id = default_model_id self.model_id = default_model_id
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_id, device_map=self.device model_id, device_map=device_map, torch_dtype=torch_dtype
) )
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
@ -397,7 +424,6 @@ class TransformersModel(Model):
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None, grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None, tools_to_call_from: Optional[List[Tool]] = None,
) -> ChatMessage: ) -> ChatMessage:
messages = get_clean_message_list( messages = get_clean_message_list(
@ -422,10 +448,10 @@ class TransformersModel(Model):
out = self.model.generate( out = self.model.generate(
**prompt_tensor, **prompt_tensor,
max_new_tokens=max_tokens,
stopping_criteria=( stopping_criteria=(
self.make_stopping_criteria(stop_sequences) if stop_sequences else None self.make_stopping_criteria(stop_sequences) if stop_sequences else None
), ),
**self.kwargs,
) )
generated_tokens = out[0, count_prompt_tokens:] generated_tokens = out[0, count_prompt_tokens:]
output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
@ -458,6 +484,19 @@ class TransformersModel(Model):
class LiteLLMModel(Model): class LiteLLMModel(Model):
"""This model connects to [LiteLLM](https://www.litellm.ai/) as a gateway to hundreds of LLMs.
Parameters:
model_id (`str`):
The model identifier to use on the server (e.g. "gpt-3.5-turbo").
api_base (`str`):
The base URL of the OpenAI-compatible API server.
api_key (`str`):
The API key to use for authentication.
**kwargs:
Additional keyword arguments to pass to the OpenAI API.
"""
def __init__( def __init__(
self, self,
model_id="anthropic/claude-3-5-sonnet-20240620", model_id="anthropic/claude-3-5-sonnet-20240620",
@ -482,7 +521,6 @@ class LiteLLMModel(Model):
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None, grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None, tools_to_call_from: Optional[List[Tool]] = None,
) -> ChatMessage: ) -> ChatMessage:
messages = get_clean_message_list( messages = get_clean_message_list(
@ -495,7 +533,6 @@ class LiteLLMModel(Model):
tools=[get_json_schema(tool) for tool in tools_to_call_from], tools=[get_json_schema(tool) for tool in tools_to_call_from],
tool_choice="required", tool_choice="required",
stop=stop_sequences, stop=stop_sequences,
max_tokens=max_tokens,
api_base=self.api_base, api_base=self.api_base,
api_key=self.api_key, api_key=self.api_key,
**self.kwargs, **self.kwargs,
@ -505,7 +542,6 @@ class LiteLLMModel(Model):
model=self.model_id, model=self.model_id,
messages=messages, messages=messages,
stop=stop_sequences, stop=stop_sequences,
max_tokens=max_tokens,
api_base=self.api_base, api_base=self.api_base,
api_key=self.api_key, api_key=self.api_key,
**self.kwargs, **self.kwargs,
@ -516,7 +552,7 @@ class LiteLLMModel(Model):
class OpenAIServerModel(Model): class OpenAIServerModel(Model):
"""This engine connects to an OpenAI-compatible API server. """This model connects to an OpenAI-compatible API server.
Parameters: Parameters:
model_id (`str`): model_id (`str`):
@ -525,8 +561,6 @@ class OpenAIServerModel(Model):
The base URL of the OpenAI-compatible API server. The base URL of the OpenAI-compatible API server.
api_key (`str`): api_key (`str`):
The API key to use for authentication. The API key to use for authentication.
temperature (`float`, *optional*, defaults to 0.7):
Controls randomness in the model's responses. Values between 0 and 2.
**kwargs: **kwargs:
Additional keyword arguments to pass to the OpenAI API. Additional keyword arguments to pass to the OpenAI API.
""" """
@ -536,7 +570,6 @@ class OpenAIServerModel(Model):
model_id: str, model_id: str,
api_base: str, api_base: str,
api_key: str, api_key: str,
temperature: float = 0.7,
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
@ -545,7 +578,6 @@ class OpenAIServerModel(Model):
base_url=api_base, base_url=api_base,
api_key=api_key, api_key=api_key,
) )
self.temperature = temperature
self.kwargs = kwargs self.kwargs = kwargs
def __call__( def __call__(
@ -553,7 +585,6 @@ class OpenAIServerModel(Model):
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None, grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None, tools_to_call_from: Optional[List[Tool]] = None,
) -> ChatMessage: ) -> ChatMessage:
messages = get_clean_message_list( messages = get_clean_message_list(
@ -566,8 +597,6 @@ class OpenAIServerModel(Model):
tools=[get_json_schema(tool) for tool in tools_to_call_from], tools=[get_json_schema(tool) for tool in tools_to_call_from],
tool_choice="auto", tool_choice="auto",
stop=stop_sequences, stop=stop_sequences,
max_tokens=max_tokens,
temperature=self.temperature,
**self.kwargs, **self.kwargs,
) )
else: else:
@ -575,8 +604,6 @@ class OpenAIServerModel(Model):
model=self.model_id, model=self.model_id,
messages=messages, messages=messages,
stop=stop_sequences, stop=stop_sequences,
max_tokens=max_tokens,
temperature=self.temperature,
**self.kwargs, **self.kwargs,
) )
self.last_input_token_count = response.usage.prompt_tokens self.last_input_token_count = response.usage.prompt_tokens

View File

@ -16,7 +16,7 @@ import unittest
import json import json
from typing import Optional from typing import Optional
from smolagents import models, tool, ChatMessage, HfApiModel from smolagents import models, tool, ChatMessage, HfApiModel, TransformersModel
class ModelTests(unittest.TestCase): class ModelTests(unittest.TestCase):
@ -46,6 +46,17 @@ class ModelTests(unittest.TestCase):
assert data["content"] == "Hello!" assert data["content"] == "Hello!"
def test_get_hfapi_message_no_tool(self): def test_get_hfapi_message_no_tool(self):
model = HfApiModel() model = HfApiModel(max_tokens=10)
messages = [{"role": "user", "content": "Hello!"}] messages = [{"role": "user", "content": "Hello!"}]
model(messages, stop_sequences=["great"]) model(messages, stop_sequences=["great"])
def test_transformers_message_no_tool(self):
model = TransformersModel(
model_id="HuggingFaceTB/SmolLM2-135M-Instruct",
max_new_tokens=5,
device_map="auto",
do_sample=False,
)
messages = [{"role": "user", "content": "Hello!"}]
output = model(messages, stop_sequences=["great"]).content
assert output == "assistant\nHello"