fix tool code generation in case of multiline descriptions (#613)

This commit is contained in:
Alex 2025-02-14 12:28:08 +01:00 committed by GitHub
parent a2e92c74a9
commit d9da0a70ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 46 additions and 3 deletions

View File

@ -221,7 +221,7 @@ class Tool:
class {class_name}(Tool): class {class_name}(Tool):
name = "{self.name}" name = "{self.name}"
description = "{self.description}" description = {json.dumps(textwrap.dedent(self.description).strip())}
inputs = {json.dumps(self.inputs, separators=(",", ":"))} inputs = {json.dumps(self.inputs, separators=(",", ":"))}
output_type = "{self.output_type}" output_type = "{self.output_type}"
""" """

View File

@ -299,10 +299,12 @@ def instance_to_source(instance, base_cls=None):
for name, value in class_attrs.items(): for name, value in class_attrs.items():
if isinstance(value, str): if isinstance(value, str):
# multiline value
if "\n" in value: if "\n" in value:
class_lines.append(f' {name} = """{value}"""') escaped_value = value.replace('"""', r"\"\"\"") # Escape triple quotes
class_lines.append(f' {name} = """{escaped_value}"""')
else: else:
class_lines.append(f' {name} = "{value}"') class_lines.append(f" {name} = {json.dumps(value)}")
else: else:
class_lines.append(f" {name} = {repr(value)}") class_lines.append(f" {name} = {repr(value)}")

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import tempfile import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
@ -437,6 +438,46 @@ class ToolTests(unittest.TestCase):
assert get_weather.inputs["locations"]["type"] == "array" assert get_weather.inputs["locations"]["type"] == "array"
assert get_weather.inputs["months"]["type"] == "array" assert get_weather.inputs["months"]["type"] == "array"
def test_saving_tool_produces_valid_pyhon_code_with_multiline_description(self):
@tool
def get_weather(location: Any) -> None:
"""
Get weather in the next days at given location.
And works pretty well.
Args:
location: The location to get the weather for.
"""
return
with tempfile.TemporaryDirectory() as tmp_dir:
get_weather.save(tmp_dir)
with open(os.path.join(tmp_dir, "tool.py"), "r", encoding="utf-8") as f:
source_code = f.read()
compile(source_code, f.name, "exec")
def test_saving_tool_produces_valid_python_code_with_complex_name(self):
# Test one cannot save tool with additional args in init
class FailTool(Tool):
name = 'spe"\rcific'
description = """test \n\r
description"""
inputs = {"string_input": {"type": "string", "description": "input description"}}
output_type = "string"
def __init__(self):
super().__init__(self)
def forward(self, string_input):
return "foo"
fail_tool = FailTool()
with tempfile.TemporaryDirectory() as tmp_dir:
fail_tool.save(tmp_dir)
with open(os.path.join(tmp_dir, "tool.py"), "r", encoding="utf-8") as f:
source_code = f.read()
compile(source_code, f.name, "exec")
@pytest.fixture @pytest.fixture
def mock_server_parameters(): def mock_server_parameters():