diff --git a/src/smolagents/tools.py b/src/smolagents/tools.py index 15275e7..2e35bcb 100644 --- a/src/smolagents/tools.py +++ b/src/smolagents/tools.py @@ -221,7 +221,7 @@ class Tool: class {class_name}(Tool): name = "{self.name}" - description = "{self.description}" + description = {json.dumps(textwrap.dedent(self.description).strip())} inputs = {json.dumps(self.inputs, separators=(",", ":"))} output_type = "{self.output_type}" """ diff --git a/src/smolagents/utils.py b/src/smolagents/utils.py index b9868dc..3f7219b 100644 --- a/src/smolagents/utils.py +++ b/src/smolagents/utils.py @@ -299,10 +299,12 @@ def instance_to_source(instance, base_cls=None): for name, value in class_attrs.items(): if isinstance(value, str): + # multiline value if "\n" in value: - class_lines.append(f' {name} = """{value}"""') + escaped_value = value.replace('"""', r"\"\"\"") # Escape triple quotes + class_lines.append(f' {name} = """{escaped_value}"""') else: - class_lines.append(f' {name} = "{value}"') + class_lines.append(f" {name} = {json.dumps(value)}") else: class_lines.append(f" {name} = {repr(value)}") diff --git a/tests/test_tools.py b/tests/test_tools.py index fcc05d5..cb8a8ee 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import tempfile import unittest from pathlib import Path @@ -437,6 +438,46 @@ class ToolTests(unittest.TestCase): assert get_weather.inputs["locations"]["type"] == "array" assert get_weather.inputs["months"]["type"] == "array" + def test_saving_tool_produces_valid_pyhon_code_with_multiline_description(self): + @tool + def get_weather(location: Any) -> None: + """ + Get weather in the next days at given location. + And works pretty well. + + Args: + location: The location to get the weather for. + """ + return + + with tempfile.TemporaryDirectory() as tmp_dir: + get_weather.save(tmp_dir) + with open(os.path.join(tmp_dir, "tool.py"), "r", encoding="utf-8") as f: + source_code = f.read() + compile(source_code, f.name, "exec") + + def test_saving_tool_produces_valid_python_code_with_complex_name(self): + # Test one cannot save tool with additional args in init + class FailTool(Tool): + name = 'spe"\rcific' + description = """test \n\r + description""" + inputs = {"string_input": {"type": "string", "description": "input description"}} + output_type = "string" + + def __init__(self): + super().__init__(self) + + def forward(self, string_input): + return "foo" + + fail_tool = FailTool() + with tempfile.TemporaryDirectory() as tmp_dir: + fail_tool.save(tmp_dir) + with open(os.path.join(tmp_dir, "tool.py"), "r", encoding="utf-8") as f: + source_code = f.read() + compile(source_code, f.name, "exec") + @pytest.fixture def mock_server_parameters():