Support multiple code blobs (#128)
This commit is contained in:
parent
e1414f6653
commit
067ae9bc90
|
@ -60,7 +60,7 @@ from .utils import (
|
||||||
AgentMaxStepsError,
|
AgentMaxStepsError,
|
||||||
AgentParsingError,
|
AgentParsingError,
|
||||||
console,
|
console,
|
||||||
parse_code_blob,
|
parse_code_blobs,
|
||||||
parse_json_tool_call,
|
parse_json_tool_call,
|
||||||
truncate_content,
|
truncate_content,
|
||||||
)
|
)
|
||||||
|
@ -894,7 +894,7 @@ class CodeAgent(MultiStepAgent):
|
||||||
)
|
)
|
||||||
llm_output = self.model(
|
llm_output = self.model(
|
||||||
self.input_messages,
|
self.input_messages,
|
||||||
stop_sequences=["<end_action>", "Observation:"],
|
stop_sequences=["<end_code>", "Observation:"],
|
||||||
**additional_args,
|
**additional_args,
|
||||||
)
|
)
|
||||||
log_entry.llm_output = llm_output
|
log_entry.llm_output = llm_output
|
||||||
|
@ -920,7 +920,7 @@ class CodeAgent(MultiStepAgent):
|
||||||
|
|
||||||
# Parse
|
# Parse
|
||||||
try:
|
try:
|
||||||
code_action = fix_final_answer_code(parse_code_blob(llm_output))
|
code_action = fix_final_answer_code(parse_code_blobs(llm_output))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = (
|
error_msg = (
|
||||||
f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
|
f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
|
||||||
|
|
|
@ -39,12 +39,12 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
|
DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
|
||||||
"type": "regex",
|
"type": "regex",
|
||||||
"value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n<end_action>',
|
"value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n<end_code>',
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
|
DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
|
||||||
"type": "regex",
|
"type": "regex",
|
||||||
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_action>",
|
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_code>",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -105,11 +105,11 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]:
|
||||||
raise ValueError(f"Error in parsing the JSON blob: {e}")
|
raise ValueError(f"Error in parsing the JSON blob: {e}")
|
||||||
|
|
||||||
|
|
||||||
def parse_code_blob(code_blob: str) -> str:
|
def parse_code_blobs(code_blob: str) -> str:
|
||||||
"""Parses the LLM's output to get any code blob inside. Will retrun the code directly if it's code."""
|
"""Parses the LLM's output to get any code blob inside. Will retrun the code directly if it's code."""
|
||||||
pattern = r"```(?:py|python)?\n(.*?)\n```"
|
pattern = r"```(?:py|python)?\n(.*?)\n```"
|
||||||
match = re.search(pattern, code_blob, re.DOTALL)
|
matches = re.findall(pattern, code_blob, re.DOTALL)
|
||||||
if match is None:
|
if len(matches) == 0:
|
||||||
try: # Maybe the LLM outputted a code blob directly
|
try: # Maybe the LLM outputted a code blob directly
|
||||||
ast.parse(code_blob)
|
ast.parse(code_blob)
|
||||||
return code_blob
|
return code_blob
|
||||||
|
@ -123,7 +123,7 @@ The code blob is invalid, because the regex pattern {pattern} was not found in {
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
final_answer("YOUR FINAL ANSWER HERE")
|
final_answer("YOUR FINAL ANSWER HERE")
|
||||||
```<end_action>""".strip()
|
```<end_code>""".strip()
|
||||||
)
|
)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"""
|
f"""
|
||||||
|
@ -132,9 +132,9 @@ Thoughts: Your thoughts
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
# Your python code here
|
# Your python code here
|
||||||
```<end_action>""".strip()
|
```<end_code>""".strip()
|
||||||
)
|
)
|
||||||
return match.group(1).strip()
|
return "\n\n".join(match.strip() for match in matches)
|
||||||
|
|
||||||
|
|
||||||
def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]:
|
def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]:
|
||||||
|
|
|
@ -15,16 +15,16 @@
|
||||||
import unittest
|
import unittest
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from smolagents.utils import parse_code_blob
|
from smolagents.utils import parse_code_blobs
|
||||||
|
|
||||||
|
|
||||||
class AgentTextTests(unittest.TestCase):
|
class AgentTextTests(unittest.TestCase):
|
||||||
def test_parse_code_blob(self):
|
def test_parse_code_blobs(self):
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
parse_code_blob("Wrong blob!")
|
parse_code_blobs("Wrong blob!")
|
||||||
|
|
||||||
# Parsing mardkwon with code blobs should work
|
# Parsing mardkwon with code blobs should work
|
||||||
output = parse_code_blob("""
|
output = parse_code_blobs("""
|
||||||
Here is how to solve the problem:
|
Here is how to solve the problem:
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
|
@ -35,5 +35,25 @@ import numpy as np
|
||||||
|
|
||||||
# Parsing code blobs should work
|
# Parsing code blobs should work
|
||||||
code_blob = "import numpy as np"
|
code_blob = "import numpy as np"
|
||||||
output = parse_code_blob(code_blob)
|
output = parse_code_blobs(code_blob)
|
||||||
assert output == code_blob
|
assert output == code_blob
|
||||||
|
|
||||||
|
def test_multiple_code_blobs(self):
|
||||||
|
test_input = """Here's a function that adds numbers:
|
||||||
|
```python
|
||||||
|
def add(a, b):
|
||||||
|
return a + b
|
||||||
|
```
|
||||||
|
And here's a function that multiplies them:
|
||||||
|
```py
|
||||||
|
def multiply(a, b):
|
||||||
|
return a * b
|
||||||
|
```"""
|
||||||
|
|
||||||
|
expected_output = """def add(a, b):
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
def multiply(a, b):
|
||||||
|
return a * b"""
|
||||||
|
result = parse_code_blobs(test_input)
|
||||||
|
assert result == expected_output
|
||||||
|
|
Loading…
Reference in New Issue