diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 11be825..172757e 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -132,8 +132,9 @@ def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]: def parse_tool_args_if_needed(message: ChatMessage) -> ChatMessage: - for tool_call in message.tool_calls: - tool_call.function.arguments = parse_json_if_needed(tool_call.function.arguments) + if message.tool_calls is not None: + for tool_call in message.tool_calls: + tool_call.function.arguments = parse_json_if_needed(tool_call.function.arguments) return message diff --git a/tests/test_models.py b/tests/test_models.py index ea70983..8c83a9f 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -107,6 +107,11 @@ class ModelTests(unittest.TestCase): output = model(messages, stop_sequences=["great"]).content assert output == "Hello! How can" + def test_parse_tool_args_if_needed(self): + original_message = ChatMessage(role="user", content=[{"type": "text", "text": "Hello!"}]) + parsed_message = models.parse_tool_args_if_needed(original_message) + assert parsed_message == original_message + def test_parse_json_if_needed(self): args = "abc" parsed_args = parse_json_if_needed(args)