Improve OpenAIServerModel by making api_base and api_url optional (will then point to OpenAI server)

This commit is contained in:
Aymeric 2025-01-17 18:03:51 +01:00
parent 68933e7e90
commit d8591dc703
2 changed files with 13 additions and 6 deletions

View File

@ -1099,7 +1099,7 @@ def import_modules(expression, state, authorized_imports):
"popen",
"spawn",
"shutil",
"glob",
"sys",
"pathlib",
"io",
"socket",

View File

@ -561,10 +561,13 @@ class OpenAIServerModel(Model):
Parameters:
model_id (`str`):
The model identifier to use on the server (e.g. "gpt-3.5-turbo").
api_base (`str`):
api_base (`str`, *optional*):
The base URL of the OpenAI-compatible API server.
api_key (`str`):
api_key (`str`, *optional*):
The API key to use for authentication.
custom_role_conversions (`Dict{str, str]`, *optional*):
Custom role conversion mapping to convert message roles in others.
Useful for specific models that do not support specific message roles like "system".
**kwargs:
Additional keyword arguments to pass to the OpenAI API.
"""
@ -572,8 +575,9 @@ class OpenAIServerModel(Model):
def __init__(
self,
model_id: str,
api_base: str,
api_key: str,
api_base: Optional[str] = None,
api_key: Optional[str] = None,
custom_role_conversions: Optional[Dict[str, str]] = None,
**kwargs,
):
try:
@ -589,6 +593,7 @@ class OpenAIServerModel(Model):
api_key=api_key,
)
self.kwargs = kwargs
self.custom_role_conversions = custom_role_conversions
def __call__(
self,
@ -598,7 +603,9 @@ class OpenAIServerModel(Model):
tools_to_call_from: Optional[List[Tool]] = None,
) -> ChatMessage:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
messages, role_conversions=(
self.custom_role_conversions if self.custom_role_conversions else tool_role_conversions
)
)
if tools_to_call_from:
response = self.client.chat.completions.create(