diff --git a/pyproject.toml b/pyproject.toml index 57eaa7c..bfd970d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "python-dotenv>=1.0.1", "e2b-code-interpreter>=1.0.3", "litellm>=1.55.10", + "openai>=1.58.1", ] [tool.ruff] diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 9b1fd55..8be29a2 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -31,6 +31,7 @@ from transformers import ( StoppingCriteria, StoppingCriteriaList, ) +import openai from .tools import Tool from .utils import parse_json_tool_call @@ -487,6 +488,99 @@ class LiteLLMModel(Model): return tool_calls.function.name, arguments, tool_calls.id +class OpenAIServerModel(Model): + """This engine connects to an OpenAI-compatible API server. + + Parameters: + model_id (`str`): + The model identifier to use on the server (e.g. "gpt-3.5-turbo"). + api_base (`str`): + The base URL of the OpenAI-compatible API server. + api_key (`str`): + The API key to use for authentication. + temperature (`float`, *optional*, defaults to 0.7): + Controls randomness in the model's responses. Values between 0 and 2. + **kwargs: + Additional keyword arguments to pass to the OpenAI API. + """ + + def __init__( + self, + model_id: str, + api_base: str, + api_key: str, + temperature: float = 0.7, + **kwargs + ): + super().__init__() + self.model_id = model_id + self.client = openai.OpenAI( + base_url=api_base, + api_key=api_key, + ) + self.temperature = temperature + self.kwargs = kwargs + + def generate( + self, + messages: List[Dict[str, str]], + stop_sequences: Optional[List[str]] = None, + grammar: Optional[str] = None, + max_tokens: int = 1500, + ) -> str: + """Generates a text completion for the given message list""" + messages = get_clean_message_list( + messages, role_conversions=tool_role_conversions + ) + + response = self.client.chat.completions.create( + model=self.model_id, + messages=messages, + stop=stop_sequences, + max_tokens=max_tokens, + temperature=self.temperature, + **self.kwargs + ) + + self.last_input_token_count = response.usage.prompt_tokens + self.last_output_token_count = response.usage.completion_tokens + return response.choices[0].message.content + + def get_tool_call( + self, + messages: List[Dict[str, str]], + available_tools: List[Tool], + stop_sequences: Optional[List[str]] = None, + max_tokens: int = 500, + ) -> Tuple[str, Union[str, Dict], str]: + """Generates a tool call for the given message list""" + messages = get_clean_message_list( + messages, role_conversions=tool_role_conversions + ) + + response = self.client.chat.completions.create( + model=self.model_id, + messages=messages, + tools=[get_json_schema(tool) for tool in available_tools], + tool_choice="auto", + stop=stop_sequences, + max_tokens=max_tokens, + temperature=self.temperature, + **self.kwargs + ) + + tool_calls = response.choices[0].message.tool_calls[0] + self.last_input_token_count = response.usage.prompt_tokens + self.last_output_token_count = response.usage.completion_tokens + + try: + arguments = json.loads(tool_calls.function.arguments) + except json.JSONDecodeError: + arguments = tool_calls.function.arguments + + return tool_calls.function.name, arguments, tool_calls.id + + __all__ = [ "MessageRole", "tool_role_conversions", @@ -495,4 +589,5 @@ __all__ = [ "TransformersModel", "HfApiModel", "LiteLLMModel", + "OpenAIServerModel", ]