Support multiple code blobs (#128)

This commit is contained in:
Aymeric Roucher 2025-01-08 23:20:50 +01:00 committed by GitHub
parent e1414f6653
commit 067ae9bc90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 36 additions and 16 deletions

View File

@ -60,7 +60,7 @@ from .utils import (
AgentMaxStepsError, AgentMaxStepsError,
AgentParsingError, AgentParsingError,
console, console,
parse_code_blob, parse_code_blobs,
parse_json_tool_call, parse_json_tool_call,
truncate_content, truncate_content,
) )
@ -894,7 +894,7 @@ class CodeAgent(MultiStepAgent):
) )
llm_output = self.model( llm_output = self.model(
self.input_messages, self.input_messages,
stop_sequences=["<end_action>", "Observation:"], stop_sequences=["<end_code>", "Observation:"],
**additional_args, **additional_args,
) )
log_entry.llm_output = llm_output log_entry.llm_output = llm_output
@ -920,7 +920,7 @@ class CodeAgent(MultiStepAgent):
# Parse # Parse
try: try:
code_action = fix_final_answer_code(parse_code_blob(llm_output)) code_action = fix_final_answer_code(parse_code_blobs(llm_output))
except Exception as e: except Exception as e:
error_msg = ( error_msg = (
f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs." f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."

View File

@ -39,12 +39,12 @@ logger = logging.getLogger(__name__)
DEFAULT_JSONAGENT_REGEX_GRAMMAR = { DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
"type": "regex", "type": "regex",
"value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n<end_action>', "value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n<end_code>',
} }
DEFAULT_CODEAGENT_REGEX_GRAMMAR = { DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
"type": "regex", "type": "regex",
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_action>", "value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_code>",
} }
try: try:

View File

@ -105,11 +105,11 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]:
raise ValueError(f"Error in parsing the JSON blob: {e}") raise ValueError(f"Error in parsing the JSON blob: {e}")
def parse_code_blob(code_blob: str) -> str: def parse_code_blobs(code_blob: str) -> str:
"""Parses the LLM's output to get any code blob inside. Will retrun the code directly if it's code.""" """Parses the LLM's output to get any code blob inside. Will retrun the code directly if it's code."""
pattern = r"```(?:py|python)?\n(.*?)\n```" pattern = r"```(?:py|python)?\n(.*?)\n```"
match = re.search(pattern, code_blob, re.DOTALL) matches = re.findall(pattern, code_blob, re.DOTALL)
if match is None: if len(matches) == 0:
try: # Maybe the LLM outputted a code blob directly try: # Maybe the LLM outputted a code blob directly
ast.parse(code_blob) ast.parse(code_blob)
return code_blob return code_blob
@ -123,7 +123,7 @@ The code blob is invalid, because the regex pattern {pattern} was not found in {
Code: Code:
```py ```py
final_answer("YOUR FINAL ANSWER HERE") final_answer("YOUR FINAL ANSWER HERE")
```<end_action>""".strip() ```<end_code>""".strip()
) )
raise ValueError( raise ValueError(
f""" f"""
@ -132,9 +132,9 @@ Thoughts: Your thoughts
Code: Code:
```py ```py
# Your python code here # Your python code here
```<end_action>""".strip() ```<end_code>""".strip()
) )
return match.group(1).strip() return "\n\n".join(match.strip() for match in matches)
def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]: def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]:

View File

@ -15,16 +15,16 @@
import unittest import unittest
import pytest import pytest
from smolagents.utils import parse_code_blob from smolagents.utils import parse_code_blobs
class AgentTextTests(unittest.TestCase): class AgentTextTests(unittest.TestCase):
def test_parse_code_blob(self): def test_parse_code_blobs(self):
with pytest.raises(ValueError): with pytest.raises(ValueError):
parse_code_blob("Wrong blob!") parse_code_blobs("Wrong blob!")
# Parsing mardkwon with code blobs should work # Parsing mardkwon with code blobs should work
output = parse_code_blob(""" output = parse_code_blobs("""
Here is how to solve the problem: Here is how to solve the problem:
Code: Code:
```py ```py
@ -35,5 +35,25 @@ import numpy as np
# Parsing code blobs should work # Parsing code blobs should work
code_blob = "import numpy as np" code_blob = "import numpy as np"
output = parse_code_blob(code_blob) output = parse_code_blobs(code_blob)
assert output == code_blob assert output == code_blob
def test_multiple_code_blobs(self):
test_input = """Here's a function that adds numbers:
```python
def add(a, b):
return a + b
```
And here's a function that multiplies them:
```py
def multiply(a, b):
return a * b
```"""
expected_output = """def add(a, b):
return a + b
def multiply(a, b):
return a * b"""
result = parse_code_blobs(test_input)
assert result == expected_output