Auto correct wrong assignments to final_answer (#123)

* Auto correct wrong assignments to final_answer
This commit is contained in:
Aymeric Roucher 2025-01-08 19:04:11 +01:00 committed by GitHub
parent e5d879feab
commit d3cd0f9e09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 694 additions and 303 deletions

View File

@ -56,4 +56,7 @@ jobs:
uv run pytest -sv ./tests/test_tools.py
- name: Types tests
run: |
uv run pytest -sv ./tests/test_types.py
uv run pytest -sv ./tests/test_types.py
- name: Utils tests
run: |
uv run pytest -sv ./tests/test_utils.py

File diff suppressed because one or more lines are too long

View File

@ -26,7 +26,11 @@ from rich.text import Text
from .default_tools import FinalAnswerTool
from .e2b_executor import E2BExecutor
from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonInterpreter
from .local_python_executor import (
BASE_BUILTIN_MODULES,
LocalPythonInterpreter,
fix_final_answer_code,
)
from .models import MessageRole
from .monitoring import Monitor
from .prompts import (
@ -895,7 +899,6 @@ class CodeAgent(MultiStepAgent):
)
log_entry.llm_output = llm_output
except Exception as e:
console.print_exception()
raise AgentGenerationError(f"Error in generating model output:\n{e}")
if self.verbose:
@ -917,10 +920,11 @@ class CodeAgent(MultiStepAgent):
# Parse
try:
code_action = parse_code_blob(llm_output)
code_action = fix_final_answer_code(parse_code_blob(llm_output))
except Exception as e:
console.print_exception()
error_msg = f"Error in code parsing: {e}. Make sure to provide correct code"
error_msg = (
f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
)
raise AgentParsingError(error_msg)
log_entry.tool_call = ToolCall(
@ -944,8 +948,9 @@ class CodeAgent(MultiStepAgent):
)
)
observation = ""
is_final_answer = False
try:
output, execution_logs = self.python_executor(
output, execution_logs, is_final_answer = self.python_executor(
code_action,
self.state,
)
@ -976,12 +981,6 @@ class CodeAgent(MultiStepAgent):
observation += "Last output from code snippet:\n" + truncated_output
log_entry.observations = observation
is_final_answer = False
for line in code_action.split("\n"):
if line[: len("final_answer")] == "final_answer":
is_final_answer = True
break
execution_outputs_console += [
Text(
f"{('Out - Final answer' if is_final_answer else 'Out')}: {truncated_output}",

View File

@ -112,7 +112,7 @@ class PythonInterpreterTool(Tool):
state=state,
static_tools=self.base_python_tools,
authorized_imports=self.authorized_imports,
)
)[0] # The second element is boolean is_final_answer
)
return f"Stdout:\n{state['print_outputs']}\nOutput: {output}"
except Exception as e:

View File

@ -18,6 +18,7 @@ import ast
import builtins
import difflib
import math
import re
from collections.abc import Mapping
from importlib import import_module
from typing import Any, Callable, Dict, List, Optional, Tuple
@ -129,6 +130,34 @@ def get_iterable(obj):
raise InterpreterError("Object is not iterable")
def fix_final_answer_code(code: str) -> str:
"""
Sometimes an LLM can try to assign a variable to final_answer, which would break the final_answer() tool.
This function fixes this behaviour by replacing variable assignments to final_answer with final_answer_variable,
while preserving function calls to final_answer().
"""
# First, find if there's a direct assignment to final_answer
# Use word boundary and negative lookbehind to ensure it's not an object attribute
assignment_pattern = r"(?<!\.)(?<!\w)\bfinal_answer\s*="
if "final_answer(" not in code or not re.search(assignment_pattern, code):
# If final_answer tool is not called in this blob, then doing the replacement is hazardous because it could false the model's memory for next steps.
# Let's not modify the code and leave the subsequent assignment error happen.
return code
# Pattern for replacing variable assignments
# Looks for 'final_answer' followed by '=' with optional whitespace
# Negative lookbehind ensures we don't match object attributes
assignment_regex = r"(?<!\.)(?<!\w)(\bfinal_answer)(\s*=)"
code = re.sub(assignment_regex, r"final_answer_variable\2", code)
# Pattern for replacing variable usage but not function calls
# Negative lookahead (?!\s*\() ensures we don't match function calls
# Negative lookbehind (?<!\.|\w) ensures we don't match object methods or other variables
variable_regex = r"(?<!\.)(?<!\w)(\bfinal_answer\b)(?!\s*\()"
code = re.sub(variable_regex, "final_answer_variable", code)
return code
def evaluate_unaryop(expression, state, static_tools, custom_tools):
operand = evaluate_ast(expression.operand, state, static_tools, custom_tools)
if isinstance(expression.op, ast.USub):
@ -224,6 +253,10 @@ def create_function(func_def, state, static_tools, custom_tools):
result = evaluate_ast(stmt, func_state, static_tools, custom_tools)
except ReturnException as e:
result = e.value
if func_def.name == "__init__":
return None
return result
return new_func
@ -484,41 +517,31 @@ def evaluate_call(call, state, static_tools, custom_tools):
for keyword in call.keywords
}
if (
isinstance(func, type) and len(func.__module__.split(".")) > 1
): # Check for user-defined classes
# Instantiate the class using its constructor
obj = func.__new__(func) # Create a new instance of the class
if hasattr(obj, "__init__"): # Check if the class has an __init__ method
obj.__init__(*args, **kwargs) # Call the __init__ method correctly
return obj
else:
if func_name == "super":
if not args:
if "__class__" in state and "self" in state:
return super(state["__class__"], state["self"])
else:
raise InterpreterError("super() needs at least one argument")
cls = args[0]
if not isinstance(cls, type):
raise InterpreterError("super() argument 1 must be type")
if len(args) == 1:
return super(cls)
elif len(args) == 2:
instance = args[1]
return super(cls, instance)
if func_name == "super":
if not args:
if "__class__" in state and "self" in state:
return super(state["__class__"], state["self"])
else:
raise InterpreterError("super() takes at most 2 arguments")
raise InterpreterError("super() needs at least one argument")
cls = args[0]
if not isinstance(cls, type):
raise InterpreterError("super() argument 1 must be type")
if len(args) == 1:
return super(cls)
elif len(args) == 2:
instance = args[1]
return super(cls, instance)
else:
if func_name == "print":
output = " ".join(map(str, args))
global PRINT_OUTPUTS
PRINT_OUTPUTS += output + "\n"
# cap the number of lines
return None
else: # Assume it's a callable object
output = func(*args, **kwargs)
return output
raise InterpreterError("super() takes at most 2 arguments")
else:
if func_name == "print":
output = " ".join(map(str, args))
global PRINT_OUTPUTS
PRINT_OUTPUTS += output + "\n"
# cap the number of lines
return None
else: # Assume it's a callable object
return func(*args, **kwargs)
def evaluate_subscript(subscript, state, static_tools, custom_tools):
@ -990,6 +1013,11 @@ def truncate_print_outputs(
return f"Print outputs:\n{print_outputs[:max_len_outputs]}\n_Print outputs have been truncated over the limit of {max_len_outputs} characters._\n"
class FinalAnswerException(Exception):
def __init__(self, value):
self.value = value
def evaluate_python_code(
code: str,
static_tools: Optional[Dict[str, Callable]] = None,
@ -1029,6 +1057,12 @@ def evaluate_python_code(
PRINT_OUTPUTS = ""
global OPERATIONS_COUNT
OPERATIONS_COUNT = 0
def final_answer(value):
raise FinalAnswerException(value)
static_tools["final_answer"] = final_answer
try:
for node in expression.body:
result = evaluate_ast(
@ -1037,7 +1071,14 @@ def evaluate_python_code(
state["print_outputs"] = truncate_content(
PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT
)
return result
is_final_answer = False
return result, is_final_answer
except FinalAnswerException as e:
state["print_outputs"] = truncate_content(
PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT
)
is_final_answer = True
return e.value, is_final_answer
except InterpreterError as e:
msg = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT)
msg += f"Code execution failed at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
@ -1059,9 +1100,11 @@ class LocalPythonInterpreter:
}
# TODO: assert self.authorized imports are all installed locally
def __call__(self, code_action: str, additional_variables: Dict) -> Tuple[Any, str]:
def __call__(
self, code_action: str, additional_variables: Dict
) -> Tuple[Any, str, bool]:
self.state.update(additional_variables)
output = evaluate_python_code(
output, is_final_answer = evaluate_python_code(
code_action,
static_tools=self.static_tools,
custom_tools=self.custom_tools,
@ -1069,7 +1112,7 @@ class LocalPythonInterpreter:
authorized_imports=self.authorized_imports,
)
logs = self.state["print_outputs"]
return output, logs
return output, logs, is_final_answer
__all__ = ["evaluate_python_code", "LocalPythonInterpreter"]

View File

@ -373,7 +373,7 @@ Here are the rules you should always follow to solve your task:
4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block.
5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters.
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables.
7. Never create any notional variables in our code, as having these in your logs will derail you from the true variables.
8. You can use imports in your code, but only from the following list of modules: {{authorized_imports}}
9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
10. Don't give up! You're in charge of solving the task, not providing directions to solve it.

View File

@ -106,26 +106,35 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]:
def parse_code_blob(code_blob: str) -> str:
try:
pattern = r"```(?:py|python)?\n(.*?)\n```"
match = re.search(pattern, code_blob, re.DOTALL)
if match is None:
raise ValueError(
f"No match ground for regex pattern {pattern} in {code_blob=}."
)
return match.group(1).strip()
"""Parses the LLM's output to get any code blob inside. Will retrun the code directly if it's code."""
pattern = r"```(?:py|python)?\n(.*?)\n```"
match = re.search(pattern, code_blob, re.DOTALL)
if match is None:
try: # Maybe the LLM outputted a code blob directly
ast.parse(code_blob)
return code_blob
except SyntaxError:
pass
except Exception as e:
if "final" in code_blob and "answer" in code_blob:
raise ValueError(
f"""
The code blob is invalid, because the regex pattern {pattern} was not found in {code_blob=}. It seems like you're trying to return the final answer, you can do it as follows:
Code:
```py
final_answer("YOUR FINAL ANSWER HERE")
```<end_action>""".strip()
)
raise ValueError(
f"""
The code blob you used is invalid: due to the following error: {e}
This means that the regex pattern {pattern} was not respected: make sure to include code with the correct pattern, for instance:
The code blob is invalid, because the regex pattern {pattern} was not found in {code_blob=}. Make sure to include code with the correct pattern, for instance:
Thoughts: Your thoughts
Code:
```py
# Your python code here
```<end_action>"""
```<end_action>""".strip()
)
return match.group(1).strip()
def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]:

View File

@ -444,3 +444,18 @@ final_answer("Final report.")
report = manager_toolcalling_agent.run("Fake question.")
assert report == "Final report."
def test_code_nontrivial_final_answer_works(self):
def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None):
return """Code:
```py
def nested_answer():
final_answer("Correct!")
nested_answer()
```<end_code>"""
agent = CodeAgent(tools=[], model=fake_code_model_final_answer)
output = agent.run("Count to 3")
assert output == "Correct!"

View File

@ -23,6 +23,7 @@ from smolagents.default_tools import BASE_PYTHON_TOOLS
from smolagents.local_python_executor import (
InterpreterError,
evaluate_python_code,
fix_final_answer_code,
)
from smolagents.types import AGENT_TYPE_MAPPING
@ -79,19 +80,19 @@ class PythonInterpreterTester(unittest.TestCase):
def test_evaluate_assign(self):
code = "x = 3"
state = {}
result = evaluate_python_code(code, {}, state=state)
result, _ = evaluate_python_code(code, {}, state=state)
assert result == 3
self.assertDictEqual(state, {"x": 3, "print_outputs": ""})
code = "x = y"
state = {"y": 5}
result = evaluate_python_code(code, {}, state=state)
result, _ = evaluate_python_code(code, {}, state=state)
# evaluate returns the value of the last assignment.
assert result == 5
self.assertDictEqual(state, {"x": 5, "y": 5, "print_outputs": ""})
code = "a=1;b=None"
result = evaluate_python_code(code, {}, state={})
result, _ = evaluate_python_code(code, {}, state={})
# evaluate returns the value of the last assignment.
assert result is None
@ -107,7 +108,7 @@ class PythonInterpreterTester(unittest.TestCase):
def test_evaluate_call(self):
code = "y = add_two(x)"
state = {"x": 3}
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
assert result == 5
self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""})
@ -119,14 +120,14 @@ class PythonInterpreterTester(unittest.TestCase):
def test_evaluate_constant(self):
code = "x = 3"
state = {}
result = evaluate_python_code(code, {}, state=state)
result, _ = evaluate_python_code(code, {}, state=state)
assert result == 3
self.assertDictEqual(state, {"x": 3, "print_outputs": ""})
def test_evaluate_dict(self):
code = "test_dict = {'x': x, 'y': add_two(x)}"
state = {"x": 3}
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
self.assertDictEqual(result, {"x": 3, "y": 5})
self.assertDictEqual(
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}
@ -135,7 +136,7 @@ class PythonInterpreterTester(unittest.TestCase):
def test_evaluate_expression(self):
code = "x = 3\ny = 5"
state = {}
result = evaluate_python_code(code, {}, state=state)
result, _ = evaluate_python_code(code, {}, state=state)
# evaluate returns the value of the last assignment.
assert result == 5
self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""})
@ -143,7 +144,7 @@ class PythonInterpreterTester(unittest.TestCase):
def test_evaluate_f_string(self):
code = "text = f'This is x: {x}.'"
state = {"x": 3}
result = evaluate_python_code(code, {}, state=state)
result, _ = evaluate_python_code(code, {}, state=state)
# evaluate returns the value of the last assignment.
assert result == "This is x: 3."
self.assertDictEqual(
@ -153,13 +154,13 @@ class PythonInterpreterTester(unittest.TestCase):
def test_evaluate_if(self):
code = "if x <= 3:\n y = 2\nelse:\n y = 5"
state = {"x": 3}
result = evaluate_python_code(code, {}, state=state)
result, _ = evaluate_python_code(code, {}, state=state)
# evaluate returns the value of the last assignment.
assert result == 2
self.assertDictEqual(state, {"x": 3, "y": 2, "print_outputs": ""})
state = {"x": 8}
result = evaluate_python_code(code, {}, state=state)
result, _ = evaluate_python_code(code, {}, state=state)
# evaluate returns the value of the last assignment.
assert result == 5
self.assertDictEqual(state, {"x": 8, "y": 5, "print_outputs": ""})
@ -167,27 +168,27 @@ class PythonInterpreterTester(unittest.TestCase):
def test_evaluate_list(self):
code = "test_list = [x, add_two(x)]"
state = {"x": 3}
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
self.assertListEqual(result, [3, 5])
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5], "print_outputs": ""})
def test_evaluate_name(self):
code = "y = x"
state = {"x": 3}
result = evaluate_python_code(code, {}, state=state)
result, _ = evaluate_python_code(code, {}, state=state)
assert result == 3
self.assertDictEqual(state, {"x": 3, "y": 3, "print_outputs": ""})
def test_evaluate_subscript(self):
code = "test_list = [x, add_two(x)]\ntest_list[1]"
state = {"x": 3}
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
assert result == 5
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5], "print_outputs": ""})
code = "test_dict = {'x': x, 'y': add_two(x)}\ntest_dict['y']"
state = {"x": 3}
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
assert result == 5
self.assertDictEqual(
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}
@ -215,14 +216,14 @@ for result in search_results:
def test_evaluate_for(self):
code = "x = 0\nfor i in range(3):\n x = i"
state = {}
result = evaluate_python_code(code, {"range": range}, state=state)
result, _ = evaluate_python_code(code, {"range": range}, state=state)
assert result == 2
self.assertDictEqual(state, {"x": 2, "i": 2, "print_outputs": ""})
def test_evaluate_binop(self):
code = "y + x"
state = {"x": 3, "y": 6}
result = evaluate_python_code(code, {}, state=state)
result, _ = evaluate_python_code(code, {}, state=state)
assert result == 9
self.assertDictEqual(state, {"x": 3, "y": 6, "print_outputs": ""})
@ -234,27 +235,27 @@ def recur_fibo(n):
else:
return(recur_fibo(n-1) + recur_fibo(n-2))
recur_fibo(6)"""
result = evaluate_python_code(code, {}, state={})
result, _ = evaluate_python_code(code, {}, state={})
assert result == 8
def test_evaluate_string_methods(self):
code = "'hello'.replace('h', 'o').split('e')"
result = evaluate_python_code(code, {}, state={})
result, _ = evaluate_python_code(code, {}, state={})
assert result == ["o", "llo"]
def test_evaluate_slicing(self):
code = "'hello'[1:3][::-1]"
result = evaluate_python_code(code, {}, state={})
result, _ = evaluate_python_code(code, {}, state={})
assert result == "le"
def test_access_attributes(self):
code = "integer = 1\nobj_class = integer.__class__\nobj_class"
result = evaluate_python_code(code, {}, state={})
result, _ = evaluate_python_code(code, {}, state={})
assert result is int
def test_list_comprehension(self):
code = "sentence = 'THESEAGULL43'\nmeaningful_sentence = '-'.join([char.lower() for char in sentence if char.isalpha()])"
result = evaluate_python_code(code, {}, state={})
result, _ = evaluate_python_code(code, {}, state={})
assert result == "t-h-e-s-e-a-g-u-l-l"
def test_string_indexing(self):
@ -267,12 +268,12 @@ for block in text_block:
for col in range(len(text_block[0])):
sentence += block[col]
"""
result = evaluate_python_code(code, {"len": len, "range": range}, state={})
result, _ = evaluate_python_code(code, {"len": len, "range": range}, state={})
assert result == "THESEAGULL"
def test_tuples(self):
code = "x = (1, 2, 3)\nx[1]"
result = evaluate_python_code(code, {}, state={})
result, _ = evaluate_python_code(code, {}, state={})
assert result == 2
code = """
@ -325,35 +326,35 @@ print(check_digits)
def test_listcomp(self):
code = "x = [i for i in range(3)]"
result = evaluate_python_code(code, {"range": range}, state={})
result, _ = evaluate_python_code(code, {"range": range}, state={})
assert result == [0, 1, 2]
def test_break_continue(self):
code = "for i in range(10):\n if i == 5:\n break\ni"
result = evaluate_python_code(code, {"range": range}, state={})
result, _ = evaluate_python_code(code, {"range": range}, state={})
assert result == 5
code = "for i in range(10):\n if i == 5:\n continue\ni"
result = evaluate_python_code(code, {"range": range}, state={})
result, _ = evaluate_python_code(code, {"range": range}, state={})
assert result == 9
def test_call_int(self):
code = "import math\nstr(math.ceil(149))"
result = evaluate_python_code(code, {"str": lambda x: str(x)}, state={})
result, _ = evaluate_python_code(code, {"str": lambda x: str(x)}, state={})
assert result == "149"
def test_lambda(self):
code = "f = lambda x: x + 2\nf(3)"
result = evaluate_python_code(code, {}, state={})
result, _ = evaluate_python_code(code, {}, state={})
assert result == 5
def test_dictcomp(self):
code = "x = {i: i**2 for i in range(3)}"
result = evaluate_python_code(code, {"range": range}, state={})
result, _ = evaluate_python_code(code, {"range": range}, state={})
assert result == {0: 0, 1: 1, 2: 4}
code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}"
result = evaluate_python_code(
result, _ = evaluate_python_code(
code, {"print": print}, state={}, authorized_imports=["pandas"]
)
assert result == {102: "b"}
@ -362,17 +363,17 @@ print(check_digits)
shifts = {'A': ('6:45', '8:00'), 'B': ('10:00', '11:45')}
shift_minutes = {worker: ('a', 'b') for worker, (start, end) in shifts.items()}
"""
result = evaluate_python_code(code, {}, state={})
result, _ = evaluate_python_code(code, {}, state={})
assert result == {"A": ("a", "b"), "B": ("a", "b")}
def test_tuple_assignment(self):
code = "a, b = 0, 1\nb"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == 1
def test_while(self):
code = "i = 0\nwhile i < 3:\n i += 1\ni"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == 3
# test infinite loop
@ -393,7 +394,7 @@ while i < n and house_positions[i] <= loc:
def test_generator(self):
code = "a = [1, 2, 3, 4, 5]; b = (i**2 for i in a); list(b)"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == [1, 4, 9, 16, 25]
def test_boolops(self):
@ -403,7 +404,7 @@ else:
best_city = "Manhattan"
best_city
"""
result = evaluate_python_code(
result, _ = evaluate_python_code(
code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
)
assert result == "Brooklyn"
@ -416,7 +417,7 @@ else:
best_city = "Manhattan"
best_city
"""
result = evaluate_python_code(
result, _ = evaluate_python_code(
code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
)
assert result == "Sacramento"
@ -431,51 +432,51 @@ if char.isalpha():
def test_imports(self):
code = "import math\nmath.sqrt(4)"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == 2.0
code = (
"from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
)
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == "lose"
code = "import time, re\ntime.sleep(0.1)"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result is None
code = "from queue import Queue\nq = Queue()\nq.put(1)\nq.get()"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == 1
code = "import itertools\nlist(itertools.islice(range(10), 3))"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == [0, 1, 2]
code = "import re\nre.search('a', 'abc').group()"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == "a"
code = "import stat\nstat.S_ISREG(0o100644)"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result
code = "import statistics\nstatistics.mean([1, 2, 3, 4, 4])"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == 2.8
code = "import unicodedata\nunicodedata.name('A')"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == "LATIN CAPITAL LETTER A"
# Test submodules are handled properly, thus not raising error
code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()"
result = evaluate_python_code(
result, _ = evaluate_python_code(
code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]
)
code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()"
result = evaluate_python_code(
result, _ = evaluate_python_code(
code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]
)
@ -491,25 +492,25 @@ if char.isalpha():
def test_multiple_comparators(self):
code = "0 <= -1 < 4 and 0 <= -5 < 4"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert not result
code = "0 <= 1 < 4 and 0 <= -5 < 4"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert not result
code = "0 <= 4 < 4 and 0 <= 3 < 4"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert not result
code = "0 <= 3 < 4 and 0 <= 3 < 4"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result
def test_print_output(self):
code = "print('Hello world!')\nprint('Ok no one cares')"
state = {}
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
assert result is None
assert state["print_outputs"] == "Hello world!\nOk no one cares\n"
@ -525,7 +526,7 @@ function()"""
def test_tuple_target_in_iterator(self):
code = "for a, b in [('Ralf Weikert', 'Austria'), ('Samuel Seungwon Lee', 'South Korea')]:res = a.split()[0]"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == "Samuel"
def test_classes(self):
@ -618,7 +619,7 @@ def var_args_method(self, *args, **kwargs):
var_args_method(1, 2, 3, x=4, y=5)
"""
state = {}
result = evaluate_python_code(code, {"sum": sum}, state=state)
result, _ = evaluate_python_code(code, {"sum": sum}, state=state)
assert result == 15
def test_exceptions(self):
@ -648,7 +649,7 @@ except ValueError as e:
def test_types_as_objects(self):
code = "type_a = float(2); type_b = str; type_c = int"
state = {}
result = evaluate_python_code(
result, is_final_answer = evaluate_python_code(
code, {"float": float, "str": str, "int": int}, state=state
)
assert result is int
@ -659,7 +660,7 @@ food_items = {"apple": 2, "banana": 3, "orange": 1, "pear": 1}
unique_food_items = [item for item, count in food_item_counts.items() if count == 1]
"""
state = {}
result = evaluate_python_code(code, {}, state=state)
result, is_final_answer = evaluate_python_code(code, {}, state=state)
assert result == ["orange", "pear"]
def test_nonsimple_augassign(self):
@ -742,8 +743,9 @@ def f(a, b=333, n=1000):
return b + n
n = f(1, n=667)
"""
res = evaluate_python_code(code, {}, {})
res, is_final_answer = evaluate_python_code(code, {}, {})
assert res == 1000
assert not is_final_answer
def test_set(self):
code = """
@ -767,8 +769,11 @@ while True:
break
i"""
result = evaluate_python_code(code, {"print": print, "round": round}, state={})
result, is_final_answer = evaluate_python_code(
code, {"print": print, "round": round}, state={}
)
assert result == 3
assert not is_final_answer
def test_return(self):
# test early returns
@ -781,7 +786,7 @@ def add_one(n, shift):
add_one(1, 1)
"""
state = {}
result = evaluate_python_code(
result, is_final_answer = evaluate_python_code(
code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state
)
assert result == 2
@ -794,7 +799,7 @@ def returns_none(a):
returns_none(1)
"""
state = {}
result = evaluate_python_code(
result, is_final_answer = evaluate_python_code(
code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state
)
assert result is None
@ -812,7 +817,7 @@ out = [i for sublist in all_res for i in sublist]
out[:10]
"""
state = {}
result = evaluate_python_code(
result, is_final_answer = evaluate_python_code(
code, {"print": print, "range": range}, state=state
)
assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]
@ -829,7 +834,7 @@ parts_with_5_set_count = df[df['SetCount'] == 5.0]
parts_with_5_set_count[['Quantity', 'SetCount']].values[1]
"""
state = {}
result = evaluate_python_code(
result, _ = evaluate_python_code(
code, {}, state=state, authorized_imports=["pandas"]
)
assert np.array_equal(result, [-1, 5])
@ -842,7 +847,7 @@ df = pd.DataFrame.from_dict({"AtomicNumber": [111, 104, 105], "ok": [0, 1, 2]})
# Filter the DataFrame to get only the rows with outdated atomic numbers
filtered_df = df.loc[df['AtomicNumber'].isin([104])]
"""
result = evaluate_python_code(
result, _ = evaluate_python_code(
code, {"print": print}, state={}, authorized_imports=["pandas"]
)
assert np.array_equal(result.values[0], [104, 1])
@ -855,7 +860,9 @@ data = pd.DataFrame.from_dict([
])
survival_rate_by_class = data.groupby('Pclass')['Survived'].mean()
"""
result = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"])
result, _ = evaluate_python_code(
code, {}, state={}, authorized_imports=["pandas"]
)
assert result.values[1] == 0.5
def test_starred(self):
@ -877,7 +884,7 @@ coords_barcelona = (41.3869, 2.1660)
distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona)
"""
result = evaluate_python_code(
result, _ = evaluate_python_code(
code, {"print": print, "map": map}, state={}, authorized_imports=["math"]
)
assert round(result, 1) == 622395.4
@ -894,5 +901,42 @@ for worker, (start, end) in shifts.items():
shift_intervals[worker] = end
shift_intervals
"""
result = evaluate_python_code(code, {"print": print, "map": map}, state={})
result, _ = evaluate_python_code(code, {"print": print, "map": map}, state={})
assert result == {"Worker A": "8:00 pm", "Worker B": "11:45 am"}
def test_fix_final_answer_code(self):
test_cases = [
(
"final_answer = 3.21\nfinal_answer(final_answer)",
"final_answer_variable = 3.21\nfinal_answer(final_answer_variable)",
),
(
"x = final_answer(5)\nfinal_answer = x + 1\nfinal_answer(final_answer)",
"x = final_answer(5)\nfinal_answer_variable = x + 1\nfinal_answer(final_answer_variable)",
),
(
"def func():\n final_answer = 42\n return final_answer(final_answer)",
"def func():\n final_answer_variable = 42\n return final_answer(final_answer_variable)",
),
(
"final_answer(5) # Should not change function calls",
"final_answer(5) # Should not change function calls",
),
(
"obj.final_answer = 5 # Should not change object attributes",
"obj.final_answer = 5 # Should not change object attributes",
),
(
"final_answer=3.21;final_answer(final_answer)",
"final_answer_variable=3.21;final_answer(final_answer_variable)",
),
]
for i, (input_code, expected) in enumerate(test_cases, 1):
result = fix_final_answer_code(input_code)
assert result == expected, f"""
Test case {i} failed:
Input: {input_code}
Expected: {expected}
Got: {result}
"""

View File

@ -1,86 +1,39 @@
import os
import shutil
import tempfile
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from pathlib import Path
import pytest
from smolagents.utils import parse_code_blob
def str_to_bool(value) -> int:
"""
Converts a string representation of truth to `True` (1) or `False` (0).
class AgentTextTests(unittest.TestCase):
def test_parse_code_blob(self):
with pytest.raises(ValueError):
parse_code_blob("Wrong blob!")
True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`;
"""
value = value.lower()
if value in ("y", "yes", "t", "true", "on", "1"):
return 1
elif value in ("n", "no", "f", "false", "off", "0"):
return 0
else:
raise ValueError(f"invalid truth value {value}")
# Parsing mardkwon with code blobs should work
output = parse_code_blob("""
Here is how to solve the problem:
Code:
```py
import numpy as np
```<end_code>
""")
assert output == "import numpy as np"
def get_int_from_env(env_keys, default):
"""Returns the first positive env value found in the `env_keys` list or the default."""
for e in env_keys:
val = int(os.environ.get(e, -1))
if val >= 0:
return val
return default
def parse_flag_from_env(key, default=False):
"""Returns truthy value for `key` from the env if available else the default."""
value = os.environ.get(key, str(default))
return (
str_to_bool(value) == 1
) # As its name indicates `str_to_bool` actually returns an int...
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
def skip(test_case):
"Decorator that skips a test unconditionally"
return unittest.skip("Test was skipped")(test_case)
def slow(test_case):
"""
Decorator marking a test as slow. Slow tests are skipped by default. Set the RUN_SLOW environment variable to a
truthy value to run them.
"""
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
class TempDirTestCase(unittest.TestCase):
"""
A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its
data at the start of a test, and then destroyes it at the end of the TestCase.
Useful for when a class or API requires a single constant folder throughout it's use, such as Weights and Biases
The temporary directory location will be stored in `self.tmpdir`
"""
clear_on_setup = True
@classmethod
def setUpClass(cls):
"Creates a `tempfile.TemporaryDirectory` and stores it in `cls.tmpdir`"
cls.tmpdir = Path(tempfile.mkdtemp())
@classmethod
def tearDownClass(cls):
"Remove `cls.tmpdir` after test suite has finished"
if os.path.exists(cls.tmpdir):
shutil.rmtree(cls.tmpdir)
def setUp(self):
"Destroy all contents in `self.tmpdir`, but not `self.tmpdir`"
if self.clear_on_setup:
for path in self.tmpdir.glob("**/*"):
if path.is_file():
path.unlink()
elif path.is_dir():
shutil.rmtree(path)
# Parsing code blobs should work
code_blob = "import numpy as np"
output = parse_code_blob(code_blob)
assert output == code_blob