Fix tool calls with LiteLLM and tool optional types (#318)
This commit is contained in:
parent
ffaa945936
commit
fe2f4e735c
|
@ -27,6 +27,7 @@ import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import types
|
import types
|
||||||
|
from copy import copy
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
|
@ -381,7 +382,7 @@ _BASE_TYPE_MAPPING = {
|
||||||
|
|
||||||
def _get_json_schema_type(param_type: str) -> Dict[str, str]:
|
def _get_json_schema_type(param_type: str) -> Dict[str, str]:
|
||||||
if param_type in _BASE_TYPE_MAPPING:
|
if param_type in _BASE_TYPE_MAPPING:
|
||||||
return _BASE_TYPE_MAPPING[param_type]
|
return copy(_BASE_TYPE_MAPPING[param_type])
|
||||||
if str(param_type) == "Image" and _is_pillow_available():
|
if str(param_type) == "Image" and _is_pillow_available():
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
|
|
||||||
|
|
|
@ -101,6 +101,18 @@ class ChatMessage:
|
||||||
tool_calls = [ChatMessageToolCall.from_hf_api(tool_call) for tool_call in message.tool_calls]
|
tool_calls = [ChatMessageToolCall.from_hf_api(tool_call) for tool_call in message.tool_calls]
|
||||||
return cls(role=message.role, content=message.content, tool_calls=tool_calls)
|
return cls(role=message.role, content=message.content, tool_calls=tool_calls)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict) -> "ChatMessage":
|
||||||
|
if data.get("tool_calls"):
|
||||||
|
tool_calls = [
|
||||||
|
ChatMessageToolCall(
|
||||||
|
function=ChatMessageToolCallDefinition(**tc["function"]), id=tc["id"], type=tc["type"]
|
||||||
|
)
|
||||||
|
for tc in data["tool_calls"]
|
||||||
|
]
|
||||||
|
data["tool_calls"] = tool_calls
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
|
||||||
def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]:
|
def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]:
|
||||||
if isinstance(arguments, dict):
|
if isinstance(arguments, dict):
|
||||||
|
@ -595,7 +607,9 @@ class LiteLLMModel(Model):
|
||||||
self.last_input_token_count = response.usage.prompt_tokens
|
self.last_input_token_count = response.usage.prompt_tokens
|
||||||
self.last_output_token_count = response.usage.completion_tokens
|
self.last_output_token_count = response.usage.completion_tokens
|
||||||
|
|
||||||
message = ChatMessage(**response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}))
|
message = ChatMessage.from_dict(
|
||||||
|
response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})
|
||||||
|
)
|
||||||
|
|
||||||
if tools_to_call_from is not None:
|
if tools_to_call_from is not None:
|
||||||
return parse_tool_args_if_needed(message)
|
return parse_tool_args_if_needed(message)
|
||||||
|
@ -664,7 +678,9 @@ class OpenAIServerModel(Model):
|
||||||
self.last_input_token_count = response.usage.prompt_tokens
|
self.last_input_token_count = response.usage.prompt_tokens
|
||||||
self.last_output_token_count = response.usage.completion_tokens
|
self.last_output_token_count = response.usage.completion_tokens
|
||||||
|
|
||||||
message = ChatMessage(**response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}))
|
message = ChatMessage.from_dict(
|
||||||
|
response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})
|
||||||
|
)
|
||||||
if tools_to_call_from is not None:
|
if tools_to_call_from is not None:
|
||||||
return parse_tool_args_if_needed(message)
|
return parse_tool_args_if_needed(message)
|
||||||
return message
|
return message
|
||||||
|
|
Loading…
Reference in New Issue