Allow passing kwargs to all models (#222)
* Allow passing kwargs to all models
This commit is contained in:
parent
a1d8f3c398
commit
b4091cb5ce
|
@ -57,6 +57,10 @@ contains the API docs for the underlying classes.
|
||||||
|
|
||||||
[[autodoc]] VisitWebpageTool
|
[[autodoc]] VisitWebpageTool
|
||||||
|
|
||||||
|
### UserInputTool
|
||||||
|
|
||||||
|
[[autodoc]] UserInputTool
|
||||||
|
|
||||||
## ToolCollection
|
## ToolCollection
|
||||||
|
|
||||||
[[autodoc]] ToolCollection
|
[[autodoc]] ToolCollection
|
||||||
|
|
|
@ -144,7 +144,7 @@ class UserInputTool(Tool):
|
||||||
output_type = "string"
|
output_type = "string"
|
||||||
|
|
||||||
def forward(self, question):
|
def forward(self, question):
|
||||||
user_input = input(f"{question} => ")
|
user_input = input(f"{question} => Type your answer here:")
|
||||||
return user_input
|
return user_input
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -222,7 +222,6 @@ class Model:
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
stop_sequences: Optional[List[str]] = None,
|
stop_sequences: Optional[List[str]] = None,
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
max_tokens: int = 1500,
|
|
||||||
) -> ChatMessage:
|
) -> ChatMessage:
|
||||||
"""Process the input messages and return the model's response.
|
"""Process the input messages and return the model's response.
|
||||||
|
|
||||||
|
@ -233,8 +232,6 @@ class Model:
|
||||||
A list of strings that will stop the generation if encountered in the model's output.
|
A list of strings that will stop the generation if encountered in the model's output.
|
||||||
grammar (`str`, *optional*):
|
grammar (`str`, *optional*):
|
||||||
The grammar or formatting structure to use in the model's response.
|
The grammar or formatting structure to use in the model's response.
|
||||||
max_tokens (`int`, *optional*):
|
|
||||||
The maximum count of tokens to generate.
|
|
||||||
Returns:
|
Returns:
|
||||||
`str`: The text content of the model's response.
|
`str`: The text content of the model's response.
|
||||||
"""
|
"""
|
||||||
|
@ -244,7 +241,7 @@ class Model:
|
||||||
class HfApiModel(Model):
|
class HfApiModel(Model):
|
||||||
"""A class to interact with Hugging Face's Inference API for language model interaction.
|
"""A class to interact with Hugging Face's Inference API for language model interaction.
|
||||||
|
|
||||||
This engine allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
|
This model allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`):
|
model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`):
|
||||||
|
@ -265,9 +262,10 @@ class HfApiModel(Model):
|
||||||
>>> engine = HfApiModel(
|
>>> engine = HfApiModel(
|
||||||
... model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
|
... model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
|
||||||
... token="your_hf_token_here",
|
... token="your_hf_token_here",
|
||||||
|
... max_tokens=5000,
|
||||||
... )
|
... )
|
||||||
>>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}]
|
>>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}]
|
||||||
>>> response = engine(messages, stop_sequences=["END"], max_tokens=1500)
|
>>> response = engine(messages, stop_sequences=["END"])
|
||||||
>>> print(response)
|
>>> print(response)
|
||||||
"Quantum mechanics is the branch of physics that studies..."
|
"Quantum mechanics is the branch of physics that studies..."
|
||||||
```
|
```
|
||||||
|
@ -279,6 +277,7 @@ class HfApiModel(Model):
|
||||||
token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
timeout: Optional[int] = 120,
|
timeout: Optional[int] = 120,
|
||||||
temperature: float = 0.5,
|
temperature: float = 0.5,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
@ -286,13 +285,13 @@ class HfApiModel(Model):
|
||||||
token = os.getenv("HF_TOKEN")
|
token = os.getenv("HF_TOKEN")
|
||||||
self.client = InferenceClient(self.model_id, token=token, timeout=timeout)
|
self.client = InferenceClient(self.model_id, token=token, timeout=timeout)
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
stop_sequences: Optional[List[str]] = None,
|
stop_sequences: Optional[List[str]] = None,
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
max_tokens: int = 1500,
|
|
||||||
tools_to_call_from: Optional[List[Tool]] = None,
|
tools_to_call_from: Optional[List[Tool]] = None,
|
||||||
) -> ChatMessage:
|
) -> ChatMessage:
|
||||||
"""
|
"""
|
||||||
|
@ -308,16 +307,16 @@ class HfApiModel(Model):
|
||||||
tools=[get_json_schema(tool) for tool in tools_to_call_from],
|
tools=[get_json_schema(tool) for tool in tools_to_call_from],
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
stop=stop_sequences,
|
stop=stop_sequences,
|
||||||
max_tokens=max_tokens,
|
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
|
**self.kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=self.model_id,
|
model=self.model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stop=stop_sequences,
|
stop=stop_sequences,
|
||||||
max_tokens=max_tokens,
|
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
|
**self.kwargs,
|
||||||
)
|
)
|
||||||
self.last_input_token_count = response.usage.prompt_tokens
|
self.last_input_token_count = response.usage.prompt_tokens
|
||||||
self.last_output_token_count = response.usage.completion_tokens
|
self.last_output_token_count = response.usage.completion_tokens
|
||||||
|
@ -325,16 +324,44 @@ class HfApiModel(Model):
|
||||||
|
|
||||||
|
|
||||||
class TransformersModel(Model):
|
class TransformersModel(Model):
|
||||||
"""This engine initializes a model and tokenizer from the given `model_id`.
|
"""A class to interact with Hugging Face's Inference API for language model interaction.
|
||||||
|
|
||||||
|
This model allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
model_id (`str`, *optional*, defaults to `"HuggingFaceTB/SmolLM2-1.7B-Instruct"`):
|
model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`):
|
||||||
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
|
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
|
||||||
device (`str`, optional, defaults to `"cuda"` if available, else `"cpu"`.):
|
device_map (`str`, *optional*):
|
||||||
The device to load the model on (`"cpu"` or `"cuda"`).
|
The device_map to initialize your model with.
|
||||||
|
torch_dtype (`str`, *optional*):
|
||||||
|
The torch_dtype to initialize your model with.
|
||||||
|
kwargs (dict, *optional*):
|
||||||
|
Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`.
|
||||||
|
Raises:
|
||||||
|
ValueError:
|
||||||
|
If the model name is not provided.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
>>> engine = TransformersModel(
|
||||||
|
... model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
|
||||||
|
... device="cuda",
|
||||||
|
... max_new_tokens=5000,
|
||||||
|
... )
|
||||||
|
>>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}]
|
||||||
|
>>> response = engine(messages, stop_sequences=["END"])
|
||||||
|
>>> print(response)
|
||||||
|
"Quantum mechanics is the branch of physics that studies..."
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: Optional[str] = None,
|
||||||
|
device_map: Optional[str] = None,
|
||||||
|
torch_dtype: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
raise ImportError("Please install torch in order to use TransformersModel.")
|
raise ImportError("Please install torch in order to use TransformersModel.")
|
||||||
|
@ -347,14 +374,14 @@ class TransformersModel(Model):
|
||||||
f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'"
|
f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'"
|
||||||
)
|
)
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
if device is None:
|
self.kwargs = kwargs
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
if device_map is None:
|
||||||
self.device = device
|
device_map = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
logger.info(f"Using device: {self.device}")
|
logger.info(f"Using device: {device_map}")
|
||||||
try:
|
try:
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_id, device_map=self.device
|
model_id, device_map=device_map, torch_dtype=torch_dtype
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
@ -363,7 +390,7 @@ class TransformersModel(Model):
|
||||||
self.model_id = default_model_id
|
self.model_id = default_model_id
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
|
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_id, device_map=self.device
|
model_id, device_map=device_map, torch_dtype=torch_dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
|
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
|
||||||
|
@ -397,7 +424,6 @@ class TransformersModel(Model):
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
stop_sequences: Optional[List[str]] = None,
|
stop_sequences: Optional[List[str]] = None,
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
max_tokens: int = 1500,
|
|
||||||
tools_to_call_from: Optional[List[Tool]] = None,
|
tools_to_call_from: Optional[List[Tool]] = None,
|
||||||
) -> ChatMessage:
|
) -> ChatMessage:
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
|
@ -422,10 +448,10 @@ class TransformersModel(Model):
|
||||||
|
|
||||||
out = self.model.generate(
|
out = self.model.generate(
|
||||||
**prompt_tensor,
|
**prompt_tensor,
|
||||||
max_new_tokens=max_tokens,
|
|
||||||
stopping_criteria=(
|
stopping_criteria=(
|
||||||
self.make_stopping_criteria(stop_sequences) if stop_sequences else None
|
self.make_stopping_criteria(stop_sequences) if stop_sequences else None
|
||||||
),
|
),
|
||||||
|
**self.kwargs,
|
||||||
)
|
)
|
||||||
generated_tokens = out[0, count_prompt_tokens:]
|
generated_tokens = out[0, count_prompt_tokens:]
|
||||||
output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||||
|
@ -458,6 +484,19 @@ class TransformersModel(Model):
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMModel(Model):
|
class LiteLLMModel(Model):
|
||||||
|
"""This model connects to [LiteLLM](https://www.litellm.ai/) as a gateway to hundreds of LLMs.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
model_id (`str`):
|
||||||
|
The model identifier to use on the server (e.g. "gpt-3.5-turbo").
|
||||||
|
api_base (`str`):
|
||||||
|
The base URL of the OpenAI-compatible API server.
|
||||||
|
api_key (`str`):
|
||||||
|
The API key to use for authentication.
|
||||||
|
**kwargs:
|
||||||
|
Additional keyword arguments to pass to the OpenAI API.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_id="anthropic/claude-3-5-sonnet-20240620",
|
model_id="anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
@ -482,7 +521,6 @@ class LiteLLMModel(Model):
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
stop_sequences: Optional[List[str]] = None,
|
stop_sequences: Optional[List[str]] = None,
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
max_tokens: int = 1500,
|
|
||||||
tools_to_call_from: Optional[List[Tool]] = None,
|
tools_to_call_from: Optional[List[Tool]] = None,
|
||||||
) -> ChatMessage:
|
) -> ChatMessage:
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
|
@ -495,7 +533,6 @@ class LiteLLMModel(Model):
|
||||||
tools=[get_json_schema(tool) for tool in tools_to_call_from],
|
tools=[get_json_schema(tool) for tool in tools_to_call_from],
|
||||||
tool_choice="required",
|
tool_choice="required",
|
||||||
stop=stop_sequences,
|
stop=stop_sequences,
|
||||||
max_tokens=max_tokens,
|
|
||||||
api_base=self.api_base,
|
api_base=self.api_base,
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
**self.kwargs,
|
**self.kwargs,
|
||||||
|
@ -505,7 +542,6 @@ class LiteLLMModel(Model):
|
||||||
model=self.model_id,
|
model=self.model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stop=stop_sequences,
|
stop=stop_sequences,
|
||||||
max_tokens=max_tokens,
|
|
||||||
api_base=self.api_base,
|
api_base=self.api_base,
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
**self.kwargs,
|
**self.kwargs,
|
||||||
|
@ -516,7 +552,7 @@ class LiteLLMModel(Model):
|
||||||
|
|
||||||
|
|
||||||
class OpenAIServerModel(Model):
|
class OpenAIServerModel(Model):
|
||||||
"""This engine connects to an OpenAI-compatible API server.
|
"""This model connects to an OpenAI-compatible API server.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
model_id (`str`):
|
model_id (`str`):
|
||||||
|
@ -525,8 +561,6 @@ class OpenAIServerModel(Model):
|
||||||
The base URL of the OpenAI-compatible API server.
|
The base URL of the OpenAI-compatible API server.
|
||||||
api_key (`str`):
|
api_key (`str`):
|
||||||
The API key to use for authentication.
|
The API key to use for authentication.
|
||||||
temperature (`float`, *optional*, defaults to 0.7):
|
|
||||||
Controls randomness in the model's responses. Values between 0 and 2.
|
|
||||||
**kwargs:
|
**kwargs:
|
||||||
Additional keyword arguments to pass to the OpenAI API.
|
Additional keyword arguments to pass to the OpenAI API.
|
||||||
"""
|
"""
|
||||||
|
@ -536,7 +570,6 @@ class OpenAIServerModel(Model):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
api_base: str,
|
api_base: str,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
temperature: float = 0.7,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -545,7 +578,6 @@ class OpenAIServerModel(Model):
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
self.temperature = temperature
|
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
|
@ -553,7 +585,6 @@ class OpenAIServerModel(Model):
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
stop_sequences: Optional[List[str]] = None,
|
stop_sequences: Optional[List[str]] = None,
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
max_tokens: int = 1500,
|
|
||||||
tools_to_call_from: Optional[List[Tool]] = None,
|
tools_to_call_from: Optional[List[Tool]] = None,
|
||||||
) -> ChatMessage:
|
) -> ChatMessage:
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
|
@ -566,8 +597,6 @@ class OpenAIServerModel(Model):
|
||||||
tools=[get_json_schema(tool) for tool in tools_to_call_from],
|
tools=[get_json_schema(tool) for tool in tools_to_call_from],
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
stop=stop_sequences,
|
stop=stop_sequences,
|
||||||
max_tokens=max_tokens,
|
|
||||||
temperature=self.temperature,
|
|
||||||
**self.kwargs,
|
**self.kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -575,8 +604,6 @@ class OpenAIServerModel(Model):
|
||||||
model=self.model_id,
|
model=self.model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stop=stop_sequences,
|
stop=stop_sequences,
|
||||||
max_tokens=max_tokens,
|
|
||||||
temperature=self.temperature,
|
|
||||||
**self.kwargs,
|
**self.kwargs,
|
||||||
)
|
)
|
||||||
self.last_input_token_count = response.usage.prompt_tokens
|
self.last_input_token_count = response.usage.prompt_tokens
|
||||||
|
|
|
@ -16,7 +16,7 @@ import unittest
|
||||||
import json
|
import json
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from smolagents import models, tool, ChatMessage, HfApiModel
|
from smolagents import models, tool, ChatMessage, HfApiModel, TransformersModel
|
||||||
|
|
||||||
|
|
||||||
class ModelTests(unittest.TestCase):
|
class ModelTests(unittest.TestCase):
|
||||||
|
@ -46,6 +46,17 @@ class ModelTests(unittest.TestCase):
|
||||||
assert data["content"] == "Hello!"
|
assert data["content"] == "Hello!"
|
||||||
|
|
||||||
def test_get_hfapi_message_no_tool(self):
|
def test_get_hfapi_message_no_tool(self):
|
||||||
model = HfApiModel()
|
model = HfApiModel(max_tokens=10)
|
||||||
messages = [{"role": "user", "content": "Hello!"}]
|
messages = [{"role": "user", "content": "Hello!"}]
|
||||||
model(messages, stop_sequences=["great"])
|
model(messages, stop_sequences=["great"])
|
||||||
|
|
||||||
|
def test_transformers_message_no_tool(self):
|
||||||
|
model = TransformersModel(
|
||||||
|
model_id="HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||||
|
max_new_tokens=5,
|
||||||
|
device_map="auto",
|
||||||
|
do_sample=False,
|
||||||
|
)
|
||||||
|
messages = [{"role": "user", "content": "Hello!"}]
|
||||||
|
output = model(messages, stop_sequences=["great"]).content
|
||||||
|
assert output == "assistant\nHello"
|
||||||
|
|
Loading…
Reference in New Issue