From 398c932250f4f0ecafc5bf9351b4166662cb121b Mon Sep 17 00:00:00 2001 From: kingdomad <34766852+kingdomad@users.noreply.github.com> Date: Wed, 22 Jan 2025 18:27:36 +0800 Subject: [PATCH] refactor(models): restructure model parameter handling (#227) * refactor(models): restructure model parameter handling - Introduce base-class level default parameters (temperature, max_tokens) - Optimize parameter handling: method args can override base config - Unify parameter handling across model implementations --- src/smolagents/models.py | 215 ++++++++++++++++++++++++--------------- 1 file changed, 133 insertions(+), 82 deletions(-) diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 29f08db..9240aed 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -196,9 +196,59 @@ def get_clean_message_list( class Model: - def __init__(self): + def __init__(self, **kwargs): self.last_input_token_count = None self.last_output_token_count = None + # Set default values for common parameters + kwargs.setdefault("temperature", 0.5) + kwargs.setdefault("max_tokens", 4096) + self.kwargs = kwargs + + def _prepare_completion_kwargs( + self, + messages: List[Dict[str, str]], + stop_sequences: Optional[List[str]] = None, + grammar: Optional[str] = None, + tools_to_call_from: Optional[List[Tool]] = None, + custom_role_conversions: Optional[Dict[str, str]] = None, + **kwargs, + ) -> Dict: + """ + Prepare parameters required for model invocation, handling parameter priorities. + + Parameter priority from high to low: + 1. Explicitly passed kwargs + 2. Specific parameters (stop_sequences, grammar, etc.) + 3. Default values in self.kwargs + """ + # Clean and standardize the message list + messages = get_clean_message_list(messages, role_conversions=custom_role_conversions or tool_role_conversions) + + # Use self.kwargs as the base configuration + completion_kwargs = { + **self.kwargs, + "messages": messages, + } + + # Handle specific parameters + if stop_sequences is not None: + completion_kwargs["stop"] = stop_sequences + if grammar is not None: + completion_kwargs["grammar"] = grammar + + # Handle tools parameter + if tools_to_call_from: + completion_kwargs.update( + { + "tools": [get_json_schema(tool) for tool in tools_to_call_from], + "tool_choice": "required", + } + ) + + # Finally, use the passed-in kwargs to override all settings + completion_kwargs.update(kwargs) + + return completion_kwargs def get_token_counts(self) -> Dict[str, int]: return { @@ -211,6 +261,8 @@ class Model: messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None, + tools_to_call_from: Optional[List[Tool]] = None, + **kwargs, ) -> ChatMessage: """Process the input messages and return the model's response. @@ -221,8 +273,13 @@ class Model: A list of strings that will stop the generation if encountered in the model's output. grammar (`str`, *optional*): The grammar or formatting structure to use in the model's response. + tools_to_call_from (`List[Tool]`, *optional*): + A list of tools that the model can use to generate responses. + **kwargs: + Additional keyword arguments to be passed to the underlying model. + Returns: - `str`: The text content of the model's response. + `ChatMessage`: A chat message object containing the model's response. """ pass # To be implemented in child classes! @@ -265,16 +322,13 @@ class HfApiModel(Model): model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct", token: Optional[str] = None, timeout: Optional[int] = 120, - temperature: float = 0.5, **kwargs, ): - super().__init__() + super().__init__(**kwargs) self.model_id = model_id if token is None: token = os.getenv("HF_TOKEN") self.client = InferenceClient(self.model_id, token=token, timeout=timeout) - self.temperature = temperature - self.kwargs = kwargs def __call__( self, @@ -282,29 +336,18 @@ class HfApiModel(Model): stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None, tools_to_call_from: Optional[List[Tool]] = None, + **kwargs, ) -> ChatMessage: - """ - Gets an LLM output message for the given list of input messages. - If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call. - """ - messages = get_clean_message_list(messages, role_conversions=tool_role_conversions) - if tools_to_call_from: - response = self.client.chat.completions.create( - messages=messages, - tools=[get_json_schema(tool) for tool in tools_to_call_from], - tool_choice="auto", - stop=stop_sequences, - temperature=self.temperature, - **self.kwargs, - ) - else: - response = self.client.chat.completions.create( - model=self.model_id, - messages=messages, - stop=stop_sequences, - temperature=self.temperature, - **self.kwargs, - ) + completion_kwargs = self._prepare_completion_kwargs( + messages=messages, + stop_sequences=stop_sequences, + grammar=grammar, + tools_to_call_from=tools_to_call_from, + **kwargs, + ) + + response = self.client.chat_completion(**completion_kwargs) + self.last_input_token_count = response.usage.prompt_tokens self.last_output_token_count = response.usage.completion_tokens message = ChatMessage.from_hf_api(response.choices[0].message) @@ -358,7 +401,7 @@ class TransformersModel(Model): trust_remote_code: bool = False, **kwargs, ): - super().__init__() + super().__init__(**kwargs) if not is_torch_available() or not _is_package_available("transformers"): raise ModuleNotFoundError( "Please install 'transformers' extra to use 'TransformersModel': `pip install 'smolagents[transformers]'`" @@ -418,12 +461,36 @@ class TransformersModel(Model): stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None, tools_to_call_from: Optional[List[Tool]] = None, + **kwargs, ) -> ChatMessage: - messages = get_clean_message_list(messages, role_conversions=tool_role_conversions) + completion_kwargs = self._prepare_completion_kwargs( + messages=messages, + stop_sequences=stop_sequences, + grammar=grammar, + tools_to_call_from=tools_to_call_from, + **kwargs, + ) + + messages = completion_kwargs.pop("messages") + stop_sequences = completion_kwargs.pop("stop", None) + + max_new_tokens = ( + kwargs.get("max_new_tokens") + or kwargs.get("max_tokens") + or self.kwargs.get("max_new_tokens") + or self.kwargs.get("max_tokens") + ) + + if max_new_tokens: + completion_kwargs["max_new_tokens"] = max_new_tokens + + if stop_sequences: + completion_kwargs["stopping_criteria"] = self.make_stopping_criteria(stop_sequences) + if tools_to_call_from is not None: prompt_tensor = self.tokenizer.apply_chat_template( messages, - tools=[get_json_schema(tool) for tool in tools_to_call_from], + tools=completion_kwargs.pop("tools", []), return_tensors="pt", return_dict=True, add_generation_prompt=True, @@ -434,14 +501,11 @@ class TransformersModel(Model): return_tensors="pt", return_dict=True, ) + prompt_tensor = prompt_tensor.to(self.model.device) count_prompt_tokens = prompt_tensor["input_ids"].shape[1] - out = self.model.generate( - **prompt_tensor, - stopping_criteria=(self.make_stopping_criteria(stop_sequences) if stop_sequences else None), - **self.kwargs, - ) + out = self.model.generate(**prompt_tensor, **completion_kwargs) generated_tokens = out[0, count_prompt_tokens:] output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) self.last_input_token_count = count_prompt_tokens @@ -449,6 +513,7 @@ class TransformersModel(Model): if stop_sequences is not None: output = remove_stop_sequences(output, stop_sequences) + if tools_to_call_from is None: return ChatMessage(role="assistant", content=output) else: @@ -498,13 +563,12 @@ class LiteLLMModel(Model): "Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`" ) - super().__init__() + super().__init__(**kwargs) self.model_id = model_id # IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs litellm.add_function_to_prompt = True self.api_base = api_base self.api_key = api_key - self.kwargs = kwargs def __call__( self, @@ -512,34 +576,28 @@ class LiteLLMModel(Model): stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None, tools_to_call_from: Optional[List[Tool]] = None, + **kwargs, ) -> ChatMessage: import litellm - messages = get_clean_message_list(messages, role_conversions=tool_role_conversions) + completion_kwargs = self._prepare_completion_kwargs( + messages=messages, + stop_sequences=stop_sequences, + grammar=grammar, + tools_to_call_from=tools_to_call_from, + model=self.model_id, + api_base=self.api_base, + api_key=self.api_key, + **kwargs, + ) + + response = litellm.completion(**completion_kwargs) - if tools_to_call_from: - response = litellm.completion( - model=self.model_id, - messages=messages, - tools=[get_json_schema(tool) for tool in tools_to_call_from], - tool_choice="required", - stop=stop_sequences, - api_base=self.api_base, - api_key=self.api_key, - **self.kwargs, - ) - else: - response = litellm.completion( - model=self.model_id, - messages=messages, - stop=stop_sequences, - api_base=self.api_base, - api_key=self.api_key, - **self.kwargs, - ) self.last_input_token_count = response.usage.prompt_tokens self.last_output_token_count = response.usage.completion_tokens - message = response.choices[0].message + + message = ChatMessage(**response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})) + if tools_to_call_from is not None: return parse_tool_args_if_needed(message) return message @@ -576,13 +634,13 @@ class OpenAIServerModel(Model): raise ModuleNotFoundError( "Please install 'openai' extra to use OpenAIServerModel: `pip install 'smolagents[openai]'`" ) from None - super().__init__() + + super().__init__(**kwargs) self.model_id = model_id self.client = openai.OpenAI( base_url=api_base, api_key=api_key, ) - self.kwargs = kwargs self.custom_role_conversions = custom_role_conversions def __call__( @@ -591,30 +649,23 @@ class OpenAIServerModel(Model): stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None, tools_to_call_from: Optional[List[Tool]] = None, + **kwargs, ) -> ChatMessage: - messages = get_clean_message_list( - messages, - role_conversions=(self.custom_role_conversions if self.custom_role_conversions else tool_role_conversions), + completion_kwargs = self._prepare_completion_kwargs( + messages=messages, + stop_sequences=stop_sequences, + grammar=grammar, + tools_to_call_from=tools_to_call_from, + model=self.model_id, + custom_role_conversions=self.custom_role_conversions, + **kwargs, ) - if tools_to_call_from: - response = self.client.chat.completions.create( - model=self.model_id, - messages=messages, - tools=[get_json_schema(tool) for tool in tools_to_call_from], - tool_choice="required", - stop=stop_sequences, - **self.kwargs, - ) - else: - response = self.client.chat.completions.create( - model=self.model_id, - messages=messages, - stop=stop_sequences, - **self.kwargs, - ) + + response = self.client.chat.completions.create(**completion_kwargs) self.last_input_token_count = response.usage.prompt_tokens self.last_output_token_count = response.usage.completion_tokens - message = response.choices[0].message + + message = ChatMessage(**response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})) if tools_to_call_from is not None: return parse_tool_args_if_needed(message) return message