Implement OpenAIServerModel (#109)
This commit is contained in:
parent
d3cd0f9e09
commit
b4528d6a6f
|
@ -27,6 +27,7 @@ dependencies = [
|
||||||
"python-dotenv>=1.0.1",
|
"python-dotenv>=1.0.1",
|
||||||
"e2b-code-interpreter>=1.0.3",
|
"e2b-code-interpreter>=1.0.3",
|
||||||
"litellm>=1.55.10",
|
"litellm>=1.55.10",
|
||||||
|
"openai>=1.58.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
|
|
|
@ -31,6 +31,7 @@ from transformers import (
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
StoppingCriteriaList,
|
StoppingCriteriaList,
|
||||||
)
|
)
|
||||||
|
import openai
|
||||||
|
|
||||||
from .tools import Tool
|
from .tools import Tool
|
||||||
from .utils import parse_json_tool_call
|
from .utils import parse_json_tool_call
|
||||||
|
@ -487,6 +488,99 @@ class LiteLLMModel(Model):
|
||||||
return tool_calls.function.name, arguments, tool_calls.id
|
return tool_calls.function.name, arguments, tool_calls.id
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIServerModel(Model):
|
||||||
|
"""This engine connects to an OpenAI-compatible API server.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
model_id (`str`):
|
||||||
|
The model identifier to use on the server (e.g. "gpt-3.5-turbo").
|
||||||
|
api_base (`str`):
|
||||||
|
The base URL of the OpenAI-compatible API server.
|
||||||
|
api_key (`str`):
|
||||||
|
The API key to use for authentication.
|
||||||
|
temperature (`float`, *optional*, defaults to 0.7):
|
||||||
|
Controls randomness in the model's responses. Values between 0 and 2.
|
||||||
|
**kwargs:
|
||||||
|
Additional keyword arguments to pass to the OpenAI API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
api_base: str,
|
||||||
|
api_key: str,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.model_id = model_id
|
||||||
|
self.client = openai.OpenAI(
|
||||||
|
base_url=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
self.temperature = temperature
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
grammar: Optional[str] = None,
|
||||||
|
max_tokens: int = 1500,
|
||||||
|
) -> str:
|
||||||
|
"""Generates a text completion for the given message list"""
|
||||||
|
messages = get_clean_message_list(
|
||||||
|
messages, role_conversions=tool_role_conversions
|
||||||
|
)
|
||||||
|
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model_id,
|
||||||
|
messages=messages,
|
||||||
|
stop=stop_sequences,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=self.temperature,
|
||||||
|
**self.kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.last_input_token_count = response.usage.prompt_tokens
|
||||||
|
self.last_output_token_count = response.usage.completion_tokens
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
def get_tool_call(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
available_tools: List[Tool],
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
max_tokens: int = 500,
|
||||||
|
) -> Tuple[str, Union[str, Dict], str]:
|
||||||
|
"""Generates a tool call for the given message list"""
|
||||||
|
messages = get_clean_message_list(
|
||||||
|
messages, role_conversions=tool_role_conversions
|
||||||
|
)
|
||||||
|
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model_id,
|
||||||
|
messages=messages,
|
||||||
|
tools=[get_json_schema(tool) for tool in available_tools],
|
||||||
|
tool_choice="auto",
|
||||||
|
stop=stop_sequences,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=self.temperature,
|
||||||
|
**self.kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_calls = response.choices[0].message.tool_calls[0]
|
||||||
|
self.last_input_token_count = response.usage.prompt_tokens
|
||||||
|
self.last_output_token_count = response.usage.completion_tokens
|
||||||
|
|
||||||
|
try:
|
||||||
|
arguments = json.loads(tool_calls.function.arguments)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
arguments = tool_calls.function.arguments
|
||||||
|
|
||||||
|
return tool_calls.function.name, arguments, tool_calls.id
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MessageRole",
|
"MessageRole",
|
||||||
"tool_role_conversions",
|
"tool_role_conversions",
|
||||||
|
@ -495,4 +589,5 @@ __all__ = [
|
||||||
"TransformersModel",
|
"TransformersModel",
|
||||||
"HfApiModel",
|
"HfApiModel",
|
||||||
"LiteLLMModel",
|
"LiteLLMModel",
|
||||||
|
"OpenAIServerModel",
|
||||||
]
|
]
|
||||||
|
|
Loading…
Reference in New Issue