Fix tool calls with LiteLLM and tool optional types (#318)
This commit is contained in:
parent
ffaa945936
commit
fe2f4e735c
|
@ -27,6 +27,7 @@ import json
|
|||
import os
|
||||
import re
|
||||
import types
|
||||
from copy import copy
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
|
@ -381,7 +382,7 @@ _BASE_TYPE_MAPPING = {
|
|||
|
||||
def _get_json_schema_type(param_type: str) -> Dict[str, str]:
|
||||
if param_type in _BASE_TYPE_MAPPING:
|
||||
return _BASE_TYPE_MAPPING[param_type]
|
||||
return copy(_BASE_TYPE_MAPPING[param_type])
|
||||
if str(param_type) == "Image" and _is_pillow_available():
|
||||
from PIL.Image import Image
|
||||
|
||||
|
|
|
@ -101,6 +101,18 @@ class ChatMessage:
|
|||
tool_calls = [ChatMessageToolCall.from_hf_api(tool_call) for tool_call in message.tool_calls]
|
||||
return cls(role=message.role, content=message.content, tool_calls=tool_calls)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "ChatMessage":
|
||||
if data.get("tool_calls"):
|
||||
tool_calls = [
|
||||
ChatMessageToolCall(
|
||||
function=ChatMessageToolCallDefinition(**tc["function"]), id=tc["id"], type=tc["type"]
|
||||
)
|
||||
for tc in data["tool_calls"]
|
||||
]
|
||||
data["tool_calls"] = tool_calls
|
||||
return cls(**data)
|
||||
|
||||
|
||||
def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]:
|
||||
if isinstance(arguments, dict):
|
||||
|
@ -595,7 +607,9 @@ class LiteLLMModel(Model):
|
|||
self.last_input_token_count = response.usage.prompt_tokens
|
||||
self.last_output_token_count = response.usage.completion_tokens
|
||||
|
||||
message = ChatMessage(**response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}))
|
||||
message = ChatMessage.from_dict(
|
||||
response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})
|
||||
)
|
||||
|
||||
if tools_to_call_from is not None:
|
||||
return parse_tool_args_if_needed(message)
|
||||
|
@ -664,7 +678,9 @@ class OpenAIServerModel(Model):
|
|||
self.last_input_token_count = response.usage.prompt_tokens
|
||||
self.last_output_token_count = response.usage.completion_tokens
|
||||
|
||||
message = ChatMessage(**response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}))
|
||||
message = ChatMessage.from_dict(
|
||||
response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})
|
||||
)
|
||||
if tools_to_call_from is not None:
|
||||
return parse_tool_args_if_needed(message)
|
||||
return message
|
||||
|
|
Loading…
Reference in New Issue