Implement OpenAIServerModel (#109)

This commit is contained in:
Zetaphor 2025-01-08 15:39:41 -06:00 committed by GitHub
parent d3cd0f9e09
commit b4528d6a6f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 96 additions and 0 deletions

View File

@ -27,6 +27,7 @@ dependencies = [
"python-dotenv>=1.0.1",
"e2b-code-interpreter>=1.0.3",
"litellm>=1.55.10",
"openai>=1.58.1",
]
[tool.ruff]

View File

@ -31,6 +31,7 @@ from transformers import (
StoppingCriteria,
StoppingCriteriaList,
)
import openai
from .tools import Tool
from .utils import parse_json_tool_call
@ -487,6 +488,99 @@ class LiteLLMModel(Model):
return tool_calls.function.name, arguments, tool_calls.id
class OpenAIServerModel(Model):
"""This engine connects to an OpenAI-compatible API server.
Parameters:
model_id (`str`):
The model identifier to use on the server (e.g. "gpt-3.5-turbo").
api_base (`str`):
The base URL of the OpenAI-compatible API server.
api_key (`str`):
The API key to use for authentication.
temperature (`float`, *optional*, defaults to 0.7):
Controls randomness in the model's responses. Values between 0 and 2.
**kwargs:
Additional keyword arguments to pass to the OpenAI API.
"""
def __init__(
self,
model_id: str,
api_base: str,
api_key: str,
temperature: float = 0.7,
**kwargs
):
super().__init__()
self.model_id = model_id
self.client = openai.OpenAI(
base_url=api_base,
api_key=api_key,
)
self.temperature = temperature
self.kwargs = kwargs
def generate(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
) -> str:
"""Generates a text completion for the given message list"""
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
response = self.client.chat.completions.create(
model=self.model_id,
messages=messages,
stop=stop_sequences,
max_tokens=max_tokens,
temperature=self.temperature,
**self.kwargs
)
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
return response.choices[0].message.content
def get_tool_call(
self,
messages: List[Dict[str, str]],
available_tools: List[Tool],
stop_sequences: Optional[List[str]] = None,
max_tokens: int = 500,
) -> Tuple[str, Union[str, Dict], str]:
"""Generates a tool call for the given message list"""
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
response = self.client.chat.completions.create(
model=self.model_id,
messages=messages,
tools=[get_json_schema(tool) for tool in available_tools],
tool_choice="auto",
stop=stop_sequences,
max_tokens=max_tokens,
temperature=self.temperature,
**self.kwargs
)
tool_calls = response.choices[0].message.tool_calls[0]
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
try:
arguments = json.loads(tool_calls.function.arguments)
except json.JSONDecodeError:
arguments = tool_calls.function.arguments
return tool_calls.function.name, arguments, tool_calls.id
__all__ = [
"MessageRole",
"tool_role_conversions",
@ -495,4 +589,5 @@ __all__ = [
"TransformersModel",
"HfApiModel",
"LiteLLMModel",
"OpenAIServerModel",
]