smolagents/src/smolagents/models.py

460 lines
16 KiB
Python

#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from enum import Enum
from typing import Dict, List, Optional
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
StoppingCriteria,
StoppingCriteriaList,
)
import litellm
import logging
import os
import random
from huggingface_hub import InferenceClient
from .tools import Tool
from .utils import parse_json_tool_call
logger = logging.getLogger(__name__)
DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
"type": "regex",
"value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n<end_action>',
}
DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
"type": "regex",
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_action>",
}
class MessageRole(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
TOOL_CALL = "tool-call"
TOOL_RESPONSE = "tool-response"
@classmethod
def roles(cls):
return [r.value for r in cls]
tool_role_conversions = {
MessageRole.TOOL_CALL: MessageRole.ASSISTANT,
MessageRole.TOOL_RESPONSE: MessageRole.USER,
}
def get_json_schema(tool: Tool) -> Dict:
return {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": {
k: {k2: v2.replace("any", "object") for k2, v2 in v.items()}
for k, v in tool.inputs.items()
},
"required": list(tool.inputs.keys()),
},
},
}
def remove_stop_sequences(content: str, stop_sequences: List[str]) -> str:
for stop_seq in stop_sequences:
if content[-len(stop_seq) :] == stop_seq:
content = content[: -len(stop_seq)]
return content
def get_clean_message_list(
message_list: List[Dict[str, str]],
role_conversions: Dict[MessageRole, MessageRole] = {},
) -> List[Dict[str, str]]:
"""
Subsequent messages with the same role will be concatenated to a single message.
Args:
message_list (`List[Dict[str, str]]`): List of chat messages.
"""
final_message_list = []
message_list = deepcopy(message_list) # Avoid modifying the original list
for message in message_list:
# if not set(message.keys()) == {"role", "content"}:
# raise ValueError("Message should contain only 'role' and 'content' keys!")
role = message["role"]
if role not in MessageRole.roles():
raise ValueError(
f"Incorrect role {role}, only {MessageRole.roles()} are supported for now."
)
if role in role_conversions:
message["role"] = role_conversions[role]
if (
len(final_message_list) > 0
and message["role"] == final_message_list[-1]["role"]
):
final_message_list[-1]["content"] += "\n=======\n" + message["content"]
else:
final_message_list.append(message)
return final_message_list
class HfModel:
def __init__(self):
self.last_input_token_count = None
self.last_output_token_count = None
def get_token_counts(self):
return {
"input_token_count": self.last_input_token_count,
"output_token_count": self.last_output_token_count,
}
def generate(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
):
raise NotImplementedError
def __call__(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
) -> str:
"""Process the input messages and return the model's response.
Parameters:
messages (`List[Dict[str, str]]`):
A list of message dictionaries to be processed. Each dictionary should have the structure `{"role": "user/system", "content": "message content"}`.
stop_sequences (`List[str]`, *optional*):
A list of strings that will stop the generation if encountered in the model's output.
grammar (`str`, *optional*):
The grammar or formatting structure to use in the model's response.
max_tokens (`int`, *optional*):
The maximum count of tokens to generate.
Returns:
`str`: The text content of the model's response.
"""
if not isinstance(messages, List):
raise ValueError(
"Messages should be a list of dictionaries with 'role' and 'content' keys."
)
if stop_sequences is None:
stop_sequences = []
response = self.generate(messages, stop_sequences, grammar, max_tokens)
return remove_stop_sequences(response, stop_sequences)
class HfApiModel(HfModel):
"""A class to interact with Hugging Face's Inference API for language model interaction.
This engine allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
Parameters:
model (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`):
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
token (`str`, *optional*):
Token used by the Hugging Face API for authentication. This token need to be authorized 'Make calls to the serverless Inference API'.
If the model is gated (like Llama-3 models), the token also needs 'Read access to contents of all public gated repos you can access'.
If not provided, the class will try to use environment variable 'HF_TOKEN', else use the token stored in the Hugging Face CLI configuration.
max_tokens (`int`, *optional*, defaults to 1500):
The maximum number of tokens allowed in the output.
timeout (`int`, *optional*, defaults to 120):
Timeout for the API request, in seconds.
Raises:
ValueError:
If the model name is not provided.
Example:
```python
>>> engine = HfApiModel(
... model="Qwen/Qwen2.5-Coder-32B-Instruct",
... token="your_hf_token_here",
... max_tokens=2000
... )
>>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}]
>>> response = engine(messages, stop_sequences=["END"])
>>> print(response)
"Quantum mechanics is the branch of physics that studies..."
```
"""
def __init__(
self,
model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct",
token: Optional[str] = None,
timeout: Optional[int] = 120,
):
super().__init__()
self.model_id = model_id
if token is None:
token = os.getenv("HF_TOKEN")
self.client = InferenceClient(self.model_id, token=token, timeout=timeout)
def generate(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
) -> str:
"""Generates a text completion for the given message list"""
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
# Send messages to the Hugging Face Inference API
if grammar is not None:
output = self.client.chat_completion(
messages,
stop=stop_sequences,
response_format=grammar,
max_tokens=max_tokens,
)
else:
output = self.client.chat.completions.create(
messages, stop=stop_sequences, max_tokens=max_tokens
)
response = output.choices[0].message.content
self.last_input_token_count = output.usage.prompt_tokens
self.last_output_token_count = output.usage.completion_tokens
return response
def get_tool_call(
self,
messages: List[Dict[str, str]],
available_tools: List[Tool],
):
"""Generates a tool call for the given message list. This method is used only by `ToolCallingAgent`."""
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
response = self.client.chat.completions.create(
messages=messages,
tools=[get_json_schema(tool) for tool in available_tools],
tool_choice="auto",
)
tool_call = response.choices[0].message.tool_calls[0]
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
return tool_call.function.name, tool_call.function.arguments, tool_call.id
class TransformersModel(HfModel):
"""This engine initializes a model and tokenizer from the given `model_id`."""
def __init__(self, model_id: Optional[str] = None):
super().__init__()
default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
if model_id is None:
model_id = default_model_id
logger.warning(
f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'"
)
self.model_id = model_id
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained(model_id)
except Exception as e:
logger.warning(
f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {model_id=}."
)
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
self.model = AutoModelForCausalLM.from_pretrained(default_model_id)
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
class StopOnStrings(StoppingCriteria):
def __init__(self, stop_strings: List[str], tokenizer):
self.stop_strings = stop_strings
self.tokenizer = tokenizer
self.stream = ""
def reset(self):
self.stream = ""
def __call__(self, input_ids, scores, **kwargs):
generated = self.tokenizer.decode(input_ids[0][-1], skip_special_tokens=True)
self.stream += generated
if any([self.stream.endswith(stop_string) for stop_string in self.stop_strings]):
return True
return False
return StoppingCriteriaList([StopOnStrings(stop_sequences, self.tokenizer)])
def generate(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
) -> str:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
# Get LLM output
prompt = self.tokenizer.apply_chat_template(
messages,
return_tensors="pt",
return_dict=True,
)
prompt = prompt.to(self.model.device)
count_prompt_tokens = prompt["input_ids"].shape[1]
out = self.model.generate(
**prompt,
max_new_tokens=max_tokens,
stopping_criteria=(
self.make_stopping_criteria(stop_sequences) if stop_sequences else None
),
)
generated_tokens = out[0, count_prompt_tokens:]
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
self.last_input_token_count = count_prompt_tokens
self.last_output_token_count = len(generated_tokens)
if stop_sequences is not None:
response = remove_stop_sequences(response, stop_sequences)
return response
def get_tool_call(
self,
messages: List[Dict[str, str]],
available_tools: List[Tool],
stop_sequences: Optional[List[str]] = None,
max_tokens: int = 500,
) -> str:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
prompt = self.tokenizer.apply_chat_template(
messages,
tools=[get_json_schema(tool) for tool in available_tools],
return_tensors="pt",
return_dict=True,
add_generation_prompt=True,
)
prompt = prompt.to(self.model.device)
count_prompt_tokens = prompt["input_ids"].shape[1]
out = self.model.generate(
**prompt,
max_new_tokens=max_tokens,
stopping_criteria=(
self.make_stopping_criteria(stop_sequences) if stop_sequences else None
),
)
generated_tokens = out[0, count_prompt_tokens:]
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
self.last_input_token_count = count_prompt_tokens
self.last_output_token_count = len(generated_tokens)
if stop_sequences is not None:
response = remove_stop_sequences(response, stop_sequences)
tool_name, tool_input = parse_json_tool_call(response)
call_id = "".join(random.choices("0123456789", k=5))
return tool_name, tool_input, call_id
class LiteLLMModel:
def __init__(self, model_id="anthropic/claude-3-5-sonnet-20240620"):
self.model_id = model_id
# IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs
litellm.add_function_to_prompt = True
self.last_input_token_count = 0
self.last_output_token_count = 0
def __call__(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
) -> str:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
response = litellm.completion(
model=self.model_id,
messages=messages,
stop=stop_sequences,
max_tokens=max_tokens,
)
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
return response.choices[0].message.content
def get_tool_call(
self,
messages: List[Dict[str, str]],
available_tools: List[Tool],
stop_sequences: Optional[List[str]] = None,
max_tokens: int = 1500,
):
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
response = litellm.completion(
model=self.model_id,
messages=messages,
tools=[get_json_schema(tool) for tool in available_tools],
tool_choice="required",
stop=stop_sequences,
max_tokens=max_tokens,
)
tool_calls = response.choices[0].message.tool_calls[0]
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
return tool_calls.function.name, tool_calls.function.arguments, tool_calls.id
__all__ = [
"MessageRole",
"tool_role_conversions",
"get_clean_message_list",
"HfModel",
"TransformersModel",
"HfApiModel",
"LiteLLMModel",
]