refactor(models): restructure model parameter handling (#227)
* refactor(models): restructure model parameter handling - Introduce base-class level default parameters (temperature, max_tokens) - Optimize parameter handling: method args can override base config - Unify parameter handling across model implementations
This commit is contained in:
parent
117014d2e1
commit
398c932250
|
@ -196,9 +196,59 @@ def get_clean_message_list(
|
||||||
|
|
||||||
|
|
||||||
class Model:
|
class Model:
|
||||||
def __init__(self):
|
def __init__(self, **kwargs):
|
||||||
self.last_input_token_count = None
|
self.last_input_token_count = None
|
||||||
self.last_output_token_count = None
|
self.last_output_token_count = None
|
||||||
|
# Set default values for common parameters
|
||||||
|
kwargs.setdefault("temperature", 0.5)
|
||||||
|
kwargs.setdefault("max_tokens", 4096)
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def _prepare_completion_kwargs(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
grammar: Optional[str] = None,
|
||||||
|
tools_to_call_from: Optional[List[Tool]] = None,
|
||||||
|
custom_role_conversions: Optional[Dict[str, str]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
Prepare parameters required for model invocation, handling parameter priorities.
|
||||||
|
|
||||||
|
Parameter priority from high to low:
|
||||||
|
1. Explicitly passed kwargs
|
||||||
|
2. Specific parameters (stop_sequences, grammar, etc.)
|
||||||
|
3. Default values in self.kwargs
|
||||||
|
"""
|
||||||
|
# Clean and standardize the message list
|
||||||
|
messages = get_clean_message_list(messages, role_conversions=custom_role_conversions or tool_role_conversions)
|
||||||
|
|
||||||
|
# Use self.kwargs as the base configuration
|
||||||
|
completion_kwargs = {
|
||||||
|
**self.kwargs,
|
||||||
|
"messages": messages,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handle specific parameters
|
||||||
|
if stop_sequences is not None:
|
||||||
|
completion_kwargs["stop"] = stop_sequences
|
||||||
|
if grammar is not None:
|
||||||
|
completion_kwargs["grammar"] = grammar
|
||||||
|
|
||||||
|
# Handle tools parameter
|
||||||
|
if tools_to_call_from:
|
||||||
|
completion_kwargs.update(
|
||||||
|
{
|
||||||
|
"tools": [get_json_schema(tool) for tool in tools_to_call_from],
|
||||||
|
"tool_choice": "required",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Finally, use the passed-in kwargs to override all settings
|
||||||
|
completion_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
return completion_kwargs
|
||||||
|
|
||||||
def get_token_counts(self) -> Dict[str, int]:
|
def get_token_counts(self) -> Dict[str, int]:
|
||||||
return {
|
return {
|
||||||
|
@ -211,6 +261,8 @@ class Model:
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
stop_sequences: Optional[List[str]] = None,
|
stop_sequences: Optional[List[str]] = None,
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
|
tools_to_call_from: Optional[List[Tool]] = None,
|
||||||
|
**kwargs,
|
||||||
) -> ChatMessage:
|
) -> ChatMessage:
|
||||||
"""Process the input messages and return the model's response.
|
"""Process the input messages and return the model's response.
|
||||||
|
|
||||||
|
@ -221,8 +273,13 @@ class Model:
|
||||||
A list of strings that will stop the generation if encountered in the model's output.
|
A list of strings that will stop the generation if encountered in the model's output.
|
||||||
grammar (`str`, *optional*):
|
grammar (`str`, *optional*):
|
||||||
The grammar or formatting structure to use in the model's response.
|
The grammar or formatting structure to use in the model's response.
|
||||||
|
tools_to_call_from (`List[Tool]`, *optional*):
|
||||||
|
A list of tools that the model can use to generate responses.
|
||||||
|
**kwargs:
|
||||||
|
Additional keyword arguments to be passed to the underlying model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`str`: The text content of the model's response.
|
`ChatMessage`: A chat message object containing the model's response.
|
||||||
"""
|
"""
|
||||||
pass # To be implemented in child classes!
|
pass # To be implemented in child classes!
|
||||||
|
|
||||||
|
@ -265,16 +322,13 @@ class HfApiModel(Model):
|
||||||
model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct",
|
model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct",
|
||||||
token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
timeout: Optional[int] = 120,
|
timeout: Optional[int] = 120,
|
||||||
temperature: float = 0.5,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__(**kwargs)
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
if token is None:
|
if token is None:
|
||||||
token = os.getenv("HF_TOKEN")
|
token = os.getenv("HF_TOKEN")
|
||||||
self.client = InferenceClient(self.model_id, token=token, timeout=timeout)
|
self.client = InferenceClient(self.model_id, token=token, timeout=timeout)
|
||||||
self.temperature = temperature
|
|
||||||
self.kwargs = kwargs
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
@ -282,29 +336,18 @@ class HfApiModel(Model):
|
||||||
stop_sequences: Optional[List[str]] = None,
|
stop_sequences: Optional[List[str]] = None,
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
tools_to_call_from: Optional[List[Tool]] = None,
|
tools_to_call_from: Optional[List[Tool]] = None,
|
||||||
|
**kwargs,
|
||||||
) -> ChatMessage:
|
) -> ChatMessage:
|
||||||
"""
|
completion_kwargs = self._prepare_completion_kwargs(
|
||||||
Gets an LLM output message for the given list of input messages.
|
messages=messages,
|
||||||
If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call.
|
stop_sequences=stop_sequences,
|
||||||
"""
|
grammar=grammar,
|
||||||
messages = get_clean_message_list(messages, role_conversions=tool_role_conversions)
|
tools_to_call_from=tools_to_call_from,
|
||||||
if tools_to_call_from:
|
**kwargs,
|
||||||
response = self.client.chat.completions.create(
|
)
|
||||||
messages=messages,
|
|
||||||
tools=[get_json_schema(tool) for tool in tools_to_call_from],
|
response = self.client.chat_completion(**completion_kwargs)
|
||||||
tool_choice="auto",
|
|
||||||
stop=stop_sequences,
|
|
||||||
temperature=self.temperature,
|
|
||||||
**self.kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = self.client.chat.completions.create(
|
|
||||||
model=self.model_id,
|
|
||||||
messages=messages,
|
|
||||||
stop=stop_sequences,
|
|
||||||
temperature=self.temperature,
|
|
||||||
**self.kwargs,
|
|
||||||
)
|
|
||||||
self.last_input_token_count = response.usage.prompt_tokens
|
self.last_input_token_count = response.usage.prompt_tokens
|
||||||
self.last_output_token_count = response.usage.completion_tokens
|
self.last_output_token_count = response.usage.completion_tokens
|
||||||
message = ChatMessage.from_hf_api(response.choices[0].message)
|
message = ChatMessage.from_hf_api(response.choices[0].message)
|
||||||
|
@ -358,7 +401,7 @@ class TransformersModel(Model):
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__(**kwargs)
|
||||||
if not is_torch_available() or not _is_package_available("transformers"):
|
if not is_torch_available() or not _is_package_available("transformers"):
|
||||||
raise ModuleNotFoundError(
|
raise ModuleNotFoundError(
|
||||||
"Please install 'transformers' extra to use 'TransformersModel': `pip install 'smolagents[transformers]'`"
|
"Please install 'transformers' extra to use 'TransformersModel': `pip install 'smolagents[transformers]'`"
|
||||||
|
@ -418,12 +461,36 @@ class TransformersModel(Model):
|
||||||
stop_sequences: Optional[List[str]] = None,
|
stop_sequences: Optional[List[str]] = None,
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
tools_to_call_from: Optional[List[Tool]] = None,
|
tools_to_call_from: Optional[List[Tool]] = None,
|
||||||
|
**kwargs,
|
||||||
) -> ChatMessage:
|
) -> ChatMessage:
|
||||||
messages = get_clean_message_list(messages, role_conversions=tool_role_conversions)
|
completion_kwargs = self._prepare_completion_kwargs(
|
||||||
|
messages=messages,
|
||||||
|
stop_sequences=stop_sequences,
|
||||||
|
grammar=grammar,
|
||||||
|
tools_to_call_from=tools_to_call_from,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = completion_kwargs.pop("messages")
|
||||||
|
stop_sequences = completion_kwargs.pop("stop", None)
|
||||||
|
|
||||||
|
max_new_tokens = (
|
||||||
|
kwargs.get("max_new_tokens")
|
||||||
|
or kwargs.get("max_tokens")
|
||||||
|
or self.kwargs.get("max_new_tokens")
|
||||||
|
or self.kwargs.get("max_tokens")
|
||||||
|
)
|
||||||
|
|
||||||
|
if max_new_tokens:
|
||||||
|
completion_kwargs["max_new_tokens"] = max_new_tokens
|
||||||
|
|
||||||
|
if stop_sequences:
|
||||||
|
completion_kwargs["stopping_criteria"] = self.make_stopping_criteria(stop_sequences)
|
||||||
|
|
||||||
if tools_to_call_from is not None:
|
if tools_to_call_from is not None:
|
||||||
prompt_tensor = self.tokenizer.apply_chat_template(
|
prompt_tensor = self.tokenizer.apply_chat_template(
|
||||||
messages,
|
messages,
|
||||||
tools=[get_json_schema(tool) for tool in tools_to_call_from],
|
tools=completion_kwargs.pop("tools", []),
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
|
@ -434,14 +501,11 @@ class TransformersModel(Model):
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_tensor = prompt_tensor.to(self.model.device)
|
prompt_tensor = prompt_tensor.to(self.model.device)
|
||||||
count_prompt_tokens = prompt_tensor["input_ids"].shape[1]
|
count_prompt_tokens = prompt_tensor["input_ids"].shape[1]
|
||||||
|
|
||||||
out = self.model.generate(
|
out = self.model.generate(**prompt_tensor, **completion_kwargs)
|
||||||
**prompt_tensor,
|
|
||||||
stopping_criteria=(self.make_stopping_criteria(stop_sequences) if stop_sequences else None),
|
|
||||||
**self.kwargs,
|
|
||||||
)
|
|
||||||
generated_tokens = out[0, count_prompt_tokens:]
|
generated_tokens = out[0, count_prompt_tokens:]
|
||||||
output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||||
self.last_input_token_count = count_prompt_tokens
|
self.last_input_token_count = count_prompt_tokens
|
||||||
|
@ -449,6 +513,7 @@ class TransformersModel(Model):
|
||||||
|
|
||||||
if stop_sequences is not None:
|
if stop_sequences is not None:
|
||||||
output = remove_stop_sequences(output, stop_sequences)
|
output = remove_stop_sequences(output, stop_sequences)
|
||||||
|
|
||||||
if tools_to_call_from is None:
|
if tools_to_call_from is None:
|
||||||
return ChatMessage(role="assistant", content=output)
|
return ChatMessage(role="assistant", content=output)
|
||||||
else:
|
else:
|
||||||
|
@ -498,13 +563,12 @@ class LiteLLMModel(Model):
|
||||||
"Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`"
|
"Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`"
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__()
|
super().__init__(**kwargs)
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
# IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs
|
# IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs
|
||||||
litellm.add_function_to_prompt = True
|
litellm.add_function_to_prompt = True
|
||||||
self.api_base = api_base
|
self.api_base = api_base
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.kwargs = kwargs
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
@ -512,34 +576,28 @@ class LiteLLMModel(Model):
|
||||||
stop_sequences: Optional[List[str]] = None,
|
stop_sequences: Optional[List[str]] = None,
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
tools_to_call_from: Optional[List[Tool]] = None,
|
tools_to_call_from: Optional[List[Tool]] = None,
|
||||||
|
**kwargs,
|
||||||
) -> ChatMessage:
|
) -> ChatMessage:
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
messages = get_clean_message_list(messages, role_conversions=tool_role_conversions)
|
completion_kwargs = self._prepare_completion_kwargs(
|
||||||
|
messages=messages,
|
||||||
|
stop_sequences=stop_sequences,
|
||||||
|
grammar=grammar,
|
||||||
|
tools_to_call_from=tools_to_call_from,
|
||||||
|
model=self.model_id,
|
||||||
|
api_base=self.api_base,
|
||||||
|
api_key=self.api_key,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = litellm.completion(**completion_kwargs)
|
||||||
|
|
||||||
if tools_to_call_from:
|
|
||||||
response = litellm.completion(
|
|
||||||
model=self.model_id,
|
|
||||||
messages=messages,
|
|
||||||
tools=[get_json_schema(tool) for tool in tools_to_call_from],
|
|
||||||
tool_choice="required",
|
|
||||||
stop=stop_sequences,
|
|
||||||
api_base=self.api_base,
|
|
||||||
api_key=self.api_key,
|
|
||||||
**self.kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = litellm.completion(
|
|
||||||
model=self.model_id,
|
|
||||||
messages=messages,
|
|
||||||
stop=stop_sequences,
|
|
||||||
api_base=self.api_base,
|
|
||||||
api_key=self.api_key,
|
|
||||||
**self.kwargs,
|
|
||||||
)
|
|
||||||
self.last_input_token_count = response.usage.prompt_tokens
|
self.last_input_token_count = response.usage.prompt_tokens
|
||||||
self.last_output_token_count = response.usage.completion_tokens
|
self.last_output_token_count = response.usage.completion_tokens
|
||||||
message = response.choices[0].message
|
|
||||||
|
message = ChatMessage(**response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}))
|
||||||
|
|
||||||
if tools_to_call_from is not None:
|
if tools_to_call_from is not None:
|
||||||
return parse_tool_args_if_needed(message)
|
return parse_tool_args_if_needed(message)
|
||||||
return message
|
return message
|
||||||
|
@ -576,13 +634,13 @@ class OpenAIServerModel(Model):
|
||||||
raise ModuleNotFoundError(
|
raise ModuleNotFoundError(
|
||||||
"Please install 'openai' extra to use OpenAIServerModel: `pip install 'smolagents[openai]'`"
|
"Please install 'openai' extra to use OpenAIServerModel: `pip install 'smolagents[openai]'`"
|
||||||
) from None
|
) from None
|
||||||
super().__init__()
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
self.client = openai.OpenAI(
|
self.client = openai.OpenAI(
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
self.kwargs = kwargs
|
|
||||||
self.custom_role_conversions = custom_role_conversions
|
self.custom_role_conversions = custom_role_conversions
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
|
@ -591,30 +649,23 @@ class OpenAIServerModel(Model):
|
||||||
stop_sequences: Optional[List[str]] = None,
|
stop_sequences: Optional[List[str]] = None,
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
tools_to_call_from: Optional[List[Tool]] = None,
|
tools_to_call_from: Optional[List[Tool]] = None,
|
||||||
|
**kwargs,
|
||||||
) -> ChatMessage:
|
) -> ChatMessage:
|
||||||
messages = get_clean_message_list(
|
completion_kwargs = self._prepare_completion_kwargs(
|
||||||
messages,
|
messages=messages,
|
||||||
role_conversions=(self.custom_role_conversions if self.custom_role_conversions else tool_role_conversions),
|
stop_sequences=stop_sequences,
|
||||||
|
grammar=grammar,
|
||||||
|
tools_to_call_from=tools_to_call_from,
|
||||||
|
model=self.model_id,
|
||||||
|
custom_role_conversions=self.custom_role_conversions,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
if tools_to_call_from:
|
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(**completion_kwargs)
|
||||||
model=self.model_id,
|
|
||||||
messages=messages,
|
|
||||||
tools=[get_json_schema(tool) for tool in tools_to_call_from],
|
|
||||||
tool_choice="required",
|
|
||||||
stop=stop_sequences,
|
|
||||||
**self.kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = self.client.chat.completions.create(
|
|
||||||
model=self.model_id,
|
|
||||||
messages=messages,
|
|
||||||
stop=stop_sequences,
|
|
||||||
**self.kwargs,
|
|
||||||
)
|
|
||||||
self.last_input_token_count = response.usage.prompt_tokens
|
self.last_input_token_count = response.usage.prompt_tokens
|
||||||
self.last_output_token_count = response.usage.completion_tokens
|
self.last_output_token_count = response.usage.completion_tokens
|
||||||
message = response.choices[0].message
|
|
||||||
|
message = ChatMessage(**response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}))
|
||||||
if tools_to_call_from is not None:
|
if tools_to_call_from is not None:
|
||||||
return parse_tool_args_if_needed(message)
|
return parse_tool_args_if_needed(message)
|
||||||
return message
|
return message
|
||||||
|
|
Loading…
Reference in New Issue