Support multiple code blobs (#128)
This commit is contained in:
parent
e1414f6653
commit
067ae9bc90
|
@ -60,7 +60,7 @@ from .utils import (
|
|||
AgentMaxStepsError,
|
||||
AgentParsingError,
|
||||
console,
|
||||
parse_code_blob,
|
||||
parse_code_blobs,
|
||||
parse_json_tool_call,
|
||||
truncate_content,
|
||||
)
|
||||
|
@ -894,7 +894,7 @@ class CodeAgent(MultiStepAgent):
|
|||
)
|
||||
llm_output = self.model(
|
||||
self.input_messages,
|
||||
stop_sequences=["<end_action>", "Observation:"],
|
||||
stop_sequences=["<end_code>", "Observation:"],
|
||||
**additional_args,
|
||||
)
|
||||
log_entry.llm_output = llm_output
|
||||
|
@ -920,7 +920,7 @@ class CodeAgent(MultiStepAgent):
|
|||
|
||||
# Parse
|
||||
try:
|
||||
code_action = fix_final_answer_code(parse_code_blob(llm_output))
|
||||
code_action = fix_final_answer_code(parse_code_blobs(llm_output))
|
||||
except Exception as e:
|
||||
error_msg = (
|
||||
f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
|
||||
|
|
|
@ -39,12 +39,12 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
|
||||
"type": "regex",
|
||||
"value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n<end_action>',
|
||||
"value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n<end_code>',
|
||||
}
|
||||
|
||||
DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
|
||||
"type": "regex",
|
||||
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_action>",
|
||||
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_code>",
|
||||
}
|
||||
|
||||
try:
|
||||
|
|
|
@ -105,11 +105,11 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]:
|
|||
raise ValueError(f"Error in parsing the JSON blob: {e}")
|
||||
|
||||
|
||||
def parse_code_blob(code_blob: str) -> str:
|
||||
def parse_code_blobs(code_blob: str) -> str:
|
||||
"""Parses the LLM's output to get any code blob inside. Will retrun the code directly if it's code."""
|
||||
pattern = r"```(?:py|python)?\n(.*?)\n```"
|
||||
match = re.search(pattern, code_blob, re.DOTALL)
|
||||
if match is None:
|
||||
matches = re.findall(pattern, code_blob, re.DOTALL)
|
||||
if len(matches) == 0:
|
||||
try: # Maybe the LLM outputted a code blob directly
|
||||
ast.parse(code_blob)
|
||||
return code_blob
|
||||
|
@ -123,7 +123,7 @@ The code blob is invalid, because the regex pattern {pattern} was not found in {
|
|||
Code:
|
||||
```py
|
||||
final_answer("YOUR FINAL ANSWER HERE")
|
||||
```<end_action>""".strip()
|
||||
```<end_code>""".strip()
|
||||
)
|
||||
raise ValueError(
|
||||
f"""
|
||||
|
@ -132,9 +132,9 @@ Thoughts: Your thoughts
|
|||
Code:
|
||||
```py
|
||||
# Your python code here
|
||||
```<end_action>""".strip()
|
||||
```<end_code>""".strip()
|
||||
)
|
||||
return match.group(1).strip()
|
||||
return "\n\n".join(match.strip() for match in matches)
|
||||
|
||||
|
||||
def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]:
|
||||
|
|
|
@ -15,16 +15,16 @@
|
|||
import unittest
|
||||
import pytest
|
||||
|
||||
from smolagents.utils import parse_code_blob
|
||||
from smolagents.utils import parse_code_blobs
|
||||
|
||||
|
||||
class AgentTextTests(unittest.TestCase):
|
||||
def test_parse_code_blob(self):
|
||||
def test_parse_code_blobs(self):
|
||||
with pytest.raises(ValueError):
|
||||
parse_code_blob("Wrong blob!")
|
||||
parse_code_blobs("Wrong blob!")
|
||||
|
||||
# Parsing mardkwon with code blobs should work
|
||||
output = parse_code_blob("""
|
||||
output = parse_code_blobs("""
|
||||
Here is how to solve the problem:
|
||||
Code:
|
||||
```py
|
||||
|
@ -35,5 +35,25 @@ import numpy as np
|
|||
|
||||
# Parsing code blobs should work
|
||||
code_blob = "import numpy as np"
|
||||
output = parse_code_blob(code_blob)
|
||||
output = parse_code_blobs(code_blob)
|
||||
assert output == code_blob
|
||||
|
||||
def test_multiple_code_blobs(self):
|
||||
test_input = """Here's a function that adds numbers:
|
||||
```python
|
||||
def add(a, b):
|
||||
return a + b
|
||||
```
|
||||
And here's a function that multiplies them:
|
||||
```py
|
||||
def multiply(a, b):
|
||||
return a * b
|
||||
```"""
|
||||
|
||||
expected_output = """def add(a, b):
|
||||
return a + b
|
||||
|
||||
def multiply(a, b):
|
||||
return a * b"""
|
||||
result = parse_code_blobs(test_input)
|
||||
assert result == expected_output
|
||||
|
|
Loading…
Reference in New Issue