From b4091cb5ce234bf0ff531abfe36035fcbb3323f7 Mon Sep 17 00:00:00 2001 From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> Date: Thu, 16 Jan 2025 23:03:38 +0100 Subject: [PATCH] Allow passing kwargs to all models (#222) * Allow passing kwargs to all models --- docs/source/en/reference/tools.md | 4 ++ src/smolagents/default_tools.py | 2 +- src/smolagents/models.py | 95 ++++++++++++++++++++----------- tests/test_models.py | 15 ++++- 4 files changed, 79 insertions(+), 37 deletions(-) diff --git a/docs/source/en/reference/tools.md b/docs/source/en/reference/tools.md index 022ad35..9d78774 100644 --- a/docs/source/en/reference/tools.md +++ b/docs/source/en/reference/tools.md @@ -57,6 +57,10 @@ contains the API docs for the underlying classes. [[autodoc]] VisitWebpageTool +### UserInputTool + +[[autodoc]] UserInputTool + ## ToolCollection [[autodoc]] ToolCollection diff --git a/src/smolagents/default_tools.py b/src/smolagents/default_tools.py index 3a6be33..14a46ae 100644 --- a/src/smolagents/default_tools.py +++ b/src/smolagents/default_tools.py @@ -144,7 +144,7 @@ class UserInputTool(Tool): output_type = "string" def forward(self, question): - user_input = input(f"{question} => ") + user_input = input(f"{question} => Type your answer here:") return user_input diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 6a1ae77..111fa0c 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -222,7 +222,6 @@ class Model: messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None, - max_tokens: int = 1500, ) -> ChatMessage: """Process the input messages and return the model's response. @@ -233,8 +232,6 @@ class Model: A list of strings that will stop the generation if encountered in the model's output. grammar (`str`, *optional*): The grammar or formatting structure to use in the model's response. - max_tokens (`int`, *optional*): - The maximum count of tokens to generate. Returns: `str`: The text content of the model's response. """ @@ -244,7 +241,7 @@ class Model: class HfApiModel(Model): """A class to interact with Hugging Face's Inference API for language model interaction. - This engine allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization. + This model allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization. Parameters: model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`): @@ -265,9 +262,10 @@ class HfApiModel(Model): >>> engine = HfApiModel( ... model_id="Qwen/Qwen2.5-Coder-32B-Instruct", ... token="your_hf_token_here", + ... max_tokens=5000, ... ) >>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}] - >>> response = engine(messages, stop_sequences=["END"], max_tokens=1500) + >>> response = engine(messages, stop_sequences=["END"]) >>> print(response) "Quantum mechanics is the branch of physics that studies..." ``` @@ -279,6 +277,7 @@ class HfApiModel(Model): token: Optional[str] = None, timeout: Optional[int] = 120, temperature: float = 0.5, + **kwargs, ): super().__init__() self.model_id = model_id @@ -286,13 +285,13 @@ class HfApiModel(Model): token = os.getenv("HF_TOKEN") self.client = InferenceClient(self.model_id, token=token, timeout=timeout) self.temperature = temperature + self.kwargs = kwargs def __call__( self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None, - max_tokens: int = 1500, tools_to_call_from: Optional[List[Tool]] = None, ) -> ChatMessage: """ @@ -308,16 +307,16 @@ class HfApiModel(Model): tools=[get_json_schema(tool) for tool in tools_to_call_from], tool_choice="auto", stop=stop_sequences, - max_tokens=max_tokens, temperature=self.temperature, + **self.kwargs, ) else: response = self.client.chat.completions.create( model=self.model_id, messages=messages, stop=stop_sequences, - max_tokens=max_tokens, temperature=self.temperature, + **self.kwargs, ) self.last_input_token_count = response.usage.prompt_tokens self.last_output_token_count = response.usage.completion_tokens @@ -325,16 +324,44 @@ class HfApiModel(Model): class TransformersModel(Model): - """This engine initializes a model and tokenizer from the given `model_id`. + """A class to interact with Hugging Face's Inference API for language model interaction. + + This model allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization. Parameters: - model_id (`str`, *optional*, defaults to `"HuggingFaceTB/SmolLM2-1.7B-Instruct"`): + model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`): The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub. - device (`str`, optional, defaults to `"cuda"` if available, else `"cpu"`.): - The device to load the model on (`"cpu"` or `"cuda"`). + device_map (`str`, *optional*): + The device_map to initialize your model with. + torch_dtype (`str`, *optional*): + The torch_dtype to initialize your model with. + kwargs (dict, *optional*): + Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`. + Raises: + ValueError: + If the model name is not provided. + + Example: + ```python + >>> engine = TransformersModel( + ... model_id="Qwen/Qwen2.5-Coder-32B-Instruct", + ... device="cuda", + ... max_new_tokens=5000, + ... ) + >>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}] + >>> response = engine(messages, stop_sequences=["END"]) + >>> print(response) + "Quantum mechanics is the branch of physics that studies..." + ``` """ - def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None): + def __init__( + self, + model_id: Optional[str] = None, + device_map: Optional[str] = None, + torch_dtype: Optional[str] = None, + **kwargs, + ): super().__init__() if not is_torch_available(): raise ImportError("Please install torch in order to use TransformersModel.") @@ -347,14 +374,14 @@ class TransformersModel(Model): f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'" ) self.model_id = model_id - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - self.device = device - logger.info(f"Using device: {self.device}") + self.kwargs = kwargs + if device_map is None: + device_map = "cuda" if torch.cuda.is_available() else "cpu" + logger.info(f"Using device: {device_map}") try: self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.model = AutoModelForCausalLM.from_pretrained( - model_id, device_map=self.device + model_id, device_map=device_map, torch_dtype=torch_dtype ) except Exception as e: logger.warning( @@ -363,7 +390,7 @@ class TransformersModel(Model): self.model_id = default_model_id self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) self.model = AutoModelForCausalLM.from_pretrained( - model_id, device_map=self.device + model_id, device_map=device_map, torch_dtype=torch_dtype ) def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: @@ -397,7 +424,6 @@ class TransformersModel(Model): messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None, - max_tokens: int = 1500, tools_to_call_from: Optional[List[Tool]] = None, ) -> ChatMessage: messages = get_clean_message_list( @@ -422,10 +448,10 @@ class TransformersModel(Model): out = self.model.generate( **prompt_tensor, - max_new_tokens=max_tokens, stopping_criteria=( self.make_stopping_criteria(stop_sequences) if stop_sequences else None ), + **self.kwargs, ) generated_tokens = out[0, count_prompt_tokens:] output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) @@ -458,6 +484,19 @@ class TransformersModel(Model): class LiteLLMModel(Model): + """This model connects to [LiteLLM](https://www.litellm.ai/) as a gateway to hundreds of LLMs. + + Parameters: + model_id (`str`): + The model identifier to use on the server (e.g. "gpt-3.5-turbo"). + api_base (`str`): + The base URL of the OpenAI-compatible API server. + api_key (`str`): + The API key to use for authentication. + **kwargs: + Additional keyword arguments to pass to the OpenAI API. + """ + def __init__( self, model_id="anthropic/claude-3-5-sonnet-20240620", @@ -482,7 +521,6 @@ class LiteLLMModel(Model): messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None, - max_tokens: int = 1500, tools_to_call_from: Optional[List[Tool]] = None, ) -> ChatMessage: messages = get_clean_message_list( @@ -495,7 +533,6 @@ class LiteLLMModel(Model): tools=[get_json_schema(tool) for tool in tools_to_call_from], tool_choice="required", stop=stop_sequences, - max_tokens=max_tokens, api_base=self.api_base, api_key=self.api_key, **self.kwargs, @@ -505,7 +542,6 @@ class LiteLLMModel(Model): model=self.model_id, messages=messages, stop=stop_sequences, - max_tokens=max_tokens, api_base=self.api_base, api_key=self.api_key, **self.kwargs, @@ -516,7 +552,7 @@ class LiteLLMModel(Model): class OpenAIServerModel(Model): - """This engine connects to an OpenAI-compatible API server. + """This model connects to an OpenAI-compatible API server. Parameters: model_id (`str`): @@ -525,8 +561,6 @@ class OpenAIServerModel(Model): The base URL of the OpenAI-compatible API server. api_key (`str`): The API key to use for authentication. - temperature (`float`, *optional*, defaults to 0.7): - Controls randomness in the model's responses. Values between 0 and 2. **kwargs: Additional keyword arguments to pass to the OpenAI API. """ @@ -536,7 +570,6 @@ class OpenAIServerModel(Model): model_id: str, api_base: str, api_key: str, - temperature: float = 0.7, **kwargs, ): super().__init__() @@ -545,7 +578,6 @@ class OpenAIServerModel(Model): base_url=api_base, api_key=api_key, ) - self.temperature = temperature self.kwargs = kwargs def __call__( @@ -553,7 +585,6 @@ class OpenAIServerModel(Model): messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None, - max_tokens: int = 1500, tools_to_call_from: Optional[List[Tool]] = None, ) -> ChatMessage: messages = get_clean_message_list( @@ -566,8 +597,6 @@ class OpenAIServerModel(Model): tools=[get_json_schema(tool) for tool in tools_to_call_from], tool_choice="auto", stop=stop_sequences, - max_tokens=max_tokens, - temperature=self.temperature, **self.kwargs, ) else: @@ -575,8 +604,6 @@ class OpenAIServerModel(Model): model=self.model_id, messages=messages, stop=stop_sequences, - max_tokens=max_tokens, - temperature=self.temperature, **self.kwargs, ) self.last_input_token_count = response.usage.prompt_tokens diff --git a/tests/test_models.py b/tests/test_models.py index 5a3a821..9921631 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -16,7 +16,7 @@ import unittest import json from typing import Optional -from smolagents import models, tool, ChatMessage, HfApiModel +from smolagents import models, tool, ChatMessage, HfApiModel, TransformersModel class ModelTests(unittest.TestCase): @@ -46,6 +46,17 @@ class ModelTests(unittest.TestCase): assert data["content"] == "Hello!" def test_get_hfapi_message_no_tool(self): - model = HfApiModel() + model = HfApiModel(max_tokens=10) messages = [{"role": "user", "content": "Hello!"}] model(messages, stop_sequences=["great"]) + + def test_transformers_message_no_tool(self): + model = TransformersModel( + model_id="HuggingFaceTB/SmolLM2-135M-Instruct", + max_new_tokens=5, + device_map="auto", + do_sample=False, + ) + messages = [{"role": "user", "content": "Hello!"}] + output = model(messages, stop_sequences=["great"]).content + assert output == "assistant\nHello"