Improve tool call argument parsing (#267)
* Improve tool call argument parsing
This commit is contained in:
parent
89a6350fe2
commit
0abd91cf72
|
@ -755,6 +755,8 @@ class ToolCallingAgent(MultiStepAgent):
|
||||||
tools_to_call_from=list(self.tools.values()),
|
tools_to_call_from=list(self.tools.values()),
|
||||||
stop_sequences=["Observation:"],
|
stop_sequences=["Observation:"],
|
||||||
)
|
)
|
||||||
|
if model_message.tool_calls is None or len(model_message.tool_calls) == 0:
|
||||||
|
raise Exception("Model did not call any tools. Call `final_answer` tool to return a final answer.")
|
||||||
tool_call = model_message.tool_calls[0]
|
tool_call = model_message.tool_calls[0]
|
||||||
tool_name, tool_call_id = tool_call.function.name, tool_call.id
|
tool_name, tool_call_id = tool_call.function.name, tool_call.id
|
||||||
tool_arguments = tool_call.function.arguments
|
tool_arguments = tool_call.function.arguments
|
||||||
|
@ -1022,9 +1024,9 @@ class ManagedAgent:
|
||||||
"""Adds additional prompting for the managed agent, like 'add more detail in your answer'."""
|
"""Adds additional prompting for the managed agent, like 'add more detail in your answer'."""
|
||||||
full_task = self.managed_agent_prompt.format(name=self.name, task=task)
|
full_task = self.managed_agent_prompt.format(name=self.name, task=task)
|
||||||
if self.additional_prompting:
|
if self.additional_prompting:
|
||||||
full_task = full_task.replace("\n{{additional_prompting}}", self.additional_prompting).strip()
|
full_task = full_task.replace("\n{additional_prompting}", self.additional_prompting).strip()
|
||||||
else:
|
else:
|
||||||
full_task = full_task.replace("\n{{additional_prompting}}", "").strip()
|
full_task = full_task.replace("\n{additional_prompting}", "").strip()
|
||||||
return full_task
|
return full_task
|
||||||
|
|
||||||
def __call__(self, request, **kwargs):
|
def __call__(self, request, **kwargs):
|
||||||
|
|
|
@ -158,6 +158,8 @@ class DuckDuckGoSearchTool(Tool):
|
||||||
|
|
||||||
def forward(self, query: str) -> str:
|
def forward(self, query: str) -> str:
|
||||||
results = self.ddgs.text(query, max_results=self.max_results)
|
results = self.ddgs.text(query, max_results=self.max_results)
|
||||||
|
if len(results) == 0:
|
||||||
|
raise Exception("No results found! Try a less restrictive/shorter query.")
|
||||||
postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results]
|
postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results]
|
||||||
return "## Search Results\n\n" + "\n\n".join(postprocessed_results)
|
return "## Search Results\n\n" + "\n\n".join(postprocessed_results)
|
||||||
|
|
||||||
|
|
|
@ -104,6 +104,22 @@ class ChatMessage:
|
||||||
return cls(role=message.role, content=message.content, tool_calls=tool_calls)
|
return cls(role=message.role, content=message.content, tool_calls=tool_calls)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]:
|
||||||
|
if isinstance(arguments, dict):
|
||||||
|
return arguments
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
return json.loads(arguments)
|
||||||
|
except Exception:
|
||||||
|
return arguments
|
||||||
|
|
||||||
|
|
||||||
|
def parse_tool_args_if_needed(message: ChatMessage) -> ChatMessage:
|
||||||
|
for tool_call in message.tool_calls:
|
||||||
|
tool_call.function.arguments = parse_json_if_needed(tool_call.function.arguments)
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
class MessageRole(str, Enum):
|
class MessageRole(str, Enum):
|
||||||
USER = "user"
|
USER = "user"
|
||||||
ASSISTANT = "assistant"
|
ASSISTANT = "assistant"
|
||||||
|
@ -181,17 +197,6 @@ def get_clean_message_list(
|
||||||
return final_message_list
|
return final_message_list
|
||||||
|
|
||||||
|
|
||||||
def parse_dictionary(possible_dictionary: str) -> Union[Dict, str]:
|
|
||||||
try:
|
|
||||||
start, end = (
|
|
||||||
possible_dictionary.find("{"),
|
|
||||||
possible_dictionary.rfind("}") + 1,
|
|
||||||
)
|
|
||||||
return json.loads(possible_dictionary[start:end])
|
|
||||||
except Exception:
|
|
||||||
return possible_dictionary
|
|
||||||
|
|
||||||
|
|
||||||
class Model:
|
class Model:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.last_input_token_count = None
|
self.last_input_token_count = None
|
||||||
|
@ -304,7 +309,10 @@ class HfApiModel(Model):
|
||||||
)
|
)
|
||||||
self.last_input_token_count = response.usage.prompt_tokens
|
self.last_input_token_count = response.usage.prompt_tokens
|
||||||
self.last_output_token_count = response.usage.completion_tokens
|
self.last_output_token_count = response.usage.completion_tokens
|
||||||
return ChatMessage.from_hf_api(response.choices[0].message)
|
message = ChatMessage.from_hf_api(response.choices[0].message)
|
||||||
|
if tools_to_call_from is not None:
|
||||||
|
return parse_tool_args_if_needed(message)
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
class TransformersModel(Model):
|
class TransformersModel(Model):
|
||||||
|
@ -523,7 +531,10 @@ class LiteLLMModel(Model):
|
||||||
)
|
)
|
||||||
self.last_input_token_count = response.usage.prompt_tokens
|
self.last_input_token_count = response.usage.prompt_tokens
|
||||||
self.last_output_token_count = response.usage.completion_tokens
|
self.last_output_token_count = response.usage.completion_tokens
|
||||||
return response.choices[0].message
|
message = response.choices[0].message
|
||||||
|
if tools_to_call_from is not None:
|
||||||
|
return parse_tool_args_if_needed(message)
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
class OpenAIServerModel(Model):
|
class OpenAIServerModel(Model):
|
||||||
|
@ -582,7 +593,7 @@ class OpenAIServerModel(Model):
|
||||||
model=self.model_id,
|
model=self.model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=[get_json_schema(tool) for tool in tools_to_call_from],
|
tools=[get_json_schema(tool) for tool in tools_to_call_from],
|
||||||
tool_choice="auto",
|
tool_choice="required",
|
||||||
stop=stop_sequences,
|
stop=stop_sequences,
|
||||||
**self.kwargs,
|
**self.kwargs,
|
||||||
)
|
)
|
||||||
|
@ -595,7 +606,10 @@ class OpenAIServerModel(Model):
|
||||||
)
|
)
|
||||||
self.last_input_token_count = response.usage.prompt_tokens
|
self.last_input_token_count = response.usage.prompt_tokens
|
||||||
self.last_output_token_count = response.usage.completion_tokens
|
self.last_output_token_count = response.usage.completion_tokens
|
||||||
return response.choices[0].message
|
message = response.choices[0].message
|
||||||
|
if tools_to_call_from is not None:
|
||||||
|
return parse_tool_args_if_needed(message)
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
|
@ -17,6 +17,7 @@ import unittest
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool
|
from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool
|
||||||
|
from smolagents.models import parse_json_if_needed
|
||||||
|
|
||||||
|
|
||||||
class ModelTests(unittest.TestCase):
|
class ModelTests(unittest.TestCase):
|
||||||
|
@ -55,3 +56,20 @@ class ModelTests(unittest.TestCase):
|
||||||
messages = [{"role": "user", "content": "Hello!"}]
|
messages = [{"role": "user", "content": "Hello!"}]
|
||||||
output = model(messages, stop_sequences=["great"]).content
|
output = model(messages, stop_sequences=["great"]).content
|
||||||
assert output == "assistant\nHello"
|
assert output == "assistant\nHello"
|
||||||
|
|
||||||
|
def test_parse_json_if_needed(self):
|
||||||
|
args = "abc"
|
||||||
|
parsed_args = parse_json_if_needed(args)
|
||||||
|
assert parsed_args == "abc"
|
||||||
|
|
||||||
|
args = '{"a": 3}'
|
||||||
|
parsed_args = parse_json_if_needed(args)
|
||||||
|
assert parsed_args == {"a": 3}
|
||||||
|
|
||||||
|
args = "3"
|
||||||
|
parsed_args = parse_json_if_needed(args)
|
||||||
|
assert parsed_args == 3
|
||||||
|
|
||||||
|
args = 3
|
||||||
|
parsed_args = parse_json_if_needed(args)
|
||||||
|
assert parsed_args == 3
|
||||||
|
|
Loading…
Reference in New Issue