Allow passing kwargs to all models (#222)
* Allow passing kwargs to all models
This commit is contained in:
parent
a1d8f3c398
commit
b4091cb5ce
|
@ -57,6 +57,10 @@ contains the API docs for the underlying classes.
|
|||
|
||||
[[autodoc]] VisitWebpageTool
|
||||
|
||||
### UserInputTool
|
||||
|
||||
[[autodoc]] UserInputTool
|
||||
|
||||
## ToolCollection
|
||||
|
||||
[[autodoc]] ToolCollection
|
||||
|
|
|
@ -144,7 +144,7 @@ class UserInputTool(Tool):
|
|||
output_type = "string"
|
||||
|
||||
def forward(self, question):
|
||||
user_input = input(f"{question} => ")
|
||||
user_input = input(f"{question} => Type your answer here:")
|
||||
return user_input
|
||||
|
||||
|
||||
|
|
|
@ -222,7 +222,6 @@ class Model:
|
|||
messages: List[Dict[str, str]],
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500,
|
||||
) -> ChatMessage:
|
||||
"""Process the input messages and return the model's response.
|
||||
|
||||
|
@ -233,8 +232,6 @@ class Model:
|
|||
A list of strings that will stop the generation if encountered in the model's output.
|
||||
grammar (`str`, *optional*):
|
||||
The grammar or formatting structure to use in the model's response.
|
||||
max_tokens (`int`, *optional*):
|
||||
The maximum count of tokens to generate.
|
||||
Returns:
|
||||
`str`: The text content of the model's response.
|
||||
"""
|
||||
|
@ -244,7 +241,7 @@ class Model:
|
|||
class HfApiModel(Model):
|
||||
"""A class to interact with Hugging Face's Inference API for language model interaction.
|
||||
|
||||
This engine allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
|
||||
This model allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
|
||||
|
||||
Parameters:
|
||||
model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`):
|
||||
|
@ -265,9 +262,10 @@ class HfApiModel(Model):
|
|||
>>> engine = HfApiModel(
|
||||
... model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
|
||||
... token="your_hf_token_here",
|
||||
... max_tokens=5000,
|
||||
... )
|
||||
>>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}]
|
||||
>>> response = engine(messages, stop_sequences=["END"], max_tokens=1500)
|
||||
>>> response = engine(messages, stop_sequences=["END"])
|
||||
>>> print(response)
|
||||
"Quantum mechanics is the branch of physics that studies..."
|
||||
```
|
||||
|
@ -279,6 +277,7 @@ class HfApiModel(Model):
|
|||
token: Optional[str] = None,
|
||||
timeout: Optional[int] = 120,
|
||||
temperature: float = 0.5,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.model_id = model_id
|
||||
|
@ -286,13 +285,13 @@ class HfApiModel(Model):
|
|||
token = os.getenv("HF_TOKEN")
|
||||
self.client = InferenceClient(self.model_id, token=token, timeout=timeout)
|
||||
self.temperature = temperature
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500,
|
||||
tools_to_call_from: Optional[List[Tool]] = None,
|
||||
) -> ChatMessage:
|
||||
"""
|
||||
|
@ -308,16 +307,16 @@ class HfApiModel(Model):
|
|||
tools=[get_json_schema(tool) for tool in tools_to_call_from],
|
||||
tool_choice="auto",
|
||||
stop=stop_sequences,
|
||||
max_tokens=max_tokens,
|
||||
temperature=self.temperature,
|
||||
**self.kwargs,
|
||||
)
|
||||
else:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_id,
|
||||
messages=messages,
|
||||
stop=stop_sequences,
|
||||
max_tokens=max_tokens,
|
||||
temperature=self.temperature,
|
||||
**self.kwargs,
|
||||
)
|
||||
self.last_input_token_count = response.usage.prompt_tokens
|
||||
self.last_output_token_count = response.usage.completion_tokens
|
||||
|
@ -325,16 +324,44 @@ class HfApiModel(Model):
|
|||
|
||||
|
||||
class TransformersModel(Model):
|
||||
"""This engine initializes a model and tokenizer from the given `model_id`.
|
||||
"""A class to interact with Hugging Face's Inference API for language model interaction.
|
||||
|
||||
This model allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
|
||||
|
||||
Parameters:
|
||||
model_id (`str`, *optional*, defaults to `"HuggingFaceTB/SmolLM2-1.7B-Instruct"`):
|
||||
model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`):
|
||||
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
|
||||
device (`str`, optional, defaults to `"cuda"` if available, else `"cpu"`.):
|
||||
The device to load the model on (`"cpu"` or `"cuda"`).
|
||||
device_map (`str`, *optional*):
|
||||
The device_map to initialize your model with.
|
||||
torch_dtype (`str`, *optional*):
|
||||
The torch_dtype to initialize your model with.
|
||||
kwargs (dict, *optional*):
|
||||
Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`.
|
||||
Raises:
|
||||
ValueError:
|
||||
If the model name is not provided.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> engine = TransformersModel(
|
||||
... model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
|
||||
... device="cuda",
|
||||
... max_new_tokens=5000,
|
||||
... )
|
||||
>>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}]
|
||||
>>> response = engine(messages, stop_sequences=["END"])
|
||||
>>> print(response)
|
||||
"Quantum mechanics is the branch of physics that studies..."
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: Optional[str] = None,
|
||||
device_map: Optional[str] = None,
|
||||
torch_dtype: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
if not is_torch_available():
|
||||
raise ImportError("Please install torch in order to use TransformersModel.")
|
||||
|
@ -347,14 +374,14 @@ class TransformersModel(Model):
|
|||
f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'"
|
||||
)
|
||||
self.model_id = model_id
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.device = device
|
||||
logger.info(f"Using device: {self.device}")
|
||||
self.kwargs = kwargs
|
||||
if device_map is None:
|
||||
device_map = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"Using device: {device_map}")
|
||||
try:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, device_map=self.device
|
||||
model_id, device_map=device_map, torch_dtype=torch_dtype
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
|
@ -363,7 +390,7 @@ class TransformersModel(Model):
|
|||
self.model_id = default_model_id
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, device_map=self.device
|
||||
model_id, device_map=device_map, torch_dtype=torch_dtype
|
||||
)
|
||||
|
||||
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
|
||||
|
@ -397,7 +424,6 @@ class TransformersModel(Model):
|
|||
messages: List[Dict[str, str]],
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500,
|
||||
tools_to_call_from: Optional[List[Tool]] = None,
|
||||
) -> ChatMessage:
|
||||
messages = get_clean_message_list(
|
||||
|
@ -422,10 +448,10 @@ class TransformersModel(Model):
|
|||
|
||||
out = self.model.generate(
|
||||
**prompt_tensor,
|
||||
max_new_tokens=max_tokens,
|
||||
stopping_criteria=(
|
||||
self.make_stopping_criteria(stop_sequences) if stop_sequences else None
|
||||
),
|
||||
**self.kwargs,
|
||||
)
|
||||
generated_tokens = out[0, count_prompt_tokens:]
|
||||
output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||
|
@ -458,6 +484,19 @@ class TransformersModel(Model):
|
|||
|
||||
|
||||
class LiteLLMModel(Model):
|
||||
"""This model connects to [LiteLLM](https://www.litellm.ai/) as a gateway to hundreds of LLMs.
|
||||
|
||||
Parameters:
|
||||
model_id (`str`):
|
||||
The model identifier to use on the server (e.g. "gpt-3.5-turbo").
|
||||
api_base (`str`):
|
||||
The base URL of the OpenAI-compatible API server.
|
||||
api_key (`str`):
|
||||
The API key to use for authentication.
|
||||
**kwargs:
|
||||
Additional keyword arguments to pass to the OpenAI API.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id="anthropic/claude-3-5-sonnet-20240620",
|
||||
|
@ -482,7 +521,6 @@ class LiteLLMModel(Model):
|
|||
messages: List[Dict[str, str]],
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500,
|
||||
tools_to_call_from: Optional[List[Tool]] = None,
|
||||
) -> ChatMessage:
|
||||
messages = get_clean_message_list(
|
||||
|
@ -495,7 +533,6 @@ class LiteLLMModel(Model):
|
|||
tools=[get_json_schema(tool) for tool in tools_to_call_from],
|
||||
tool_choice="required",
|
||||
stop=stop_sequences,
|
||||
max_tokens=max_tokens,
|
||||
api_base=self.api_base,
|
||||
api_key=self.api_key,
|
||||
**self.kwargs,
|
||||
|
@ -505,7 +542,6 @@ class LiteLLMModel(Model):
|
|||
model=self.model_id,
|
||||
messages=messages,
|
||||
stop=stop_sequences,
|
||||
max_tokens=max_tokens,
|
||||
api_base=self.api_base,
|
||||
api_key=self.api_key,
|
||||
**self.kwargs,
|
||||
|
@ -516,7 +552,7 @@ class LiteLLMModel(Model):
|
|||
|
||||
|
||||
class OpenAIServerModel(Model):
|
||||
"""This engine connects to an OpenAI-compatible API server.
|
||||
"""This model connects to an OpenAI-compatible API server.
|
||||
|
||||
Parameters:
|
||||
model_id (`str`):
|
||||
|
@ -525,8 +561,6 @@ class OpenAIServerModel(Model):
|
|||
The base URL of the OpenAI-compatible API server.
|
||||
api_key (`str`):
|
||||
The API key to use for authentication.
|
||||
temperature (`float`, *optional*, defaults to 0.7):
|
||||
Controls randomness in the model's responses. Values between 0 and 2.
|
||||
**kwargs:
|
||||
Additional keyword arguments to pass to the OpenAI API.
|
||||
"""
|
||||
|
@ -536,7 +570,6 @@ class OpenAIServerModel(Model):
|
|||
model_id: str,
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
temperature: float = 0.7,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -545,7 +578,6 @@ class OpenAIServerModel(Model):
|
|||
base_url=api_base,
|
||||
api_key=api_key,
|
||||
)
|
||||
self.temperature = temperature
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __call__(
|
||||
|
@ -553,7 +585,6 @@ class OpenAIServerModel(Model):
|
|||
messages: List[Dict[str, str]],
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500,
|
||||
tools_to_call_from: Optional[List[Tool]] = None,
|
||||
) -> ChatMessage:
|
||||
messages = get_clean_message_list(
|
||||
|
@ -566,8 +597,6 @@ class OpenAIServerModel(Model):
|
|||
tools=[get_json_schema(tool) for tool in tools_to_call_from],
|
||||
tool_choice="auto",
|
||||
stop=stop_sequences,
|
||||
max_tokens=max_tokens,
|
||||
temperature=self.temperature,
|
||||
**self.kwargs,
|
||||
)
|
||||
else:
|
||||
|
@ -575,8 +604,6 @@ class OpenAIServerModel(Model):
|
|||
model=self.model_id,
|
||||
messages=messages,
|
||||
stop=stop_sequences,
|
||||
max_tokens=max_tokens,
|
||||
temperature=self.temperature,
|
||||
**self.kwargs,
|
||||
)
|
||||
self.last_input_token_count = response.usage.prompt_tokens
|
||||
|
|
|
@ -16,7 +16,7 @@ import unittest
|
|||
import json
|
||||
from typing import Optional
|
||||
|
||||
from smolagents import models, tool, ChatMessage, HfApiModel
|
||||
from smolagents import models, tool, ChatMessage, HfApiModel, TransformersModel
|
||||
|
||||
|
||||
class ModelTests(unittest.TestCase):
|
||||
|
@ -46,6 +46,17 @@ class ModelTests(unittest.TestCase):
|
|||
assert data["content"] == "Hello!"
|
||||
|
||||
def test_get_hfapi_message_no_tool(self):
|
||||
model = HfApiModel()
|
||||
model = HfApiModel(max_tokens=10)
|
||||
messages = [{"role": "user", "content": "Hello!"}]
|
||||
model(messages, stop_sequences=["great"])
|
||||
|
||||
def test_transformers_message_no_tool(self):
|
||||
model = TransformersModel(
|
||||
model_id="HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||
max_new_tokens=5,
|
||||
device_map="auto",
|
||||
do_sample=False,
|
||||
)
|
||||
messages = [{"role": "user", "content": "Hello!"}]
|
||||
output = model(messages, stop_sequences=["great"]).content
|
||||
assert output == "assistant\nHello"
|
||||
|
|
Loading…
Reference in New Issue