From 0abd91cf72afbc687c884b9a4e2df4fbd8ca9fe6 Mon Sep 17 00:00:00 2001 From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> Date: Mon, 20 Jan 2025 10:44:40 +0100 Subject: [PATCH] Improve tool call argument parsing (#267) * Improve tool call argument parsing --- src/smolagents/agents.py | 6 +++-- src/smolagents/default_tools.py | 2 ++ src/smolagents/models.py | 44 ++++++++++++++++++++++----------- tests/test_models.py | 18 ++++++++++++++ 4 files changed, 53 insertions(+), 17 deletions(-) diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index e44a050..22af248 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -755,6 +755,8 @@ class ToolCallingAgent(MultiStepAgent): tools_to_call_from=list(self.tools.values()), stop_sequences=["Observation:"], ) + if model_message.tool_calls is None or len(model_message.tool_calls) == 0: + raise Exception("Model did not call any tools. Call `final_answer` tool to return a final answer.") tool_call = model_message.tool_calls[0] tool_name, tool_call_id = tool_call.function.name, tool_call.id tool_arguments = tool_call.function.arguments @@ -1022,9 +1024,9 @@ class ManagedAgent: """Adds additional prompting for the managed agent, like 'add more detail in your answer'.""" full_task = self.managed_agent_prompt.format(name=self.name, task=task) if self.additional_prompting: - full_task = full_task.replace("\n{{additional_prompting}}", self.additional_prompting).strip() + full_task = full_task.replace("\n{additional_prompting}", self.additional_prompting).strip() else: - full_task = full_task.replace("\n{{additional_prompting}}", "").strip() + full_task = full_task.replace("\n{additional_prompting}", "").strip() return full_task def __call__(self, request, **kwargs): diff --git a/src/smolagents/default_tools.py b/src/smolagents/default_tools.py index c0fa139..6a0913e 100644 --- a/src/smolagents/default_tools.py +++ b/src/smolagents/default_tools.py @@ -158,6 +158,8 @@ class DuckDuckGoSearchTool(Tool): def forward(self, query: str) -> str: results = self.ddgs.text(query, max_results=self.max_results) + if len(results) == 0: + raise Exception("No results found! Try a less restrictive/shorter query.") postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results] return "## Search Results\n\n" + "\n\n".join(postprocessed_results) diff --git a/src/smolagents/models.py b/src/smolagents/models.py index f19a291..94d8a57 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -104,6 +104,22 @@ class ChatMessage: return cls(role=message.role, content=message.content, tool_calls=tool_calls) +def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]: + if isinstance(arguments, dict): + return arguments + else: + try: + return json.loads(arguments) + except Exception: + return arguments + + +def parse_tool_args_if_needed(message: ChatMessage) -> ChatMessage: + for tool_call in message.tool_calls: + tool_call.function.arguments = parse_json_if_needed(tool_call.function.arguments) + return message + + class MessageRole(str, Enum): USER = "user" ASSISTANT = "assistant" @@ -181,17 +197,6 @@ def get_clean_message_list( return final_message_list -def parse_dictionary(possible_dictionary: str) -> Union[Dict, str]: - try: - start, end = ( - possible_dictionary.find("{"), - possible_dictionary.rfind("}") + 1, - ) - return json.loads(possible_dictionary[start:end]) - except Exception: - return possible_dictionary - - class Model: def __init__(self): self.last_input_token_count = None @@ -304,7 +309,10 @@ class HfApiModel(Model): ) self.last_input_token_count = response.usage.prompt_tokens self.last_output_token_count = response.usage.completion_tokens - return ChatMessage.from_hf_api(response.choices[0].message) + message = ChatMessage.from_hf_api(response.choices[0].message) + if tools_to_call_from is not None: + return parse_tool_args_if_needed(message) + return message class TransformersModel(Model): @@ -523,7 +531,10 @@ class LiteLLMModel(Model): ) self.last_input_token_count = response.usage.prompt_tokens self.last_output_token_count = response.usage.completion_tokens - return response.choices[0].message + message = response.choices[0].message + if tools_to_call_from is not None: + return parse_tool_args_if_needed(message) + return message class OpenAIServerModel(Model): @@ -582,7 +593,7 @@ class OpenAIServerModel(Model): model=self.model_id, messages=messages, tools=[get_json_schema(tool) for tool in tools_to_call_from], - tool_choice="auto", + tool_choice="required", stop=stop_sequences, **self.kwargs, ) @@ -595,7 +606,10 @@ class OpenAIServerModel(Model): ) self.last_input_token_count = response.usage.prompt_tokens self.last_output_token_count = response.usage.completion_tokens - return response.choices[0].message + message = response.choices[0].message + if tools_to_call_from is not None: + return parse_tool_args_if_needed(message) + return message __all__ = [ diff --git a/tests/test_models.py b/tests/test_models.py index 8faae92..4f897f3 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -17,6 +17,7 @@ import unittest from typing import Optional from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool +from smolagents.models import parse_json_if_needed class ModelTests(unittest.TestCase): @@ -55,3 +56,20 @@ class ModelTests(unittest.TestCase): messages = [{"role": "user", "content": "Hello!"}] output = model(messages, stop_sequences=["great"]).content assert output == "assistant\nHello" + + def test_parse_json_if_needed(self): + args = "abc" + parsed_args = parse_json_if_needed(args) + assert parsed_args == "abc" + + args = '{"a": 3}' + parsed_args = parse_json_if_needed(args) + assert parsed_args == {"a": 3} + + args = "3" + parsed_args = parse_json_if_needed(args) + assert parsed_args == 3 + + args = 3 + parsed_args = parse_json_if_needed(args) + assert parsed_args == 3