Support multiple code blobs (#128)

This commit is contained in:
Aymeric Roucher 2025-01-08 23:20:50 +01:00 committed by GitHub
parent e1414f6653
commit 067ae9bc90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 36 additions and 16 deletions

View File

@ -60,7 +60,7 @@ from .utils import (
AgentMaxStepsError,
AgentParsingError,
console,
parse_code_blob,
parse_code_blobs,
parse_json_tool_call,
truncate_content,
)
@ -894,7 +894,7 @@ class CodeAgent(MultiStepAgent):
)
llm_output = self.model(
self.input_messages,
stop_sequences=["<end_action>", "Observation:"],
stop_sequences=["<end_code>", "Observation:"],
**additional_args,
)
log_entry.llm_output = llm_output
@ -920,7 +920,7 @@ class CodeAgent(MultiStepAgent):
# Parse
try:
code_action = fix_final_answer_code(parse_code_blob(llm_output))
code_action = fix_final_answer_code(parse_code_blobs(llm_output))
except Exception as e:
error_msg = (
f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."

View File

@ -39,12 +39,12 @@ logger = logging.getLogger(__name__)
DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
"type": "regex",
"value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n<end_action>',
"value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n<end_code>',
}
DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
"type": "regex",
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_action>",
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_code>",
}
try:

View File

@ -105,11 +105,11 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]:
raise ValueError(f"Error in parsing the JSON blob: {e}")
def parse_code_blob(code_blob: str) -> str:
def parse_code_blobs(code_blob: str) -> str:
"""Parses the LLM's output to get any code blob inside. Will retrun the code directly if it's code."""
pattern = r"```(?:py|python)?\n(.*?)\n```"
match = re.search(pattern, code_blob, re.DOTALL)
if match is None:
matches = re.findall(pattern, code_blob, re.DOTALL)
if len(matches) == 0:
try: # Maybe the LLM outputted a code blob directly
ast.parse(code_blob)
return code_blob
@ -123,7 +123,7 @@ The code blob is invalid, because the regex pattern {pattern} was not found in {
Code:
```py
final_answer("YOUR FINAL ANSWER HERE")
```<end_action>""".strip()
```<end_code>""".strip()
)
raise ValueError(
f"""
@ -132,9 +132,9 @@ Thoughts: Your thoughts
Code:
```py
# Your python code here
```<end_action>""".strip()
```<end_code>""".strip()
)
return match.group(1).strip()
return "\n\n".join(match.strip() for match in matches)
def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]:

View File

@ -15,16 +15,16 @@
import unittest
import pytest
from smolagents.utils import parse_code_blob
from smolagents.utils import parse_code_blobs
class AgentTextTests(unittest.TestCase):
def test_parse_code_blob(self):
def test_parse_code_blobs(self):
with pytest.raises(ValueError):
parse_code_blob("Wrong blob!")
parse_code_blobs("Wrong blob!")
# Parsing mardkwon with code blobs should work
output = parse_code_blob("""
output = parse_code_blobs("""
Here is how to solve the problem:
Code:
```py
@ -35,5 +35,25 @@ import numpy as np
# Parsing code blobs should work
code_blob = "import numpy as np"
output = parse_code_blob(code_blob)
output = parse_code_blobs(code_blob)
assert output == code_blob
def test_multiple_code_blobs(self):
test_input = """Here's a function that adds numbers:
```python
def add(a, b):
return a + b
```
And here's a function that multiplies them:
```py
def multiply(a, b):
return a * b
```"""
expected_output = """def add(a, b):
return a + b
def multiply(a, b):
return a * b"""
result = parse_code_blobs(test_input)
assert result == expected_output