Auto correct wrong assignments to final_answer (#123)

* Auto correct wrong assignments to final_answer
This commit is contained in:
Aymeric Roucher 2025-01-08 19:04:11 +01:00 committed by GitHub
parent e5d879feab
commit d3cd0f9e09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 694 additions and 303 deletions

View File

@ -57,3 +57,6 @@ jobs:
- name: Types tests - name: Types tests
run: | run: |
uv run pytest -sv ./tests/test_types.py uv run pytest -sv ./tests/test_types.py
- name: Utils tests
run: |
uv run pytest -sv ./tests/test_utils.py

File diff suppressed because one or more lines are too long

View File

@ -26,7 +26,11 @@ from rich.text import Text
from .default_tools import FinalAnswerTool from .default_tools import FinalAnswerTool
from .e2b_executor import E2BExecutor from .e2b_executor import E2BExecutor
from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonInterpreter from .local_python_executor import (
BASE_BUILTIN_MODULES,
LocalPythonInterpreter,
fix_final_answer_code,
)
from .models import MessageRole from .models import MessageRole
from .monitoring import Monitor from .monitoring import Monitor
from .prompts import ( from .prompts import (
@ -895,7 +899,6 @@ class CodeAgent(MultiStepAgent):
) )
log_entry.llm_output = llm_output log_entry.llm_output = llm_output
except Exception as e: except Exception as e:
console.print_exception()
raise AgentGenerationError(f"Error in generating model output:\n{e}") raise AgentGenerationError(f"Error in generating model output:\n{e}")
if self.verbose: if self.verbose:
@ -917,10 +920,11 @@ class CodeAgent(MultiStepAgent):
# Parse # Parse
try: try:
code_action = parse_code_blob(llm_output) code_action = fix_final_answer_code(parse_code_blob(llm_output))
except Exception as e: except Exception as e:
console.print_exception() error_msg = (
error_msg = f"Error in code parsing: {e}. Make sure to provide correct code" f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
)
raise AgentParsingError(error_msg) raise AgentParsingError(error_msg)
log_entry.tool_call = ToolCall( log_entry.tool_call = ToolCall(
@ -944,8 +948,9 @@ class CodeAgent(MultiStepAgent):
) )
) )
observation = "" observation = ""
is_final_answer = False
try: try:
output, execution_logs = self.python_executor( output, execution_logs, is_final_answer = self.python_executor(
code_action, code_action,
self.state, self.state,
) )
@ -976,12 +981,6 @@ class CodeAgent(MultiStepAgent):
observation += "Last output from code snippet:\n" + truncated_output observation += "Last output from code snippet:\n" + truncated_output
log_entry.observations = observation log_entry.observations = observation
is_final_answer = False
for line in code_action.split("\n"):
if line[: len("final_answer")] == "final_answer":
is_final_answer = True
break
execution_outputs_console += [ execution_outputs_console += [
Text( Text(
f"{('Out - Final answer' if is_final_answer else 'Out')}: {truncated_output}", f"{('Out - Final answer' if is_final_answer else 'Out')}: {truncated_output}",

View File

@ -112,7 +112,7 @@ class PythonInterpreterTool(Tool):
state=state, state=state,
static_tools=self.base_python_tools, static_tools=self.base_python_tools,
authorized_imports=self.authorized_imports, authorized_imports=self.authorized_imports,
) )[0] # The second element is boolean is_final_answer
) )
return f"Stdout:\n{state['print_outputs']}\nOutput: {output}" return f"Stdout:\n{state['print_outputs']}\nOutput: {output}"
except Exception as e: except Exception as e:

View File

@ -18,6 +18,7 @@ import ast
import builtins import builtins
import difflib import difflib
import math import math
import re
from collections.abc import Mapping from collections.abc import Mapping
from importlib import import_module from importlib import import_module
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
@ -129,6 +130,34 @@ def get_iterable(obj):
raise InterpreterError("Object is not iterable") raise InterpreterError("Object is not iterable")
def fix_final_answer_code(code: str) -> str:
"""
Sometimes an LLM can try to assign a variable to final_answer, which would break the final_answer() tool.
This function fixes this behaviour by replacing variable assignments to final_answer with final_answer_variable,
while preserving function calls to final_answer().
"""
# First, find if there's a direct assignment to final_answer
# Use word boundary and negative lookbehind to ensure it's not an object attribute
assignment_pattern = r"(?<!\.)(?<!\w)\bfinal_answer\s*="
if "final_answer(" not in code or not re.search(assignment_pattern, code):
# If final_answer tool is not called in this blob, then doing the replacement is hazardous because it could false the model's memory for next steps.
# Let's not modify the code and leave the subsequent assignment error happen.
return code
# Pattern for replacing variable assignments
# Looks for 'final_answer' followed by '=' with optional whitespace
# Negative lookbehind ensures we don't match object attributes
assignment_regex = r"(?<!\.)(?<!\w)(\bfinal_answer)(\s*=)"
code = re.sub(assignment_regex, r"final_answer_variable\2", code)
# Pattern for replacing variable usage but not function calls
# Negative lookahead (?!\s*\() ensures we don't match function calls
# Negative lookbehind (?<!\.|\w) ensures we don't match object methods or other variables
variable_regex = r"(?<!\.)(?<!\w)(\bfinal_answer\b)(?!\s*\()"
code = re.sub(variable_regex, "final_answer_variable", code)
return code
def evaluate_unaryop(expression, state, static_tools, custom_tools): def evaluate_unaryop(expression, state, static_tools, custom_tools):
operand = evaluate_ast(expression.operand, state, static_tools, custom_tools) operand = evaluate_ast(expression.operand, state, static_tools, custom_tools)
if isinstance(expression.op, ast.USub): if isinstance(expression.op, ast.USub):
@ -224,6 +253,10 @@ def create_function(func_def, state, static_tools, custom_tools):
result = evaluate_ast(stmt, func_state, static_tools, custom_tools) result = evaluate_ast(stmt, func_state, static_tools, custom_tools)
except ReturnException as e: except ReturnException as e:
result = e.value result = e.value
if func_def.name == "__init__":
return None
return result return result
return new_func return new_func
@ -484,41 +517,31 @@ def evaluate_call(call, state, static_tools, custom_tools):
for keyword in call.keywords for keyword in call.keywords
} }
if ( if func_name == "super":
isinstance(func, type) and len(func.__module__.split(".")) > 1 if not args:
): # Check for user-defined classes if "__class__" in state and "self" in state:
# Instantiate the class using its constructor return super(state["__class__"], state["self"])
obj = func.__new__(func) # Create a new instance of the class
if hasattr(obj, "__init__"): # Check if the class has an __init__ method
obj.__init__(*args, **kwargs) # Call the __init__ method correctly
return obj
else:
if func_name == "super":
if not args:
if "__class__" in state and "self" in state:
return super(state["__class__"], state["self"])
else:
raise InterpreterError("super() needs at least one argument")
cls = args[0]
if not isinstance(cls, type):
raise InterpreterError("super() argument 1 must be type")
if len(args) == 1:
return super(cls)
elif len(args) == 2:
instance = args[1]
return super(cls, instance)
else: else:
raise InterpreterError("super() takes at most 2 arguments") raise InterpreterError("super() needs at least one argument")
cls = args[0]
if not isinstance(cls, type):
raise InterpreterError("super() argument 1 must be type")
if len(args) == 1:
return super(cls)
elif len(args) == 2:
instance = args[1]
return super(cls, instance)
else: else:
if func_name == "print": raise InterpreterError("super() takes at most 2 arguments")
output = " ".join(map(str, args)) else:
global PRINT_OUTPUTS if func_name == "print":
PRINT_OUTPUTS += output + "\n" output = " ".join(map(str, args))
# cap the number of lines global PRINT_OUTPUTS
return None PRINT_OUTPUTS += output + "\n"
else: # Assume it's a callable object # cap the number of lines
output = func(*args, **kwargs) return None
return output else: # Assume it's a callable object
return func(*args, **kwargs)
def evaluate_subscript(subscript, state, static_tools, custom_tools): def evaluate_subscript(subscript, state, static_tools, custom_tools):
@ -990,6 +1013,11 @@ def truncate_print_outputs(
return f"Print outputs:\n{print_outputs[:max_len_outputs]}\n_Print outputs have been truncated over the limit of {max_len_outputs} characters._\n" return f"Print outputs:\n{print_outputs[:max_len_outputs]}\n_Print outputs have been truncated over the limit of {max_len_outputs} characters._\n"
class FinalAnswerException(Exception):
def __init__(self, value):
self.value = value
def evaluate_python_code( def evaluate_python_code(
code: str, code: str,
static_tools: Optional[Dict[str, Callable]] = None, static_tools: Optional[Dict[str, Callable]] = None,
@ -1029,6 +1057,12 @@ def evaluate_python_code(
PRINT_OUTPUTS = "" PRINT_OUTPUTS = ""
global OPERATIONS_COUNT global OPERATIONS_COUNT
OPERATIONS_COUNT = 0 OPERATIONS_COUNT = 0
def final_answer(value):
raise FinalAnswerException(value)
static_tools["final_answer"] = final_answer
try: try:
for node in expression.body: for node in expression.body:
result = evaluate_ast( result = evaluate_ast(
@ -1037,7 +1071,14 @@ def evaluate_python_code(
state["print_outputs"] = truncate_content( state["print_outputs"] = truncate_content(
PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT
) )
return result is_final_answer = False
return result, is_final_answer
except FinalAnswerException as e:
state["print_outputs"] = truncate_content(
PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT
)
is_final_answer = True
return e.value, is_final_answer
except InterpreterError as e: except InterpreterError as e:
msg = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT) msg = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT)
msg += f"Code execution failed at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}" msg += f"Code execution failed at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
@ -1059,9 +1100,11 @@ class LocalPythonInterpreter:
} }
# TODO: assert self.authorized imports are all installed locally # TODO: assert self.authorized imports are all installed locally
def __call__(self, code_action: str, additional_variables: Dict) -> Tuple[Any, str]: def __call__(
self, code_action: str, additional_variables: Dict
) -> Tuple[Any, str, bool]:
self.state.update(additional_variables) self.state.update(additional_variables)
output = evaluate_python_code( output, is_final_answer = evaluate_python_code(
code_action, code_action,
static_tools=self.static_tools, static_tools=self.static_tools,
custom_tools=self.custom_tools, custom_tools=self.custom_tools,
@ -1069,7 +1112,7 @@ class LocalPythonInterpreter:
authorized_imports=self.authorized_imports, authorized_imports=self.authorized_imports,
) )
logs = self.state["print_outputs"] logs = self.state["print_outputs"]
return output, logs return output, logs, is_final_answer
__all__ = ["evaluate_python_code", "LocalPythonInterpreter"] __all__ = ["evaluate_python_code", "LocalPythonInterpreter"]

View File

@ -373,7 +373,7 @@ Here are the rules you should always follow to solve your task:
4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block. 4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block.
5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters. 5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters.
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'. 6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables. 7. Never create any notional variables in our code, as having these in your logs will derail you from the true variables.
8. You can use imports in your code, but only from the following list of modules: {{authorized_imports}} 8. You can use imports in your code, but only from the following list of modules: {{authorized_imports}}
9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist. 9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
10. Don't give up! You're in charge of solving the task, not providing directions to solve it. 10. Don't give up! You're in charge of solving the task, not providing directions to solve it.

View File

@ -106,26 +106,35 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]:
def parse_code_blob(code_blob: str) -> str: def parse_code_blob(code_blob: str) -> str:
try: """Parses the LLM's output to get any code blob inside. Will retrun the code directly if it's code."""
pattern = r"```(?:py|python)?\n(.*?)\n```" pattern = r"```(?:py|python)?\n(.*?)\n```"
match = re.search(pattern, code_blob, re.DOTALL) match = re.search(pattern, code_blob, re.DOTALL)
if match is None: if match is None:
raise ValueError( try: # Maybe the LLM outputted a code blob directly
f"No match ground for regex pattern {pattern} in {code_blob=}." ast.parse(code_blob)
) return code_blob
return match.group(1).strip() except SyntaxError:
pass
except Exception as e: if "final" in code_blob and "answer" in code_blob:
raise ValueError(
f"""
The code blob is invalid, because the regex pattern {pattern} was not found in {code_blob=}. It seems like you're trying to return the final answer, you can do it as follows:
Code:
```py
final_answer("YOUR FINAL ANSWER HERE")
```<end_action>""".strip()
)
raise ValueError( raise ValueError(
f""" f"""
The code blob you used is invalid: due to the following error: {e} The code blob is invalid, because the regex pattern {pattern} was not found in {code_blob=}. Make sure to include code with the correct pattern, for instance:
This means that the regex pattern {pattern} was not respected: make sure to include code with the correct pattern, for instance:
Thoughts: Your thoughts Thoughts: Your thoughts
Code: Code:
```py ```py
# Your python code here # Your python code here
```<end_action>""" ```<end_action>""".strip()
) )
return match.group(1).strip()
def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]: def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]:

View File

@ -444,3 +444,18 @@ final_answer("Final report.")
report = manager_toolcalling_agent.run("Fake question.") report = manager_toolcalling_agent.run("Fake question.")
assert report == "Final report." assert report == "Final report."
def test_code_nontrivial_final_answer_works(self):
def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None):
return """Code:
```py
def nested_answer():
final_answer("Correct!")
nested_answer()
```<end_code>"""
agent = CodeAgent(tools=[], model=fake_code_model_final_answer)
output = agent.run("Count to 3")
assert output == "Correct!"

View File

@ -23,6 +23,7 @@ from smolagents.default_tools import BASE_PYTHON_TOOLS
from smolagents.local_python_executor import ( from smolagents.local_python_executor import (
InterpreterError, InterpreterError,
evaluate_python_code, evaluate_python_code,
fix_final_answer_code,
) )
from smolagents.types import AGENT_TYPE_MAPPING from smolagents.types import AGENT_TYPE_MAPPING
@ -79,19 +80,19 @@ class PythonInterpreterTester(unittest.TestCase):
def test_evaluate_assign(self): def test_evaluate_assign(self):
code = "x = 3" code = "x = 3"
state = {} state = {}
result = evaluate_python_code(code, {}, state=state) result, _ = evaluate_python_code(code, {}, state=state)
assert result == 3 assert result == 3
self.assertDictEqual(state, {"x": 3, "print_outputs": ""}) self.assertDictEqual(state, {"x": 3, "print_outputs": ""})
code = "x = y" code = "x = y"
state = {"y": 5} state = {"y": 5}
result = evaluate_python_code(code, {}, state=state) result, _ = evaluate_python_code(code, {}, state=state)
# evaluate returns the value of the last assignment. # evaluate returns the value of the last assignment.
assert result == 5 assert result == 5
self.assertDictEqual(state, {"x": 5, "y": 5, "print_outputs": ""}) self.assertDictEqual(state, {"x": 5, "y": 5, "print_outputs": ""})
code = "a=1;b=None" code = "a=1;b=None"
result = evaluate_python_code(code, {}, state={}) result, _ = evaluate_python_code(code, {}, state={})
# evaluate returns the value of the last assignment. # evaluate returns the value of the last assignment.
assert result is None assert result is None
@ -107,7 +108,7 @@ class PythonInterpreterTester(unittest.TestCase):
def test_evaluate_call(self): def test_evaluate_call(self):
code = "y = add_two(x)" code = "y = add_two(x)"
state = {"x": 3} state = {"x": 3}
result = evaluate_python_code(code, {"add_two": add_two}, state=state) result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
assert result == 5 assert result == 5
self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""}) self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""})
@ -119,14 +120,14 @@ class PythonInterpreterTester(unittest.TestCase):
def test_evaluate_constant(self): def test_evaluate_constant(self):
code = "x = 3" code = "x = 3"
state = {} state = {}
result = evaluate_python_code(code, {}, state=state) result, _ = evaluate_python_code(code, {}, state=state)
assert result == 3 assert result == 3
self.assertDictEqual(state, {"x": 3, "print_outputs": ""}) self.assertDictEqual(state, {"x": 3, "print_outputs": ""})
def test_evaluate_dict(self): def test_evaluate_dict(self):
code = "test_dict = {'x': x, 'y': add_two(x)}" code = "test_dict = {'x': x, 'y': add_two(x)}"
state = {"x": 3} state = {"x": 3}
result = evaluate_python_code(code, {"add_two": add_two}, state=state) result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
self.assertDictEqual(result, {"x": 3, "y": 5}) self.assertDictEqual(result, {"x": 3, "y": 5})
self.assertDictEqual( self.assertDictEqual(
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""} state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}
@ -135,7 +136,7 @@ class PythonInterpreterTester(unittest.TestCase):
def test_evaluate_expression(self): def test_evaluate_expression(self):
code = "x = 3\ny = 5" code = "x = 3\ny = 5"
state = {} state = {}
result = evaluate_python_code(code, {}, state=state) result, _ = evaluate_python_code(code, {}, state=state)
# evaluate returns the value of the last assignment. # evaluate returns the value of the last assignment.
assert result == 5 assert result == 5
self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""}) self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""})
@ -143,7 +144,7 @@ class PythonInterpreterTester(unittest.TestCase):
def test_evaluate_f_string(self): def test_evaluate_f_string(self):
code = "text = f'This is x: {x}.'" code = "text = f'This is x: {x}.'"
state = {"x": 3} state = {"x": 3}
result = evaluate_python_code(code, {}, state=state) result, _ = evaluate_python_code(code, {}, state=state)
# evaluate returns the value of the last assignment. # evaluate returns the value of the last assignment.
assert result == "This is x: 3." assert result == "This is x: 3."
self.assertDictEqual( self.assertDictEqual(
@ -153,13 +154,13 @@ class PythonInterpreterTester(unittest.TestCase):
def test_evaluate_if(self): def test_evaluate_if(self):
code = "if x <= 3:\n y = 2\nelse:\n y = 5" code = "if x <= 3:\n y = 2\nelse:\n y = 5"
state = {"x": 3} state = {"x": 3}
result = evaluate_python_code(code, {}, state=state) result, _ = evaluate_python_code(code, {}, state=state)
# evaluate returns the value of the last assignment. # evaluate returns the value of the last assignment.
assert result == 2 assert result == 2
self.assertDictEqual(state, {"x": 3, "y": 2, "print_outputs": ""}) self.assertDictEqual(state, {"x": 3, "y": 2, "print_outputs": ""})
state = {"x": 8} state = {"x": 8}
result = evaluate_python_code(code, {}, state=state) result, _ = evaluate_python_code(code, {}, state=state)
# evaluate returns the value of the last assignment. # evaluate returns the value of the last assignment.
assert result == 5 assert result == 5
self.assertDictEqual(state, {"x": 8, "y": 5, "print_outputs": ""}) self.assertDictEqual(state, {"x": 8, "y": 5, "print_outputs": ""})
@ -167,27 +168,27 @@ class PythonInterpreterTester(unittest.TestCase):
def test_evaluate_list(self): def test_evaluate_list(self):
code = "test_list = [x, add_two(x)]" code = "test_list = [x, add_two(x)]"
state = {"x": 3} state = {"x": 3}
result = evaluate_python_code(code, {"add_two": add_two}, state=state) result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
self.assertListEqual(result, [3, 5]) self.assertListEqual(result, [3, 5])
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5], "print_outputs": ""}) self.assertDictEqual(state, {"x": 3, "test_list": [3, 5], "print_outputs": ""})
def test_evaluate_name(self): def test_evaluate_name(self):
code = "y = x" code = "y = x"
state = {"x": 3} state = {"x": 3}
result = evaluate_python_code(code, {}, state=state) result, _ = evaluate_python_code(code, {}, state=state)
assert result == 3 assert result == 3
self.assertDictEqual(state, {"x": 3, "y": 3, "print_outputs": ""}) self.assertDictEqual(state, {"x": 3, "y": 3, "print_outputs": ""})
def test_evaluate_subscript(self): def test_evaluate_subscript(self):
code = "test_list = [x, add_two(x)]\ntest_list[1]" code = "test_list = [x, add_two(x)]\ntest_list[1]"
state = {"x": 3} state = {"x": 3}
result = evaluate_python_code(code, {"add_two": add_two}, state=state) result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
assert result == 5 assert result == 5
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5], "print_outputs": ""}) self.assertDictEqual(state, {"x": 3, "test_list": [3, 5], "print_outputs": ""})
code = "test_dict = {'x': x, 'y': add_two(x)}\ntest_dict['y']" code = "test_dict = {'x': x, 'y': add_two(x)}\ntest_dict['y']"
state = {"x": 3} state = {"x": 3}
result = evaluate_python_code(code, {"add_two": add_two}, state=state) result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
assert result == 5 assert result == 5
self.assertDictEqual( self.assertDictEqual(
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""} state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}
@ -215,14 +216,14 @@ for result in search_results:
def test_evaluate_for(self): def test_evaluate_for(self):
code = "x = 0\nfor i in range(3):\n x = i" code = "x = 0\nfor i in range(3):\n x = i"
state = {} state = {}
result = evaluate_python_code(code, {"range": range}, state=state) result, _ = evaluate_python_code(code, {"range": range}, state=state)
assert result == 2 assert result == 2
self.assertDictEqual(state, {"x": 2, "i": 2, "print_outputs": ""}) self.assertDictEqual(state, {"x": 2, "i": 2, "print_outputs": ""})
def test_evaluate_binop(self): def test_evaluate_binop(self):
code = "y + x" code = "y + x"
state = {"x": 3, "y": 6} state = {"x": 3, "y": 6}
result = evaluate_python_code(code, {}, state=state) result, _ = evaluate_python_code(code, {}, state=state)
assert result == 9 assert result == 9
self.assertDictEqual(state, {"x": 3, "y": 6, "print_outputs": ""}) self.assertDictEqual(state, {"x": 3, "y": 6, "print_outputs": ""})
@ -234,27 +235,27 @@ def recur_fibo(n):
else: else:
return(recur_fibo(n-1) + recur_fibo(n-2)) return(recur_fibo(n-1) + recur_fibo(n-2))
recur_fibo(6)""" recur_fibo(6)"""
result = evaluate_python_code(code, {}, state={}) result, _ = evaluate_python_code(code, {}, state={})
assert result == 8 assert result == 8
def test_evaluate_string_methods(self): def test_evaluate_string_methods(self):
code = "'hello'.replace('h', 'o').split('e')" code = "'hello'.replace('h', 'o').split('e')"
result = evaluate_python_code(code, {}, state={}) result, _ = evaluate_python_code(code, {}, state={})
assert result == ["o", "llo"] assert result == ["o", "llo"]
def test_evaluate_slicing(self): def test_evaluate_slicing(self):
code = "'hello'[1:3][::-1]" code = "'hello'[1:3][::-1]"
result = evaluate_python_code(code, {}, state={}) result, _ = evaluate_python_code(code, {}, state={})
assert result == "le" assert result == "le"
def test_access_attributes(self): def test_access_attributes(self):
code = "integer = 1\nobj_class = integer.__class__\nobj_class" code = "integer = 1\nobj_class = integer.__class__\nobj_class"
result = evaluate_python_code(code, {}, state={}) result, _ = evaluate_python_code(code, {}, state={})
assert result is int assert result is int
def test_list_comprehension(self): def test_list_comprehension(self):
code = "sentence = 'THESEAGULL43'\nmeaningful_sentence = '-'.join([char.lower() for char in sentence if char.isalpha()])" code = "sentence = 'THESEAGULL43'\nmeaningful_sentence = '-'.join([char.lower() for char in sentence if char.isalpha()])"
result = evaluate_python_code(code, {}, state={}) result, _ = evaluate_python_code(code, {}, state={})
assert result == "t-h-e-s-e-a-g-u-l-l" assert result == "t-h-e-s-e-a-g-u-l-l"
def test_string_indexing(self): def test_string_indexing(self):
@ -267,12 +268,12 @@ for block in text_block:
for col in range(len(text_block[0])): for col in range(len(text_block[0])):
sentence += block[col] sentence += block[col]
""" """
result = evaluate_python_code(code, {"len": len, "range": range}, state={}) result, _ = evaluate_python_code(code, {"len": len, "range": range}, state={})
assert result == "THESEAGULL" assert result == "THESEAGULL"
def test_tuples(self): def test_tuples(self):
code = "x = (1, 2, 3)\nx[1]" code = "x = (1, 2, 3)\nx[1]"
result = evaluate_python_code(code, {}, state={}) result, _ = evaluate_python_code(code, {}, state={})
assert result == 2 assert result == 2
code = """ code = """
@ -325,35 +326,35 @@ print(check_digits)
def test_listcomp(self): def test_listcomp(self):
code = "x = [i for i in range(3)]" code = "x = [i for i in range(3)]"
result = evaluate_python_code(code, {"range": range}, state={}) result, _ = evaluate_python_code(code, {"range": range}, state={})
assert result == [0, 1, 2] assert result == [0, 1, 2]
def test_break_continue(self): def test_break_continue(self):
code = "for i in range(10):\n if i == 5:\n break\ni" code = "for i in range(10):\n if i == 5:\n break\ni"
result = evaluate_python_code(code, {"range": range}, state={}) result, _ = evaluate_python_code(code, {"range": range}, state={})
assert result == 5 assert result == 5
code = "for i in range(10):\n if i == 5:\n continue\ni" code = "for i in range(10):\n if i == 5:\n continue\ni"
result = evaluate_python_code(code, {"range": range}, state={}) result, _ = evaluate_python_code(code, {"range": range}, state={})
assert result == 9 assert result == 9
def test_call_int(self): def test_call_int(self):
code = "import math\nstr(math.ceil(149))" code = "import math\nstr(math.ceil(149))"
result = evaluate_python_code(code, {"str": lambda x: str(x)}, state={}) result, _ = evaluate_python_code(code, {"str": lambda x: str(x)}, state={})
assert result == "149" assert result == "149"
def test_lambda(self): def test_lambda(self):
code = "f = lambda x: x + 2\nf(3)" code = "f = lambda x: x + 2\nf(3)"
result = evaluate_python_code(code, {}, state={}) result, _ = evaluate_python_code(code, {}, state={})
assert result == 5 assert result == 5
def test_dictcomp(self): def test_dictcomp(self):
code = "x = {i: i**2 for i in range(3)}" code = "x = {i: i**2 for i in range(3)}"
result = evaluate_python_code(code, {"range": range}, state={}) result, _ = evaluate_python_code(code, {"range": range}, state={})
assert result == {0: 0, 1: 1, 2: 4} assert result == {0: 0, 1: 1, 2: 4}
code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}" code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}"
result = evaluate_python_code( result, _ = evaluate_python_code(
code, {"print": print}, state={}, authorized_imports=["pandas"] code, {"print": print}, state={}, authorized_imports=["pandas"]
) )
assert result == {102: "b"} assert result == {102: "b"}
@ -362,17 +363,17 @@ print(check_digits)
shifts = {'A': ('6:45', '8:00'), 'B': ('10:00', '11:45')} shifts = {'A': ('6:45', '8:00'), 'B': ('10:00', '11:45')}
shift_minutes = {worker: ('a', 'b') for worker, (start, end) in shifts.items()} shift_minutes = {worker: ('a', 'b') for worker, (start, end) in shifts.items()}
""" """
result = evaluate_python_code(code, {}, state={}) result, _ = evaluate_python_code(code, {}, state={})
assert result == {"A": ("a", "b"), "B": ("a", "b")} assert result == {"A": ("a", "b"), "B": ("a", "b")}
def test_tuple_assignment(self): def test_tuple_assignment(self):
code = "a, b = 0, 1\nb" code = "a, b = 0, 1\nb"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == 1 assert result == 1
def test_while(self): def test_while(self):
code = "i = 0\nwhile i < 3:\n i += 1\ni" code = "i = 0\nwhile i < 3:\n i += 1\ni"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == 3 assert result == 3
# test infinite loop # test infinite loop
@ -393,7 +394,7 @@ while i < n and house_positions[i] <= loc:
def test_generator(self): def test_generator(self):
code = "a = [1, 2, 3, 4, 5]; b = (i**2 for i in a); list(b)" code = "a = [1, 2, 3, 4, 5]; b = (i**2 for i in a); list(b)"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == [1, 4, 9, 16, 25] assert result == [1, 4, 9, 16, 25]
def test_boolops(self): def test_boolops(self):
@ -403,7 +404,7 @@ else:
best_city = "Manhattan" best_city = "Manhattan"
best_city best_city
""" """
result = evaluate_python_code( result, _ = evaluate_python_code(
code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5} code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
) )
assert result == "Brooklyn" assert result == "Brooklyn"
@ -416,7 +417,7 @@ else:
best_city = "Manhattan" best_city = "Manhattan"
best_city best_city
""" """
result = evaluate_python_code( result, _ = evaluate_python_code(
code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5} code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
) )
assert result == "Sacramento" assert result == "Sacramento"
@ -431,51 +432,51 @@ if char.isalpha():
def test_imports(self): def test_imports(self):
code = "import math\nmath.sqrt(4)" code = "import math\nmath.sqrt(4)"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == 2.0 assert result == 2.0
code = ( code = (
"from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])" "from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
) )
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == "lose" assert result == "lose"
code = "import time, re\ntime.sleep(0.1)" code = "import time, re\ntime.sleep(0.1)"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result is None assert result is None
code = "from queue import Queue\nq = Queue()\nq.put(1)\nq.get()" code = "from queue import Queue\nq = Queue()\nq.put(1)\nq.get()"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == 1 assert result == 1
code = "import itertools\nlist(itertools.islice(range(10), 3))" code = "import itertools\nlist(itertools.islice(range(10), 3))"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == [0, 1, 2] assert result == [0, 1, 2]
code = "import re\nre.search('a', 'abc').group()" code = "import re\nre.search('a', 'abc').group()"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == "a" assert result == "a"
code = "import stat\nstat.S_ISREG(0o100644)" code = "import stat\nstat.S_ISREG(0o100644)"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result assert result
code = "import statistics\nstatistics.mean([1, 2, 3, 4, 4])" code = "import statistics\nstatistics.mean([1, 2, 3, 4, 4])"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == 2.8 assert result == 2.8
code = "import unicodedata\nunicodedata.name('A')" code = "import unicodedata\nunicodedata.name('A')"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == "LATIN CAPITAL LETTER A" assert result == "LATIN CAPITAL LETTER A"
# Test submodules are handled properly, thus not raising error # Test submodules are handled properly, thus not raising error
code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()" code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()"
result = evaluate_python_code( result, _ = evaluate_python_code(
code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"] code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]
) )
code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()" code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()"
result = evaluate_python_code( result, _ = evaluate_python_code(
code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"] code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]
) )
@ -491,25 +492,25 @@ if char.isalpha():
def test_multiple_comparators(self): def test_multiple_comparators(self):
code = "0 <= -1 < 4 and 0 <= -5 < 4" code = "0 <= -1 < 4 and 0 <= -5 < 4"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert not result assert not result
code = "0 <= 1 < 4 and 0 <= -5 < 4" code = "0 <= 1 < 4 and 0 <= -5 < 4"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert not result assert not result
code = "0 <= 4 < 4 and 0 <= 3 < 4" code = "0 <= 4 < 4 and 0 <= 3 < 4"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert not result assert not result
code = "0 <= 3 < 4 and 0 <= 3 < 4" code = "0 <= 3 < 4 and 0 <= 3 < 4"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result assert result
def test_print_output(self): def test_print_output(self):
code = "print('Hello world!')\nprint('Ok no one cares')" code = "print('Hello world!')\nprint('Ok no one cares')"
state = {} state = {}
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state) result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
assert result is None assert result is None
assert state["print_outputs"] == "Hello world!\nOk no one cares\n" assert state["print_outputs"] == "Hello world!\nOk no one cares\n"
@ -525,7 +526,7 @@ function()"""
def test_tuple_target_in_iterator(self): def test_tuple_target_in_iterator(self):
code = "for a, b in [('Ralf Weikert', 'Austria'), ('Samuel Seungwon Lee', 'South Korea')]:res = a.split()[0]" code = "for a, b in [('Ralf Weikert', 'Austria'), ('Samuel Seungwon Lee', 'South Korea')]:res = a.split()[0]"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == "Samuel" assert result == "Samuel"
def test_classes(self): def test_classes(self):
@ -618,7 +619,7 @@ def var_args_method(self, *args, **kwargs):
var_args_method(1, 2, 3, x=4, y=5) var_args_method(1, 2, 3, x=4, y=5)
""" """
state = {} state = {}
result = evaluate_python_code(code, {"sum": sum}, state=state) result, _ = evaluate_python_code(code, {"sum": sum}, state=state)
assert result == 15 assert result == 15
def test_exceptions(self): def test_exceptions(self):
@ -648,7 +649,7 @@ except ValueError as e:
def test_types_as_objects(self): def test_types_as_objects(self):
code = "type_a = float(2); type_b = str; type_c = int" code = "type_a = float(2); type_b = str; type_c = int"
state = {} state = {}
result = evaluate_python_code( result, is_final_answer = evaluate_python_code(
code, {"float": float, "str": str, "int": int}, state=state code, {"float": float, "str": str, "int": int}, state=state
) )
assert result is int assert result is int
@ -659,7 +660,7 @@ food_items = {"apple": 2, "banana": 3, "orange": 1, "pear": 1}
unique_food_items = [item for item, count in food_item_counts.items() if count == 1] unique_food_items = [item for item, count in food_item_counts.items() if count == 1]
""" """
state = {} state = {}
result = evaluate_python_code(code, {}, state=state) result, is_final_answer = evaluate_python_code(code, {}, state=state)
assert result == ["orange", "pear"] assert result == ["orange", "pear"]
def test_nonsimple_augassign(self): def test_nonsimple_augassign(self):
@ -742,8 +743,9 @@ def f(a, b=333, n=1000):
return b + n return b + n
n = f(1, n=667) n = f(1, n=667)
""" """
res = evaluate_python_code(code, {}, {}) res, is_final_answer = evaluate_python_code(code, {}, {})
assert res == 1000 assert res == 1000
assert not is_final_answer
def test_set(self): def test_set(self):
code = """ code = """
@ -767,8 +769,11 @@ while True:
break break
i""" i"""
result = evaluate_python_code(code, {"print": print, "round": round}, state={}) result, is_final_answer = evaluate_python_code(
code, {"print": print, "round": round}, state={}
)
assert result == 3 assert result == 3
assert not is_final_answer
def test_return(self): def test_return(self):
# test early returns # test early returns
@ -781,7 +786,7 @@ def add_one(n, shift):
add_one(1, 1) add_one(1, 1)
""" """
state = {} state = {}
result = evaluate_python_code( result, is_final_answer = evaluate_python_code(
code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state
) )
assert result == 2 assert result == 2
@ -794,7 +799,7 @@ def returns_none(a):
returns_none(1) returns_none(1)
""" """
state = {} state = {}
result = evaluate_python_code( result, is_final_answer = evaluate_python_code(
code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state
) )
assert result is None assert result is None
@ -812,7 +817,7 @@ out = [i for sublist in all_res for i in sublist]
out[:10] out[:10]
""" """
state = {} state = {}
result = evaluate_python_code( result, is_final_answer = evaluate_python_code(
code, {"print": print, "range": range}, state=state code, {"print": print, "range": range}, state=state
) )
assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3] assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]
@ -829,7 +834,7 @@ parts_with_5_set_count = df[df['SetCount'] == 5.0]
parts_with_5_set_count[['Quantity', 'SetCount']].values[1] parts_with_5_set_count[['Quantity', 'SetCount']].values[1]
""" """
state = {} state = {}
result = evaluate_python_code( result, _ = evaluate_python_code(
code, {}, state=state, authorized_imports=["pandas"] code, {}, state=state, authorized_imports=["pandas"]
) )
assert np.array_equal(result, [-1, 5]) assert np.array_equal(result, [-1, 5])
@ -842,7 +847,7 @@ df = pd.DataFrame.from_dict({"AtomicNumber": [111, 104, 105], "ok": [0, 1, 2]})
# Filter the DataFrame to get only the rows with outdated atomic numbers # Filter the DataFrame to get only the rows with outdated atomic numbers
filtered_df = df.loc[df['AtomicNumber'].isin([104])] filtered_df = df.loc[df['AtomicNumber'].isin([104])]
""" """
result = evaluate_python_code( result, _ = evaluate_python_code(
code, {"print": print}, state={}, authorized_imports=["pandas"] code, {"print": print}, state={}, authorized_imports=["pandas"]
) )
assert np.array_equal(result.values[0], [104, 1]) assert np.array_equal(result.values[0], [104, 1])
@ -855,7 +860,9 @@ data = pd.DataFrame.from_dict([
]) ])
survival_rate_by_class = data.groupby('Pclass')['Survived'].mean() survival_rate_by_class = data.groupby('Pclass')['Survived'].mean()
""" """
result = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"]) result, _ = evaluate_python_code(
code, {}, state={}, authorized_imports=["pandas"]
)
assert result.values[1] == 0.5 assert result.values[1] == 0.5
def test_starred(self): def test_starred(self):
@ -877,7 +884,7 @@ coords_barcelona = (41.3869, 2.1660)
distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona) distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona)
""" """
result = evaluate_python_code( result, _ = evaluate_python_code(
code, {"print": print, "map": map}, state={}, authorized_imports=["math"] code, {"print": print, "map": map}, state={}, authorized_imports=["math"]
) )
assert round(result, 1) == 622395.4 assert round(result, 1) == 622395.4
@ -894,5 +901,42 @@ for worker, (start, end) in shifts.items():
shift_intervals[worker] = end shift_intervals[worker] = end
shift_intervals shift_intervals
""" """
result = evaluate_python_code(code, {"print": print, "map": map}, state={}) result, _ = evaluate_python_code(code, {"print": print, "map": map}, state={})
assert result == {"Worker A": "8:00 pm", "Worker B": "11:45 am"} assert result == {"Worker A": "8:00 pm", "Worker B": "11:45 am"}
def test_fix_final_answer_code(self):
test_cases = [
(
"final_answer = 3.21\nfinal_answer(final_answer)",
"final_answer_variable = 3.21\nfinal_answer(final_answer_variable)",
),
(
"x = final_answer(5)\nfinal_answer = x + 1\nfinal_answer(final_answer)",
"x = final_answer(5)\nfinal_answer_variable = x + 1\nfinal_answer(final_answer_variable)",
),
(
"def func():\n final_answer = 42\n return final_answer(final_answer)",
"def func():\n final_answer_variable = 42\n return final_answer(final_answer_variable)",
),
(
"final_answer(5) # Should not change function calls",
"final_answer(5) # Should not change function calls",
),
(
"obj.final_answer = 5 # Should not change object attributes",
"obj.final_answer = 5 # Should not change object attributes",
),
(
"final_answer=3.21;final_answer(final_answer)",
"final_answer_variable=3.21;final_answer(final_answer_variable)",
),
]
for i, (input_code, expected) in enumerate(test_cases, 1):
result = fix_final_answer_code(input_code)
assert result == expected, f"""
Test case {i} failed:
Input: {input_code}
Expected: {expected}
Got: {result}
"""

View File

@ -1,86 +1,39 @@
import os # coding=utf-8
import shutil # Copyright 2024 HuggingFace Inc.
import tempfile #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest import unittest
from pathlib import Path import pytest
from smolagents.utils import parse_code_blob
def str_to_bool(value) -> int: class AgentTextTests(unittest.TestCase):
""" def test_parse_code_blob(self):
Converts a string representation of truth to `True` (1) or `False` (0). with pytest.raises(ValueError):
parse_code_blob("Wrong blob!")
True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`; # Parsing mardkwon with code blobs should work
""" output = parse_code_blob("""
value = value.lower() Here is how to solve the problem:
if value in ("y", "yes", "t", "true", "on", "1"): Code:
return 1 ```py
elif value in ("n", "no", "f", "false", "off", "0"): import numpy as np
return 0 ```<end_code>
else: """)
raise ValueError(f"invalid truth value {value}") assert output == "import numpy as np"
# Parsing code blobs should work
def get_int_from_env(env_keys, default): code_blob = "import numpy as np"
"""Returns the first positive env value found in the `env_keys` list or the default.""" output = parse_code_blob(code_blob)
for e in env_keys: assert output == code_blob
val = int(os.environ.get(e, -1))
if val >= 0:
return val
return default
def parse_flag_from_env(key, default=False):
"""Returns truthy value for `key` from the env if available else the default."""
value = os.environ.get(key, str(default))
return (
str_to_bool(value) == 1
) # As its name indicates `str_to_bool` actually returns an int...
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
def skip(test_case):
"Decorator that skips a test unconditionally"
return unittest.skip("Test was skipped")(test_case)
def slow(test_case):
"""
Decorator marking a test as slow. Slow tests are skipped by default. Set the RUN_SLOW environment variable to a
truthy value to run them.
"""
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
class TempDirTestCase(unittest.TestCase):
"""
A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its
data at the start of a test, and then destroyes it at the end of the TestCase.
Useful for when a class or API requires a single constant folder throughout it's use, such as Weights and Biases
The temporary directory location will be stored in `self.tmpdir`
"""
clear_on_setup = True
@classmethod
def setUpClass(cls):
"Creates a `tempfile.TemporaryDirectory` and stores it in `cls.tmpdir`"
cls.tmpdir = Path(tempfile.mkdtemp())
@classmethod
def tearDownClass(cls):
"Remove `cls.tmpdir` after test suite has finished"
if os.path.exists(cls.tmpdir):
shutil.rmtree(cls.tmpdir)
def setUp(self):
"Destroy all contents in `self.tmpdir`, but not `self.tmpdir`"
if self.clear_on_setup:
for path in self.tmpdir.glob("**/*"):
if path.is_file():
path.unlink()
elif path.is_dir():
shutil.rmtree(path)