Improve tool call argument parsing (#267)

* Improve tool call argument parsing
This commit is contained in:
Aymeric Roucher 2025-01-20 10:44:40 +01:00 committed by GitHub
parent 89a6350fe2
commit 0abd91cf72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 53 additions and 17 deletions

View File

@ -755,6 +755,8 @@ class ToolCallingAgent(MultiStepAgent):
tools_to_call_from=list(self.tools.values()),
stop_sequences=["Observation:"],
)
if model_message.tool_calls is None or len(model_message.tool_calls) == 0:
raise Exception("Model did not call any tools. Call `final_answer` tool to return a final answer.")
tool_call = model_message.tool_calls[0]
tool_name, tool_call_id = tool_call.function.name, tool_call.id
tool_arguments = tool_call.function.arguments
@ -1022,9 +1024,9 @@ class ManagedAgent:
"""Adds additional prompting for the managed agent, like 'add more detail in your answer'."""
full_task = self.managed_agent_prompt.format(name=self.name, task=task)
if self.additional_prompting:
full_task = full_task.replace("\n{{additional_prompting}}", self.additional_prompting).strip()
full_task = full_task.replace("\n{additional_prompting}", self.additional_prompting).strip()
else:
full_task = full_task.replace("\n{{additional_prompting}}", "").strip()
full_task = full_task.replace("\n{additional_prompting}", "").strip()
return full_task
def __call__(self, request, **kwargs):

View File

@ -158,6 +158,8 @@ class DuckDuckGoSearchTool(Tool):
def forward(self, query: str) -> str:
results = self.ddgs.text(query, max_results=self.max_results)
if len(results) == 0:
raise Exception("No results found! Try a less restrictive/shorter query.")
postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results]
return "## Search Results\n\n" + "\n\n".join(postprocessed_results)

View File

@ -104,6 +104,22 @@ class ChatMessage:
return cls(role=message.role, content=message.content, tool_calls=tool_calls)
def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]:
if isinstance(arguments, dict):
return arguments
else:
try:
return json.loads(arguments)
except Exception:
return arguments
def parse_tool_args_if_needed(message: ChatMessage) -> ChatMessage:
for tool_call in message.tool_calls:
tool_call.function.arguments = parse_json_if_needed(tool_call.function.arguments)
return message
class MessageRole(str, Enum):
USER = "user"
ASSISTANT = "assistant"
@ -181,17 +197,6 @@ def get_clean_message_list(
return final_message_list
def parse_dictionary(possible_dictionary: str) -> Union[Dict, str]:
try:
start, end = (
possible_dictionary.find("{"),
possible_dictionary.rfind("}") + 1,
)
return json.loads(possible_dictionary[start:end])
except Exception:
return possible_dictionary
class Model:
def __init__(self):
self.last_input_token_count = None
@ -304,7 +309,10 @@ class HfApiModel(Model):
)
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
return ChatMessage.from_hf_api(response.choices[0].message)
message = ChatMessage.from_hf_api(response.choices[0].message)
if tools_to_call_from is not None:
return parse_tool_args_if_needed(message)
return message
class TransformersModel(Model):
@ -523,7 +531,10 @@ class LiteLLMModel(Model):
)
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
return response.choices[0].message
message = response.choices[0].message
if tools_to_call_from is not None:
return parse_tool_args_if_needed(message)
return message
class OpenAIServerModel(Model):
@ -582,7 +593,7 @@ class OpenAIServerModel(Model):
model=self.model_id,
messages=messages,
tools=[get_json_schema(tool) for tool in tools_to_call_from],
tool_choice="auto",
tool_choice="required",
stop=stop_sequences,
**self.kwargs,
)
@ -595,7 +606,10 @@ class OpenAIServerModel(Model):
)
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
return response.choices[0].message
message = response.choices[0].message
if tools_to_call_from is not None:
return parse_tool_args_if_needed(message)
return message
__all__ = [

View File

@ -17,6 +17,7 @@ import unittest
from typing import Optional
from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool
from smolagents.models import parse_json_if_needed
class ModelTests(unittest.TestCase):
@ -55,3 +56,20 @@ class ModelTests(unittest.TestCase):
messages = [{"role": "user", "content": "Hello!"}]
output = model(messages, stop_sequences=["great"]).content
assert output == "assistant\nHello"
def test_parse_json_if_needed(self):
args = "abc"
parsed_args = parse_json_if_needed(args)
assert parsed_args == "abc"
args = '{"a": 3}'
parsed_args = parse_json_if_needed(args)
assert parsed_args == {"a": 3}
args = "3"
parsed_args = parse_json_if_needed(args)
assert parsed_args == 3
args = 3
parsed_args = parse_json_if_needed(args)
assert parsed_args == 3