Improve tool call argument parsing (#267)
* Improve tool call argument parsing
This commit is contained in:
parent
89a6350fe2
commit
0abd91cf72
|
@ -755,6 +755,8 @@ class ToolCallingAgent(MultiStepAgent):
|
|||
tools_to_call_from=list(self.tools.values()),
|
||||
stop_sequences=["Observation:"],
|
||||
)
|
||||
if model_message.tool_calls is None or len(model_message.tool_calls) == 0:
|
||||
raise Exception("Model did not call any tools. Call `final_answer` tool to return a final answer.")
|
||||
tool_call = model_message.tool_calls[0]
|
||||
tool_name, tool_call_id = tool_call.function.name, tool_call.id
|
||||
tool_arguments = tool_call.function.arguments
|
||||
|
@ -1022,9 +1024,9 @@ class ManagedAgent:
|
|||
"""Adds additional prompting for the managed agent, like 'add more detail in your answer'."""
|
||||
full_task = self.managed_agent_prompt.format(name=self.name, task=task)
|
||||
if self.additional_prompting:
|
||||
full_task = full_task.replace("\n{{additional_prompting}}", self.additional_prompting).strip()
|
||||
full_task = full_task.replace("\n{additional_prompting}", self.additional_prompting).strip()
|
||||
else:
|
||||
full_task = full_task.replace("\n{{additional_prompting}}", "").strip()
|
||||
full_task = full_task.replace("\n{additional_prompting}", "").strip()
|
||||
return full_task
|
||||
|
||||
def __call__(self, request, **kwargs):
|
||||
|
|
|
@ -158,6 +158,8 @@ class DuckDuckGoSearchTool(Tool):
|
|||
|
||||
def forward(self, query: str) -> str:
|
||||
results = self.ddgs.text(query, max_results=self.max_results)
|
||||
if len(results) == 0:
|
||||
raise Exception("No results found! Try a less restrictive/shorter query.")
|
||||
postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results]
|
||||
return "## Search Results\n\n" + "\n\n".join(postprocessed_results)
|
||||
|
||||
|
|
|
@ -104,6 +104,22 @@ class ChatMessage:
|
|||
return cls(role=message.role, content=message.content, tool_calls=tool_calls)
|
||||
|
||||
|
||||
def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]:
|
||||
if isinstance(arguments, dict):
|
||||
return arguments
|
||||
else:
|
||||
try:
|
||||
return json.loads(arguments)
|
||||
except Exception:
|
||||
return arguments
|
||||
|
||||
|
||||
def parse_tool_args_if_needed(message: ChatMessage) -> ChatMessage:
|
||||
for tool_call in message.tool_calls:
|
||||
tool_call.function.arguments = parse_json_if_needed(tool_call.function.arguments)
|
||||
return message
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
|
@ -181,17 +197,6 @@ def get_clean_message_list(
|
|||
return final_message_list
|
||||
|
||||
|
||||
def parse_dictionary(possible_dictionary: str) -> Union[Dict, str]:
|
||||
try:
|
||||
start, end = (
|
||||
possible_dictionary.find("{"),
|
||||
possible_dictionary.rfind("}") + 1,
|
||||
)
|
||||
return json.loads(possible_dictionary[start:end])
|
||||
except Exception:
|
||||
return possible_dictionary
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self):
|
||||
self.last_input_token_count = None
|
||||
|
@ -304,7 +309,10 @@ class HfApiModel(Model):
|
|||
)
|
||||
self.last_input_token_count = response.usage.prompt_tokens
|
||||
self.last_output_token_count = response.usage.completion_tokens
|
||||
return ChatMessage.from_hf_api(response.choices[0].message)
|
||||
message = ChatMessage.from_hf_api(response.choices[0].message)
|
||||
if tools_to_call_from is not None:
|
||||
return parse_tool_args_if_needed(message)
|
||||
return message
|
||||
|
||||
|
||||
class TransformersModel(Model):
|
||||
|
@ -523,7 +531,10 @@ class LiteLLMModel(Model):
|
|||
)
|
||||
self.last_input_token_count = response.usage.prompt_tokens
|
||||
self.last_output_token_count = response.usage.completion_tokens
|
||||
return response.choices[0].message
|
||||
message = response.choices[0].message
|
||||
if tools_to_call_from is not None:
|
||||
return parse_tool_args_if_needed(message)
|
||||
return message
|
||||
|
||||
|
||||
class OpenAIServerModel(Model):
|
||||
|
@ -582,7 +593,7 @@ class OpenAIServerModel(Model):
|
|||
model=self.model_id,
|
||||
messages=messages,
|
||||
tools=[get_json_schema(tool) for tool in tools_to_call_from],
|
||||
tool_choice="auto",
|
||||
tool_choice="required",
|
||||
stop=stop_sequences,
|
||||
**self.kwargs,
|
||||
)
|
||||
|
@ -595,7 +606,10 @@ class OpenAIServerModel(Model):
|
|||
)
|
||||
self.last_input_token_count = response.usage.prompt_tokens
|
||||
self.last_output_token_count = response.usage.completion_tokens
|
||||
return response.choices[0].message
|
||||
message = response.choices[0].message
|
||||
if tools_to_call_from is not None:
|
||||
return parse_tool_args_if_needed(message)
|
||||
return message
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
|
|
@ -17,6 +17,7 @@ import unittest
|
|||
from typing import Optional
|
||||
|
||||
from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool
|
||||
from smolagents.models import parse_json_if_needed
|
||||
|
||||
|
||||
class ModelTests(unittest.TestCase):
|
||||
|
@ -55,3 +56,20 @@ class ModelTests(unittest.TestCase):
|
|||
messages = [{"role": "user", "content": "Hello!"}]
|
||||
output = model(messages, stop_sequences=["great"]).content
|
||||
assert output == "assistant\nHello"
|
||||
|
||||
def test_parse_json_if_needed(self):
|
||||
args = "abc"
|
||||
parsed_args = parse_json_if_needed(args)
|
||||
assert parsed_args == "abc"
|
||||
|
||||
args = '{"a": 3}'
|
||||
parsed_args = parse_json_if_needed(args)
|
||||
assert parsed_args == {"a": 3}
|
||||
|
||||
args = "3"
|
||||
parsed_args = parse_json_if_needed(args)
|
||||
assert parsed_args == 3
|
||||
|
||||
args = 3
|
||||
parsed_args = parse_json_if_needed(args)
|
||||
assert parsed_args == 3
|
||||
|
|
Loading…
Reference in New Issue