fix tool code generation in case of multiline descriptions (#613)
This commit is contained in:
parent
a2e92c74a9
commit
d9da0a70ad
|
@ -221,7 +221,7 @@ class Tool:
|
||||||
|
|
||||||
class {class_name}(Tool):
|
class {class_name}(Tool):
|
||||||
name = "{self.name}"
|
name = "{self.name}"
|
||||||
description = "{self.description}"
|
description = {json.dumps(textwrap.dedent(self.description).strip())}
|
||||||
inputs = {json.dumps(self.inputs, separators=(",", ":"))}
|
inputs = {json.dumps(self.inputs, separators=(",", ":"))}
|
||||||
output_type = "{self.output_type}"
|
output_type = "{self.output_type}"
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -299,10 +299,12 @@ def instance_to_source(instance, base_cls=None):
|
||||||
|
|
||||||
for name, value in class_attrs.items():
|
for name, value in class_attrs.items():
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
|
# multiline value
|
||||||
if "\n" in value:
|
if "\n" in value:
|
||||||
class_lines.append(f' {name} = """{value}"""')
|
escaped_value = value.replace('"""', r"\"\"\"") # Escape triple quotes
|
||||||
|
class_lines.append(f' {name} = """{escaped_value}"""')
|
||||||
else:
|
else:
|
||||||
class_lines.append(f' {name} = "{value}"')
|
class_lines.append(f" {name} = {json.dumps(value)}")
|
||||||
else:
|
else:
|
||||||
class_lines.append(f" {name} = {repr(value)}")
|
class_lines.append(f" {name} = {repr(value)}")
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -437,6 +438,46 @@ class ToolTests(unittest.TestCase):
|
||||||
assert get_weather.inputs["locations"]["type"] == "array"
|
assert get_weather.inputs["locations"]["type"] == "array"
|
||||||
assert get_weather.inputs["months"]["type"] == "array"
|
assert get_weather.inputs["months"]["type"] == "array"
|
||||||
|
|
||||||
|
def test_saving_tool_produces_valid_pyhon_code_with_multiline_description(self):
|
||||||
|
@tool
|
||||||
|
def get_weather(location: Any) -> None:
|
||||||
|
"""
|
||||||
|
Get weather in the next days at given location.
|
||||||
|
And works pretty well.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
location: The location to get the weather for.
|
||||||
|
"""
|
||||||
|
return
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
get_weather.save(tmp_dir)
|
||||||
|
with open(os.path.join(tmp_dir, "tool.py"), "r", encoding="utf-8") as f:
|
||||||
|
source_code = f.read()
|
||||||
|
compile(source_code, f.name, "exec")
|
||||||
|
|
||||||
|
def test_saving_tool_produces_valid_python_code_with_complex_name(self):
|
||||||
|
# Test one cannot save tool with additional args in init
|
||||||
|
class FailTool(Tool):
|
||||||
|
name = 'spe"\rcific'
|
||||||
|
description = """test \n\r
|
||||||
|
description"""
|
||||||
|
inputs = {"string_input": {"type": "string", "description": "input description"}}
|
||||||
|
output_type = "string"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(self)
|
||||||
|
|
||||||
|
def forward(self, string_input):
|
||||||
|
return "foo"
|
||||||
|
|
||||||
|
fail_tool = FailTool()
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
fail_tool.save(tmp_dir)
|
||||||
|
with open(os.path.join(tmp_dir, "tool.py"), "r", encoding="utf-8") as f:
|
||||||
|
source_code = f.read()
|
||||||
|
compile(source_code, f.name, "exec")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_server_parameters():
|
def mock_server_parameters():
|
||||||
|
|
Loading…
Reference in New Issue