Fix tool calls with LiteLLM and tool optional types (#318)

This commit is contained in:
Aymeric Roucher 2025-01-22 18:42:10 +01:00 committed by GitHub
parent ffaa945936
commit fe2f4e735c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 3 deletions

View File

@ -27,6 +27,7 @@ import json
import os import os
import re import re
import types import types
from copy import copy
from typing import ( from typing import (
Any, Any,
Callable, Callable,
@ -381,7 +382,7 @@ _BASE_TYPE_MAPPING = {
def _get_json_schema_type(param_type: str) -> Dict[str, str]: def _get_json_schema_type(param_type: str) -> Dict[str, str]:
if param_type in _BASE_TYPE_MAPPING: if param_type in _BASE_TYPE_MAPPING:
return _BASE_TYPE_MAPPING[param_type] return copy(_BASE_TYPE_MAPPING[param_type])
if str(param_type) == "Image" and _is_pillow_available(): if str(param_type) == "Image" and _is_pillow_available():
from PIL.Image import Image from PIL.Image import Image

View File

@ -101,6 +101,18 @@ class ChatMessage:
tool_calls = [ChatMessageToolCall.from_hf_api(tool_call) for tool_call in message.tool_calls] tool_calls = [ChatMessageToolCall.from_hf_api(tool_call) for tool_call in message.tool_calls]
return cls(role=message.role, content=message.content, tool_calls=tool_calls) return cls(role=message.role, content=message.content, tool_calls=tool_calls)
@classmethod
def from_dict(cls, data: dict) -> "ChatMessage":
if data.get("tool_calls"):
tool_calls = [
ChatMessageToolCall(
function=ChatMessageToolCallDefinition(**tc["function"]), id=tc["id"], type=tc["type"]
)
for tc in data["tool_calls"]
]
data["tool_calls"] = tool_calls
return cls(**data)
def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]: def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]:
if isinstance(arguments, dict): if isinstance(arguments, dict):
@ -595,7 +607,9 @@ class LiteLLMModel(Model):
self.last_input_token_count = response.usage.prompt_tokens self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens self.last_output_token_count = response.usage.completion_tokens
message = ChatMessage(**response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})) message = ChatMessage.from_dict(
response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})
)
if tools_to_call_from is not None: if tools_to_call_from is not None:
return parse_tool_args_if_needed(message) return parse_tool_args_if_needed(message)
@ -664,7 +678,9 @@ class OpenAIServerModel(Model):
self.last_input_token_count = response.usage.prompt_tokens self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens self.last_output_token_count = response.usage.completion_tokens
message = ChatMessage(**response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})) message = ChatMessage.from_dict(
response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})
)
if tools_to_call_from is not None: if tools_to_call_from is not None:
return parse_tool_args_if_needed(message) return parse_tool_args_if_needed(message)
return message return message