Auto correct wrong assignments to final_answer (#123)
* Auto correct wrong assignments to final_answer
This commit is contained in:
parent
e5d879feab
commit
d3cd0f9e09
|
@ -57,3 +57,6 @@ jobs:
|
||||||
- name: Types tests
|
- name: Types tests
|
||||||
run: |
|
run: |
|
||||||
uv run pytest -sv ./tests/test_types.py
|
uv run pytest -sv ./tests/test_types.py
|
||||||
|
- name: Utils tests
|
||||||
|
run: |
|
||||||
|
uv run pytest -sv ./tests/test_utils.py
|
File diff suppressed because one or more lines are too long
|
@ -26,7 +26,11 @@ from rich.text import Text
|
||||||
|
|
||||||
from .default_tools import FinalAnswerTool
|
from .default_tools import FinalAnswerTool
|
||||||
from .e2b_executor import E2BExecutor
|
from .e2b_executor import E2BExecutor
|
||||||
from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonInterpreter
|
from .local_python_executor import (
|
||||||
|
BASE_BUILTIN_MODULES,
|
||||||
|
LocalPythonInterpreter,
|
||||||
|
fix_final_answer_code,
|
||||||
|
)
|
||||||
from .models import MessageRole
|
from .models import MessageRole
|
||||||
from .monitoring import Monitor
|
from .monitoring import Monitor
|
||||||
from .prompts import (
|
from .prompts import (
|
||||||
|
@ -895,7 +899,6 @@ class CodeAgent(MultiStepAgent):
|
||||||
)
|
)
|
||||||
log_entry.llm_output = llm_output
|
log_entry.llm_output = llm_output
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print_exception()
|
|
||||||
raise AgentGenerationError(f"Error in generating model output:\n{e}")
|
raise AgentGenerationError(f"Error in generating model output:\n{e}")
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
@ -917,10 +920,11 @@ class CodeAgent(MultiStepAgent):
|
||||||
|
|
||||||
# Parse
|
# Parse
|
||||||
try:
|
try:
|
||||||
code_action = parse_code_blob(llm_output)
|
code_action = fix_final_answer_code(parse_code_blob(llm_output))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print_exception()
|
error_msg = (
|
||||||
error_msg = f"Error in code parsing: {e}. Make sure to provide correct code"
|
f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
|
||||||
|
)
|
||||||
raise AgentParsingError(error_msg)
|
raise AgentParsingError(error_msg)
|
||||||
|
|
||||||
log_entry.tool_call = ToolCall(
|
log_entry.tool_call = ToolCall(
|
||||||
|
@ -944,8 +948,9 @@ class CodeAgent(MultiStepAgent):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
observation = ""
|
observation = ""
|
||||||
|
is_final_answer = False
|
||||||
try:
|
try:
|
||||||
output, execution_logs = self.python_executor(
|
output, execution_logs, is_final_answer = self.python_executor(
|
||||||
code_action,
|
code_action,
|
||||||
self.state,
|
self.state,
|
||||||
)
|
)
|
||||||
|
@ -976,12 +981,6 @@ class CodeAgent(MultiStepAgent):
|
||||||
observation += "Last output from code snippet:\n" + truncated_output
|
observation += "Last output from code snippet:\n" + truncated_output
|
||||||
log_entry.observations = observation
|
log_entry.observations = observation
|
||||||
|
|
||||||
is_final_answer = False
|
|
||||||
for line in code_action.split("\n"):
|
|
||||||
if line[: len("final_answer")] == "final_answer":
|
|
||||||
is_final_answer = True
|
|
||||||
break
|
|
||||||
|
|
||||||
execution_outputs_console += [
|
execution_outputs_console += [
|
||||||
Text(
|
Text(
|
||||||
f"{('Out - Final answer' if is_final_answer else 'Out')}: {truncated_output}",
|
f"{('Out - Final answer' if is_final_answer else 'Out')}: {truncated_output}",
|
||||||
|
|
|
@ -112,7 +112,7 @@ class PythonInterpreterTool(Tool):
|
||||||
state=state,
|
state=state,
|
||||||
static_tools=self.base_python_tools,
|
static_tools=self.base_python_tools,
|
||||||
authorized_imports=self.authorized_imports,
|
authorized_imports=self.authorized_imports,
|
||||||
)
|
)[0] # The second element is boolean is_final_answer
|
||||||
)
|
)
|
||||||
return f"Stdout:\n{state['print_outputs']}\nOutput: {output}"
|
return f"Stdout:\n{state['print_outputs']}\nOutput: {output}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -18,6 +18,7 @@ import ast
|
||||||
import builtins
|
import builtins
|
||||||
import difflib
|
import difflib
|
||||||
import math
|
import math
|
||||||
|
import re
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
@ -129,6 +130,34 @@ def get_iterable(obj):
|
||||||
raise InterpreterError("Object is not iterable")
|
raise InterpreterError("Object is not iterable")
|
||||||
|
|
||||||
|
|
||||||
|
def fix_final_answer_code(code: str) -> str:
|
||||||
|
"""
|
||||||
|
Sometimes an LLM can try to assign a variable to final_answer, which would break the final_answer() tool.
|
||||||
|
This function fixes this behaviour by replacing variable assignments to final_answer with final_answer_variable,
|
||||||
|
while preserving function calls to final_answer().
|
||||||
|
"""
|
||||||
|
# First, find if there's a direct assignment to final_answer
|
||||||
|
# Use word boundary and negative lookbehind to ensure it's not an object attribute
|
||||||
|
assignment_pattern = r"(?<!\.)(?<!\w)\bfinal_answer\s*="
|
||||||
|
if "final_answer(" not in code or not re.search(assignment_pattern, code):
|
||||||
|
# If final_answer tool is not called in this blob, then doing the replacement is hazardous because it could false the model's memory for next steps.
|
||||||
|
# Let's not modify the code and leave the subsequent assignment error happen.
|
||||||
|
return code
|
||||||
|
|
||||||
|
# Pattern for replacing variable assignments
|
||||||
|
# Looks for 'final_answer' followed by '=' with optional whitespace
|
||||||
|
# Negative lookbehind ensures we don't match object attributes
|
||||||
|
assignment_regex = r"(?<!\.)(?<!\w)(\bfinal_answer)(\s*=)"
|
||||||
|
code = re.sub(assignment_regex, r"final_answer_variable\2", code)
|
||||||
|
|
||||||
|
# Pattern for replacing variable usage but not function calls
|
||||||
|
# Negative lookahead (?!\s*\() ensures we don't match function calls
|
||||||
|
# Negative lookbehind (?<!\.|\w) ensures we don't match object methods or other variables
|
||||||
|
variable_regex = r"(?<!\.)(?<!\w)(\bfinal_answer\b)(?!\s*\()"
|
||||||
|
code = re.sub(variable_regex, "final_answer_variable", code)
|
||||||
|
return code
|
||||||
|
|
||||||
|
|
||||||
def evaluate_unaryop(expression, state, static_tools, custom_tools):
|
def evaluate_unaryop(expression, state, static_tools, custom_tools):
|
||||||
operand = evaluate_ast(expression.operand, state, static_tools, custom_tools)
|
operand = evaluate_ast(expression.operand, state, static_tools, custom_tools)
|
||||||
if isinstance(expression.op, ast.USub):
|
if isinstance(expression.op, ast.USub):
|
||||||
|
@ -224,6 +253,10 @@ def create_function(func_def, state, static_tools, custom_tools):
|
||||||
result = evaluate_ast(stmt, func_state, static_tools, custom_tools)
|
result = evaluate_ast(stmt, func_state, static_tools, custom_tools)
|
||||||
except ReturnException as e:
|
except ReturnException as e:
|
||||||
result = e.value
|
result = e.value
|
||||||
|
|
||||||
|
if func_def.name == "__init__":
|
||||||
|
return None
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return new_func
|
return new_func
|
||||||
|
@ -484,41 +517,31 @@ def evaluate_call(call, state, static_tools, custom_tools):
|
||||||
for keyword in call.keywords
|
for keyword in call.keywords
|
||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if func_name == "super":
|
||||||
isinstance(func, type) and len(func.__module__.split(".")) > 1
|
if not args:
|
||||||
): # Check for user-defined classes
|
if "__class__" in state and "self" in state:
|
||||||
# Instantiate the class using its constructor
|
return super(state["__class__"], state["self"])
|
||||||
obj = func.__new__(func) # Create a new instance of the class
|
|
||||||
if hasattr(obj, "__init__"): # Check if the class has an __init__ method
|
|
||||||
obj.__init__(*args, **kwargs) # Call the __init__ method correctly
|
|
||||||
return obj
|
|
||||||
else:
|
|
||||||
if func_name == "super":
|
|
||||||
if not args:
|
|
||||||
if "__class__" in state and "self" in state:
|
|
||||||
return super(state["__class__"], state["self"])
|
|
||||||
else:
|
|
||||||
raise InterpreterError("super() needs at least one argument")
|
|
||||||
cls = args[0]
|
|
||||||
if not isinstance(cls, type):
|
|
||||||
raise InterpreterError("super() argument 1 must be type")
|
|
||||||
if len(args) == 1:
|
|
||||||
return super(cls)
|
|
||||||
elif len(args) == 2:
|
|
||||||
instance = args[1]
|
|
||||||
return super(cls, instance)
|
|
||||||
else:
|
else:
|
||||||
raise InterpreterError("super() takes at most 2 arguments")
|
raise InterpreterError("super() needs at least one argument")
|
||||||
|
cls = args[0]
|
||||||
|
if not isinstance(cls, type):
|
||||||
|
raise InterpreterError("super() argument 1 must be type")
|
||||||
|
if len(args) == 1:
|
||||||
|
return super(cls)
|
||||||
|
elif len(args) == 2:
|
||||||
|
instance = args[1]
|
||||||
|
return super(cls, instance)
|
||||||
else:
|
else:
|
||||||
if func_name == "print":
|
raise InterpreterError("super() takes at most 2 arguments")
|
||||||
output = " ".join(map(str, args))
|
else:
|
||||||
global PRINT_OUTPUTS
|
if func_name == "print":
|
||||||
PRINT_OUTPUTS += output + "\n"
|
output = " ".join(map(str, args))
|
||||||
# cap the number of lines
|
global PRINT_OUTPUTS
|
||||||
return None
|
PRINT_OUTPUTS += output + "\n"
|
||||||
else: # Assume it's a callable object
|
# cap the number of lines
|
||||||
output = func(*args, **kwargs)
|
return None
|
||||||
return output
|
else: # Assume it's a callable object
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def evaluate_subscript(subscript, state, static_tools, custom_tools):
|
def evaluate_subscript(subscript, state, static_tools, custom_tools):
|
||||||
|
@ -990,6 +1013,11 @@ def truncate_print_outputs(
|
||||||
return f"Print outputs:\n{print_outputs[:max_len_outputs]}\n_Print outputs have been truncated over the limit of {max_len_outputs} characters._\n"
|
return f"Print outputs:\n{print_outputs[:max_len_outputs]}\n_Print outputs have been truncated over the limit of {max_len_outputs} characters._\n"
|
||||||
|
|
||||||
|
|
||||||
|
class FinalAnswerException(Exception):
|
||||||
|
def __init__(self, value):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
|
||||||
def evaluate_python_code(
|
def evaluate_python_code(
|
||||||
code: str,
|
code: str,
|
||||||
static_tools: Optional[Dict[str, Callable]] = None,
|
static_tools: Optional[Dict[str, Callable]] = None,
|
||||||
|
@ -1029,6 +1057,12 @@ def evaluate_python_code(
|
||||||
PRINT_OUTPUTS = ""
|
PRINT_OUTPUTS = ""
|
||||||
global OPERATIONS_COUNT
|
global OPERATIONS_COUNT
|
||||||
OPERATIONS_COUNT = 0
|
OPERATIONS_COUNT = 0
|
||||||
|
|
||||||
|
def final_answer(value):
|
||||||
|
raise FinalAnswerException(value)
|
||||||
|
|
||||||
|
static_tools["final_answer"] = final_answer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for node in expression.body:
|
for node in expression.body:
|
||||||
result = evaluate_ast(
|
result = evaluate_ast(
|
||||||
|
@ -1037,7 +1071,14 @@ def evaluate_python_code(
|
||||||
state["print_outputs"] = truncate_content(
|
state["print_outputs"] = truncate_content(
|
||||||
PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT
|
PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT
|
||||||
)
|
)
|
||||||
return result
|
is_final_answer = False
|
||||||
|
return result, is_final_answer
|
||||||
|
except FinalAnswerException as e:
|
||||||
|
state["print_outputs"] = truncate_content(
|
||||||
|
PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT
|
||||||
|
)
|
||||||
|
is_final_answer = True
|
||||||
|
return e.value, is_final_answer
|
||||||
except InterpreterError as e:
|
except InterpreterError as e:
|
||||||
msg = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT)
|
msg = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT)
|
||||||
msg += f"Code execution failed at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
|
msg += f"Code execution failed at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
|
||||||
|
@ -1059,9 +1100,11 @@ class LocalPythonInterpreter:
|
||||||
}
|
}
|
||||||
# TODO: assert self.authorized imports are all installed locally
|
# TODO: assert self.authorized imports are all installed locally
|
||||||
|
|
||||||
def __call__(self, code_action: str, additional_variables: Dict) -> Tuple[Any, str]:
|
def __call__(
|
||||||
|
self, code_action: str, additional_variables: Dict
|
||||||
|
) -> Tuple[Any, str, bool]:
|
||||||
self.state.update(additional_variables)
|
self.state.update(additional_variables)
|
||||||
output = evaluate_python_code(
|
output, is_final_answer = evaluate_python_code(
|
||||||
code_action,
|
code_action,
|
||||||
static_tools=self.static_tools,
|
static_tools=self.static_tools,
|
||||||
custom_tools=self.custom_tools,
|
custom_tools=self.custom_tools,
|
||||||
|
@ -1069,7 +1112,7 @@ class LocalPythonInterpreter:
|
||||||
authorized_imports=self.authorized_imports,
|
authorized_imports=self.authorized_imports,
|
||||||
)
|
)
|
||||||
logs = self.state["print_outputs"]
|
logs = self.state["print_outputs"]
|
||||||
return output, logs
|
return output, logs, is_final_answer
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["evaluate_python_code", "LocalPythonInterpreter"]
|
__all__ = ["evaluate_python_code", "LocalPythonInterpreter"]
|
||||||
|
|
|
@ -373,7 +373,7 @@ Here are the rules you should always follow to solve your task:
|
||||||
4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block.
|
4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block.
|
||||||
5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters.
|
5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters.
|
||||||
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
|
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
|
||||||
7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables.
|
7. Never create any notional variables in our code, as having these in your logs will derail you from the true variables.
|
||||||
8. You can use imports in your code, but only from the following list of modules: {{authorized_imports}}
|
8. You can use imports in your code, but only from the following list of modules: {{authorized_imports}}
|
||||||
9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
|
9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
|
||||||
10. Don't give up! You're in charge of solving the task, not providing directions to solve it.
|
10. Don't give up! You're in charge of solving the task, not providing directions to solve it.
|
||||||
|
|
|
@ -106,26 +106,35 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]:
|
||||||
|
|
||||||
|
|
||||||
def parse_code_blob(code_blob: str) -> str:
|
def parse_code_blob(code_blob: str) -> str:
|
||||||
try:
|
"""Parses the LLM's output to get any code blob inside. Will retrun the code directly if it's code."""
|
||||||
pattern = r"```(?:py|python)?\n(.*?)\n```"
|
pattern = r"```(?:py|python)?\n(.*?)\n```"
|
||||||
match = re.search(pattern, code_blob, re.DOTALL)
|
match = re.search(pattern, code_blob, re.DOTALL)
|
||||||
if match is None:
|
if match is None:
|
||||||
raise ValueError(
|
try: # Maybe the LLM outputted a code blob directly
|
||||||
f"No match ground for regex pattern {pattern} in {code_blob=}."
|
ast.parse(code_blob)
|
||||||
)
|
return code_blob
|
||||||
return match.group(1).strip()
|
except SyntaxError:
|
||||||
|
pass
|
||||||
|
|
||||||
except Exception as e:
|
if "final" in code_blob and "answer" in code_blob:
|
||||||
|
raise ValueError(
|
||||||
|
f"""
|
||||||
|
The code blob is invalid, because the regex pattern {pattern} was not found in {code_blob=}. It seems like you're trying to return the final answer, you can do it as follows:
|
||||||
|
Code:
|
||||||
|
```py
|
||||||
|
final_answer("YOUR FINAL ANSWER HERE")
|
||||||
|
```<end_action>""".strip()
|
||||||
|
)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"""
|
f"""
|
||||||
The code blob you used is invalid: due to the following error: {e}
|
The code blob is invalid, because the regex pattern {pattern} was not found in {code_blob=}. Make sure to include code with the correct pattern, for instance:
|
||||||
This means that the regex pattern {pattern} was not respected: make sure to include code with the correct pattern, for instance:
|
|
||||||
Thoughts: Your thoughts
|
Thoughts: Your thoughts
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
# Your python code here
|
# Your python code here
|
||||||
```<end_action>"""
|
```<end_action>""".strip()
|
||||||
)
|
)
|
||||||
|
return match.group(1).strip()
|
||||||
|
|
||||||
|
|
||||||
def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]:
|
def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]:
|
||||||
|
|
|
@ -444,3 +444,18 @@ final_answer("Final report.")
|
||||||
|
|
||||||
report = manager_toolcalling_agent.run("Fake question.")
|
report = manager_toolcalling_agent.run("Fake question.")
|
||||||
assert report == "Final report."
|
assert report == "Final report."
|
||||||
|
|
||||||
|
def test_code_nontrivial_final_answer_works(self):
|
||||||
|
def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None):
|
||||||
|
return """Code:
|
||||||
|
```py
|
||||||
|
def nested_answer():
|
||||||
|
final_answer("Correct!")
|
||||||
|
|
||||||
|
nested_answer()
|
||||||
|
```<end_code>"""
|
||||||
|
|
||||||
|
agent = CodeAgent(tools=[], model=fake_code_model_final_answer)
|
||||||
|
|
||||||
|
output = agent.run("Count to 3")
|
||||||
|
assert output == "Correct!"
|
||||||
|
|
|
@ -23,6 +23,7 @@ from smolagents.default_tools import BASE_PYTHON_TOOLS
|
||||||
from smolagents.local_python_executor import (
|
from smolagents.local_python_executor import (
|
||||||
InterpreterError,
|
InterpreterError,
|
||||||
evaluate_python_code,
|
evaluate_python_code,
|
||||||
|
fix_final_answer_code,
|
||||||
)
|
)
|
||||||
from smolagents.types import AGENT_TYPE_MAPPING
|
from smolagents.types import AGENT_TYPE_MAPPING
|
||||||
|
|
||||||
|
@ -79,19 +80,19 @@ class PythonInterpreterTester(unittest.TestCase):
|
||||||
def test_evaluate_assign(self):
|
def test_evaluate_assign(self):
|
||||||
code = "x = 3"
|
code = "x = 3"
|
||||||
state = {}
|
state = {}
|
||||||
result = evaluate_python_code(code, {}, state=state)
|
result, _ = evaluate_python_code(code, {}, state=state)
|
||||||
assert result == 3
|
assert result == 3
|
||||||
self.assertDictEqual(state, {"x": 3, "print_outputs": ""})
|
self.assertDictEqual(state, {"x": 3, "print_outputs": ""})
|
||||||
|
|
||||||
code = "x = y"
|
code = "x = y"
|
||||||
state = {"y": 5}
|
state = {"y": 5}
|
||||||
result = evaluate_python_code(code, {}, state=state)
|
result, _ = evaluate_python_code(code, {}, state=state)
|
||||||
# evaluate returns the value of the last assignment.
|
# evaluate returns the value of the last assignment.
|
||||||
assert result == 5
|
assert result == 5
|
||||||
self.assertDictEqual(state, {"x": 5, "y": 5, "print_outputs": ""})
|
self.assertDictEqual(state, {"x": 5, "y": 5, "print_outputs": ""})
|
||||||
|
|
||||||
code = "a=1;b=None"
|
code = "a=1;b=None"
|
||||||
result = evaluate_python_code(code, {}, state={})
|
result, _ = evaluate_python_code(code, {}, state={})
|
||||||
# evaluate returns the value of the last assignment.
|
# evaluate returns the value of the last assignment.
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
@ -107,7 +108,7 @@ class PythonInterpreterTester(unittest.TestCase):
|
||||||
def test_evaluate_call(self):
|
def test_evaluate_call(self):
|
||||||
code = "y = add_two(x)"
|
code = "y = add_two(x)"
|
||||||
state = {"x": 3}
|
state = {"x": 3}
|
||||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||||
assert result == 5
|
assert result == 5
|
||||||
self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""})
|
self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""})
|
||||||
|
|
||||||
|
@ -119,14 +120,14 @@ class PythonInterpreterTester(unittest.TestCase):
|
||||||
def test_evaluate_constant(self):
|
def test_evaluate_constant(self):
|
||||||
code = "x = 3"
|
code = "x = 3"
|
||||||
state = {}
|
state = {}
|
||||||
result = evaluate_python_code(code, {}, state=state)
|
result, _ = evaluate_python_code(code, {}, state=state)
|
||||||
assert result == 3
|
assert result == 3
|
||||||
self.assertDictEqual(state, {"x": 3, "print_outputs": ""})
|
self.assertDictEqual(state, {"x": 3, "print_outputs": ""})
|
||||||
|
|
||||||
def test_evaluate_dict(self):
|
def test_evaluate_dict(self):
|
||||||
code = "test_dict = {'x': x, 'y': add_two(x)}"
|
code = "test_dict = {'x': x, 'y': add_two(x)}"
|
||||||
state = {"x": 3}
|
state = {"x": 3}
|
||||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||||
self.assertDictEqual(result, {"x": 3, "y": 5})
|
self.assertDictEqual(result, {"x": 3, "y": 5})
|
||||||
self.assertDictEqual(
|
self.assertDictEqual(
|
||||||
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}
|
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}
|
||||||
|
@ -135,7 +136,7 @@ class PythonInterpreterTester(unittest.TestCase):
|
||||||
def test_evaluate_expression(self):
|
def test_evaluate_expression(self):
|
||||||
code = "x = 3\ny = 5"
|
code = "x = 3\ny = 5"
|
||||||
state = {}
|
state = {}
|
||||||
result = evaluate_python_code(code, {}, state=state)
|
result, _ = evaluate_python_code(code, {}, state=state)
|
||||||
# evaluate returns the value of the last assignment.
|
# evaluate returns the value of the last assignment.
|
||||||
assert result == 5
|
assert result == 5
|
||||||
self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""})
|
self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""})
|
||||||
|
@ -143,7 +144,7 @@ class PythonInterpreterTester(unittest.TestCase):
|
||||||
def test_evaluate_f_string(self):
|
def test_evaluate_f_string(self):
|
||||||
code = "text = f'This is x: {x}.'"
|
code = "text = f'This is x: {x}.'"
|
||||||
state = {"x": 3}
|
state = {"x": 3}
|
||||||
result = evaluate_python_code(code, {}, state=state)
|
result, _ = evaluate_python_code(code, {}, state=state)
|
||||||
# evaluate returns the value of the last assignment.
|
# evaluate returns the value of the last assignment.
|
||||||
assert result == "This is x: 3."
|
assert result == "This is x: 3."
|
||||||
self.assertDictEqual(
|
self.assertDictEqual(
|
||||||
|
@ -153,13 +154,13 @@ class PythonInterpreterTester(unittest.TestCase):
|
||||||
def test_evaluate_if(self):
|
def test_evaluate_if(self):
|
||||||
code = "if x <= 3:\n y = 2\nelse:\n y = 5"
|
code = "if x <= 3:\n y = 2\nelse:\n y = 5"
|
||||||
state = {"x": 3}
|
state = {"x": 3}
|
||||||
result = evaluate_python_code(code, {}, state=state)
|
result, _ = evaluate_python_code(code, {}, state=state)
|
||||||
# evaluate returns the value of the last assignment.
|
# evaluate returns the value of the last assignment.
|
||||||
assert result == 2
|
assert result == 2
|
||||||
self.assertDictEqual(state, {"x": 3, "y": 2, "print_outputs": ""})
|
self.assertDictEqual(state, {"x": 3, "y": 2, "print_outputs": ""})
|
||||||
|
|
||||||
state = {"x": 8}
|
state = {"x": 8}
|
||||||
result = evaluate_python_code(code, {}, state=state)
|
result, _ = evaluate_python_code(code, {}, state=state)
|
||||||
# evaluate returns the value of the last assignment.
|
# evaluate returns the value of the last assignment.
|
||||||
assert result == 5
|
assert result == 5
|
||||||
self.assertDictEqual(state, {"x": 8, "y": 5, "print_outputs": ""})
|
self.assertDictEqual(state, {"x": 8, "y": 5, "print_outputs": ""})
|
||||||
|
@ -167,27 +168,27 @@ class PythonInterpreterTester(unittest.TestCase):
|
||||||
def test_evaluate_list(self):
|
def test_evaluate_list(self):
|
||||||
code = "test_list = [x, add_two(x)]"
|
code = "test_list = [x, add_two(x)]"
|
||||||
state = {"x": 3}
|
state = {"x": 3}
|
||||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||||
self.assertListEqual(result, [3, 5])
|
self.assertListEqual(result, [3, 5])
|
||||||
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5], "print_outputs": ""})
|
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5], "print_outputs": ""})
|
||||||
|
|
||||||
def test_evaluate_name(self):
|
def test_evaluate_name(self):
|
||||||
code = "y = x"
|
code = "y = x"
|
||||||
state = {"x": 3}
|
state = {"x": 3}
|
||||||
result = evaluate_python_code(code, {}, state=state)
|
result, _ = evaluate_python_code(code, {}, state=state)
|
||||||
assert result == 3
|
assert result == 3
|
||||||
self.assertDictEqual(state, {"x": 3, "y": 3, "print_outputs": ""})
|
self.assertDictEqual(state, {"x": 3, "y": 3, "print_outputs": ""})
|
||||||
|
|
||||||
def test_evaluate_subscript(self):
|
def test_evaluate_subscript(self):
|
||||||
code = "test_list = [x, add_two(x)]\ntest_list[1]"
|
code = "test_list = [x, add_two(x)]\ntest_list[1]"
|
||||||
state = {"x": 3}
|
state = {"x": 3}
|
||||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||||
assert result == 5
|
assert result == 5
|
||||||
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5], "print_outputs": ""})
|
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5], "print_outputs": ""})
|
||||||
|
|
||||||
code = "test_dict = {'x': x, 'y': add_two(x)}\ntest_dict['y']"
|
code = "test_dict = {'x': x, 'y': add_two(x)}\ntest_dict['y']"
|
||||||
state = {"x": 3}
|
state = {"x": 3}
|
||||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||||
assert result == 5
|
assert result == 5
|
||||||
self.assertDictEqual(
|
self.assertDictEqual(
|
||||||
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}
|
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}
|
||||||
|
@ -215,14 +216,14 @@ for result in search_results:
|
||||||
def test_evaluate_for(self):
|
def test_evaluate_for(self):
|
||||||
code = "x = 0\nfor i in range(3):\n x = i"
|
code = "x = 0\nfor i in range(3):\n x = i"
|
||||||
state = {}
|
state = {}
|
||||||
result = evaluate_python_code(code, {"range": range}, state=state)
|
result, _ = evaluate_python_code(code, {"range": range}, state=state)
|
||||||
assert result == 2
|
assert result == 2
|
||||||
self.assertDictEqual(state, {"x": 2, "i": 2, "print_outputs": ""})
|
self.assertDictEqual(state, {"x": 2, "i": 2, "print_outputs": ""})
|
||||||
|
|
||||||
def test_evaluate_binop(self):
|
def test_evaluate_binop(self):
|
||||||
code = "y + x"
|
code = "y + x"
|
||||||
state = {"x": 3, "y": 6}
|
state = {"x": 3, "y": 6}
|
||||||
result = evaluate_python_code(code, {}, state=state)
|
result, _ = evaluate_python_code(code, {}, state=state)
|
||||||
assert result == 9
|
assert result == 9
|
||||||
self.assertDictEqual(state, {"x": 3, "y": 6, "print_outputs": ""})
|
self.assertDictEqual(state, {"x": 3, "y": 6, "print_outputs": ""})
|
||||||
|
|
||||||
|
@ -234,27 +235,27 @@ def recur_fibo(n):
|
||||||
else:
|
else:
|
||||||
return(recur_fibo(n-1) + recur_fibo(n-2))
|
return(recur_fibo(n-1) + recur_fibo(n-2))
|
||||||
recur_fibo(6)"""
|
recur_fibo(6)"""
|
||||||
result = evaluate_python_code(code, {}, state={})
|
result, _ = evaluate_python_code(code, {}, state={})
|
||||||
assert result == 8
|
assert result == 8
|
||||||
|
|
||||||
def test_evaluate_string_methods(self):
|
def test_evaluate_string_methods(self):
|
||||||
code = "'hello'.replace('h', 'o').split('e')"
|
code = "'hello'.replace('h', 'o').split('e')"
|
||||||
result = evaluate_python_code(code, {}, state={})
|
result, _ = evaluate_python_code(code, {}, state={})
|
||||||
assert result == ["o", "llo"]
|
assert result == ["o", "llo"]
|
||||||
|
|
||||||
def test_evaluate_slicing(self):
|
def test_evaluate_slicing(self):
|
||||||
code = "'hello'[1:3][::-1]"
|
code = "'hello'[1:3][::-1]"
|
||||||
result = evaluate_python_code(code, {}, state={})
|
result, _ = evaluate_python_code(code, {}, state={})
|
||||||
assert result == "le"
|
assert result == "le"
|
||||||
|
|
||||||
def test_access_attributes(self):
|
def test_access_attributes(self):
|
||||||
code = "integer = 1\nobj_class = integer.__class__\nobj_class"
|
code = "integer = 1\nobj_class = integer.__class__\nobj_class"
|
||||||
result = evaluate_python_code(code, {}, state={})
|
result, _ = evaluate_python_code(code, {}, state={})
|
||||||
assert result is int
|
assert result is int
|
||||||
|
|
||||||
def test_list_comprehension(self):
|
def test_list_comprehension(self):
|
||||||
code = "sentence = 'THESEAGULL43'\nmeaningful_sentence = '-'.join([char.lower() for char in sentence if char.isalpha()])"
|
code = "sentence = 'THESEAGULL43'\nmeaningful_sentence = '-'.join([char.lower() for char in sentence if char.isalpha()])"
|
||||||
result = evaluate_python_code(code, {}, state={})
|
result, _ = evaluate_python_code(code, {}, state={})
|
||||||
assert result == "t-h-e-s-e-a-g-u-l-l"
|
assert result == "t-h-e-s-e-a-g-u-l-l"
|
||||||
|
|
||||||
def test_string_indexing(self):
|
def test_string_indexing(self):
|
||||||
|
@ -267,12 +268,12 @@ for block in text_block:
|
||||||
for col in range(len(text_block[0])):
|
for col in range(len(text_block[0])):
|
||||||
sentence += block[col]
|
sentence += block[col]
|
||||||
"""
|
"""
|
||||||
result = evaluate_python_code(code, {"len": len, "range": range}, state={})
|
result, _ = evaluate_python_code(code, {"len": len, "range": range}, state={})
|
||||||
assert result == "THESEAGULL"
|
assert result == "THESEAGULL"
|
||||||
|
|
||||||
def test_tuples(self):
|
def test_tuples(self):
|
||||||
code = "x = (1, 2, 3)\nx[1]"
|
code = "x = (1, 2, 3)\nx[1]"
|
||||||
result = evaluate_python_code(code, {}, state={})
|
result, _ = evaluate_python_code(code, {}, state={})
|
||||||
assert result == 2
|
assert result == 2
|
||||||
|
|
||||||
code = """
|
code = """
|
||||||
|
@ -325,35 +326,35 @@ print(check_digits)
|
||||||
|
|
||||||
def test_listcomp(self):
|
def test_listcomp(self):
|
||||||
code = "x = [i for i in range(3)]"
|
code = "x = [i for i in range(3)]"
|
||||||
result = evaluate_python_code(code, {"range": range}, state={})
|
result, _ = evaluate_python_code(code, {"range": range}, state={})
|
||||||
assert result == [0, 1, 2]
|
assert result == [0, 1, 2]
|
||||||
|
|
||||||
def test_break_continue(self):
|
def test_break_continue(self):
|
||||||
code = "for i in range(10):\n if i == 5:\n break\ni"
|
code = "for i in range(10):\n if i == 5:\n break\ni"
|
||||||
result = evaluate_python_code(code, {"range": range}, state={})
|
result, _ = evaluate_python_code(code, {"range": range}, state={})
|
||||||
assert result == 5
|
assert result == 5
|
||||||
|
|
||||||
code = "for i in range(10):\n if i == 5:\n continue\ni"
|
code = "for i in range(10):\n if i == 5:\n continue\ni"
|
||||||
result = evaluate_python_code(code, {"range": range}, state={})
|
result, _ = evaluate_python_code(code, {"range": range}, state={})
|
||||||
assert result == 9
|
assert result == 9
|
||||||
|
|
||||||
def test_call_int(self):
|
def test_call_int(self):
|
||||||
code = "import math\nstr(math.ceil(149))"
|
code = "import math\nstr(math.ceil(149))"
|
||||||
result = evaluate_python_code(code, {"str": lambda x: str(x)}, state={})
|
result, _ = evaluate_python_code(code, {"str": lambda x: str(x)}, state={})
|
||||||
assert result == "149"
|
assert result == "149"
|
||||||
|
|
||||||
def test_lambda(self):
|
def test_lambda(self):
|
||||||
code = "f = lambda x: x + 2\nf(3)"
|
code = "f = lambda x: x + 2\nf(3)"
|
||||||
result = evaluate_python_code(code, {}, state={})
|
result, _ = evaluate_python_code(code, {}, state={})
|
||||||
assert result == 5
|
assert result == 5
|
||||||
|
|
||||||
def test_dictcomp(self):
|
def test_dictcomp(self):
|
||||||
code = "x = {i: i**2 for i in range(3)}"
|
code = "x = {i: i**2 for i in range(3)}"
|
||||||
result = evaluate_python_code(code, {"range": range}, state={})
|
result, _ = evaluate_python_code(code, {"range": range}, state={})
|
||||||
assert result == {0: 0, 1: 1, 2: 4}
|
assert result == {0: 0, 1: 1, 2: 4}
|
||||||
|
|
||||||
code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}"
|
code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}"
|
||||||
result = evaluate_python_code(
|
result, _ = evaluate_python_code(
|
||||||
code, {"print": print}, state={}, authorized_imports=["pandas"]
|
code, {"print": print}, state={}, authorized_imports=["pandas"]
|
||||||
)
|
)
|
||||||
assert result == {102: "b"}
|
assert result == {102: "b"}
|
||||||
|
@ -362,17 +363,17 @@ print(check_digits)
|
||||||
shifts = {'A': ('6:45', '8:00'), 'B': ('10:00', '11:45')}
|
shifts = {'A': ('6:45', '8:00'), 'B': ('10:00', '11:45')}
|
||||||
shift_minutes = {worker: ('a', 'b') for worker, (start, end) in shifts.items()}
|
shift_minutes = {worker: ('a', 'b') for worker, (start, end) in shifts.items()}
|
||||||
"""
|
"""
|
||||||
result = evaluate_python_code(code, {}, state={})
|
result, _ = evaluate_python_code(code, {}, state={})
|
||||||
assert result == {"A": ("a", "b"), "B": ("a", "b")}
|
assert result == {"A": ("a", "b"), "B": ("a", "b")}
|
||||||
|
|
||||||
def test_tuple_assignment(self):
|
def test_tuple_assignment(self):
|
||||||
code = "a, b = 0, 1\nb"
|
code = "a, b = 0, 1\nb"
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
assert result == 1
|
assert result == 1
|
||||||
|
|
||||||
def test_while(self):
|
def test_while(self):
|
||||||
code = "i = 0\nwhile i < 3:\n i += 1\ni"
|
code = "i = 0\nwhile i < 3:\n i += 1\ni"
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
assert result == 3
|
assert result == 3
|
||||||
|
|
||||||
# test infinite loop
|
# test infinite loop
|
||||||
|
@ -393,7 +394,7 @@ while i < n and house_positions[i] <= loc:
|
||||||
|
|
||||||
def test_generator(self):
|
def test_generator(self):
|
||||||
code = "a = [1, 2, 3, 4, 5]; b = (i**2 for i in a); list(b)"
|
code = "a = [1, 2, 3, 4, 5]; b = (i**2 for i in a); list(b)"
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
assert result == [1, 4, 9, 16, 25]
|
assert result == [1, 4, 9, 16, 25]
|
||||||
|
|
||||||
def test_boolops(self):
|
def test_boolops(self):
|
||||||
|
@ -403,7 +404,7 @@ else:
|
||||||
best_city = "Manhattan"
|
best_city = "Manhattan"
|
||||||
best_city
|
best_city
|
||||||
"""
|
"""
|
||||||
result = evaluate_python_code(
|
result, _ = evaluate_python_code(
|
||||||
code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
|
code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
|
||||||
)
|
)
|
||||||
assert result == "Brooklyn"
|
assert result == "Brooklyn"
|
||||||
|
@ -416,7 +417,7 @@ else:
|
||||||
best_city = "Manhattan"
|
best_city = "Manhattan"
|
||||||
best_city
|
best_city
|
||||||
"""
|
"""
|
||||||
result = evaluate_python_code(
|
result, _ = evaluate_python_code(
|
||||||
code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
|
code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
|
||||||
)
|
)
|
||||||
assert result == "Sacramento"
|
assert result == "Sacramento"
|
||||||
|
@ -431,51 +432,51 @@ if char.isalpha():
|
||||||
|
|
||||||
def test_imports(self):
|
def test_imports(self):
|
||||||
code = "import math\nmath.sqrt(4)"
|
code = "import math\nmath.sqrt(4)"
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
assert result == 2.0
|
assert result == 2.0
|
||||||
|
|
||||||
code = (
|
code = (
|
||||||
"from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
|
"from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
|
||||||
)
|
)
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
assert result == "lose"
|
assert result == "lose"
|
||||||
|
|
||||||
code = "import time, re\ntime.sleep(0.1)"
|
code = "import time, re\ntime.sleep(0.1)"
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
code = "from queue import Queue\nq = Queue()\nq.put(1)\nq.get()"
|
code = "from queue import Queue\nq = Queue()\nq.put(1)\nq.get()"
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
assert result == 1
|
assert result == 1
|
||||||
|
|
||||||
code = "import itertools\nlist(itertools.islice(range(10), 3))"
|
code = "import itertools\nlist(itertools.islice(range(10), 3))"
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
assert result == [0, 1, 2]
|
assert result == [0, 1, 2]
|
||||||
|
|
||||||
code = "import re\nre.search('a', 'abc').group()"
|
code = "import re\nre.search('a', 'abc').group()"
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
assert result == "a"
|
assert result == "a"
|
||||||
|
|
||||||
code = "import stat\nstat.S_ISREG(0o100644)"
|
code = "import stat\nstat.S_ISREG(0o100644)"
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
assert result
|
assert result
|
||||||
|
|
||||||
code = "import statistics\nstatistics.mean([1, 2, 3, 4, 4])"
|
code = "import statistics\nstatistics.mean([1, 2, 3, 4, 4])"
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
assert result == 2.8
|
assert result == 2.8
|
||||||
|
|
||||||
code = "import unicodedata\nunicodedata.name('A')"
|
code = "import unicodedata\nunicodedata.name('A')"
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
assert result == "LATIN CAPITAL LETTER A"
|
assert result == "LATIN CAPITAL LETTER A"
|
||||||
|
|
||||||
# Test submodules are handled properly, thus not raising error
|
# Test submodules are handled properly, thus not raising error
|
||||||
code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()"
|
code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()"
|
||||||
result = evaluate_python_code(
|
result, _ = evaluate_python_code(
|
||||||
code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]
|
code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]
|
||||||
)
|
)
|
||||||
|
|
||||||
code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()"
|
code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()"
|
||||||
result = evaluate_python_code(
|
result, _ = evaluate_python_code(
|
||||||
code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]
|
code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -491,25 +492,25 @@ if char.isalpha():
|
||||||
|
|
||||||
def test_multiple_comparators(self):
|
def test_multiple_comparators(self):
|
||||||
code = "0 <= -1 < 4 and 0 <= -5 < 4"
|
code = "0 <= -1 < 4 and 0 <= -5 < 4"
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
assert not result
|
assert not result
|
||||||
|
|
||||||
code = "0 <= 1 < 4 and 0 <= -5 < 4"
|
code = "0 <= 1 < 4 and 0 <= -5 < 4"
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
assert not result
|
assert not result
|
||||||
|
|
||||||
code = "0 <= 4 < 4 and 0 <= 3 < 4"
|
code = "0 <= 4 < 4 and 0 <= 3 < 4"
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
assert not result
|
assert not result
|
||||||
|
|
||||||
code = "0 <= 3 < 4 and 0 <= 3 < 4"
|
code = "0 <= 3 < 4 and 0 <= 3 < 4"
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
assert result
|
assert result
|
||||||
|
|
||||||
def test_print_output(self):
|
def test_print_output(self):
|
||||||
code = "print('Hello world!')\nprint('Ok no one cares')"
|
code = "print('Hello world!')\nprint('Ok no one cares')"
|
||||||
state = {}
|
state = {}
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
|
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
|
||||||
assert result is None
|
assert result is None
|
||||||
assert state["print_outputs"] == "Hello world!\nOk no one cares\n"
|
assert state["print_outputs"] == "Hello world!\nOk no one cares\n"
|
||||||
|
|
||||||
|
@ -525,7 +526,7 @@ function()"""
|
||||||
|
|
||||||
def test_tuple_target_in_iterator(self):
|
def test_tuple_target_in_iterator(self):
|
||||||
code = "for a, b in [('Ralf Weikert', 'Austria'), ('Samuel Seungwon Lee', 'South Korea')]:res = a.split()[0]"
|
code = "for a, b in [('Ralf Weikert', 'Austria'), ('Samuel Seungwon Lee', 'South Korea')]:res = a.split()[0]"
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
assert result == "Samuel"
|
assert result == "Samuel"
|
||||||
|
|
||||||
def test_classes(self):
|
def test_classes(self):
|
||||||
|
@ -618,7 +619,7 @@ def var_args_method(self, *args, **kwargs):
|
||||||
var_args_method(1, 2, 3, x=4, y=5)
|
var_args_method(1, 2, 3, x=4, y=5)
|
||||||
"""
|
"""
|
||||||
state = {}
|
state = {}
|
||||||
result = evaluate_python_code(code, {"sum": sum}, state=state)
|
result, _ = evaluate_python_code(code, {"sum": sum}, state=state)
|
||||||
assert result == 15
|
assert result == 15
|
||||||
|
|
||||||
def test_exceptions(self):
|
def test_exceptions(self):
|
||||||
|
@ -648,7 +649,7 @@ except ValueError as e:
|
||||||
def test_types_as_objects(self):
|
def test_types_as_objects(self):
|
||||||
code = "type_a = float(2); type_b = str; type_c = int"
|
code = "type_a = float(2); type_b = str; type_c = int"
|
||||||
state = {}
|
state = {}
|
||||||
result = evaluate_python_code(
|
result, is_final_answer = evaluate_python_code(
|
||||||
code, {"float": float, "str": str, "int": int}, state=state
|
code, {"float": float, "str": str, "int": int}, state=state
|
||||||
)
|
)
|
||||||
assert result is int
|
assert result is int
|
||||||
|
@ -659,7 +660,7 @@ food_items = {"apple": 2, "banana": 3, "orange": 1, "pear": 1}
|
||||||
unique_food_items = [item for item, count in food_item_counts.items() if count == 1]
|
unique_food_items = [item for item, count in food_item_counts.items() if count == 1]
|
||||||
"""
|
"""
|
||||||
state = {}
|
state = {}
|
||||||
result = evaluate_python_code(code, {}, state=state)
|
result, is_final_answer = evaluate_python_code(code, {}, state=state)
|
||||||
assert result == ["orange", "pear"]
|
assert result == ["orange", "pear"]
|
||||||
|
|
||||||
def test_nonsimple_augassign(self):
|
def test_nonsimple_augassign(self):
|
||||||
|
@ -742,8 +743,9 @@ def f(a, b=333, n=1000):
|
||||||
return b + n
|
return b + n
|
||||||
n = f(1, n=667)
|
n = f(1, n=667)
|
||||||
"""
|
"""
|
||||||
res = evaluate_python_code(code, {}, {})
|
res, is_final_answer = evaluate_python_code(code, {}, {})
|
||||||
assert res == 1000
|
assert res == 1000
|
||||||
|
assert not is_final_answer
|
||||||
|
|
||||||
def test_set(self):
|
def test_set(self):
|
||||||
code = """
|
code = """
|
||||||
|
@ -767,8 +769,11 @@ while True:
|
||||||
break
|
break
|
||||||
|
|
||||||
i"""
|
i"""
|
||||||
result = evaluate_python_code(code, {"print": print, "round": round}, state={})
|
result, is_final_answer = evaluate_python_code(
|
||||||
|
code, {"print": print, "round": round}, state={}
|
||||||
|
)
|
||||||
assert result == 3
|
assert result == 3
|
||||||
|
assert not is_final_answer
|
||||||
|
|
||||||
def test_return(self):
|
def test_return(self):
|
||||||
# test early returns
|
# test early returns
|
||||||
|
@ -781,7 +786,7 @@ def add_one(n, shift):
|
||||||
add_one(1, 1)
|
add_one(1, 1)
|
||||||
"""
|
"""
|
||||||
state = {}
|
state = {}
|
||||||
result = evaluate_python_code(
|
result, is_final_answer = evaluate_python_code(
|
||||||
code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state
|
code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state
|
||||||
)
|
)
|
||||||
assert result == 2
|
assert result == 2
|
||||||
|
@ -794,7 +799,7 @@ def returns_none(a):
|
||||||
returns_none(1)
|
returns_none(1)
|
||||||
"""
|
"""
|
||||||
state = {}
|
state = {}
|
||||||
result = evaluate_python_code(
|
result, is_final_answer = evaluate_python_code(
|
||||||
code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state
|
code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state
|
||||||
)
|
)
|
||||||
assert result is None
|
assert result is None
|
||||||
|
@ -812,7 +817,7 @@ out = [i for sublist in all_res for i in sublist]
|
||||||
out[:10]
|
out[:10]
|
||||||
"""
|
"""
|
||||||
state = {}
|
state = {}
|
||||||
result = evaluate_python_code(
|
result, is_final_answer = evaluate_python_code(
|
||||||
code, {"print": print, "range": range}, state=state
|
code, {"print": print, "range": range}, state=state
|
||||||
)
|
)
|
||||||
assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]
|
assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]
|
||||||
|
@ -829,7 +834,7 @@ parts_with_5_set_count = df[df['SetCount'] == 5.0]
|
||||||
parts_with_5_set_count[['Quantity', 'SetCount']].values[1]
|
parts_with_5_set_count[['Quantity', 'SetCount']].values[1]
|
||||||
"""
|
"""
|
||||||
state = {}
|
state = {}
|
||||||
result = evaluate_python_code(
|
result, _ = evaluate_python_code(
|
||||||
code, {}, state=state, authorized_imports=["pandas"]
|
code, {}, state=state, authorized_imports=["pandas"]
|
||||||
)
|
)
|
||||||
assert np.array_equal(result, [-1, 5])
|
assert np.array_equal(result, [-1, 5])
|
||||||
|
@ -842,7 +847,7 @@ df = pd.DataFrame.from_dict({"AtomicNumber": [111, 104, 105], "ok": [0, 1, 2]})
|
||||||
# Filter the DataFrame to get only the rows with outdated atomic numbers
|
# Filter the DataFrame to get only the rows with outdated atomic numbers
|
||||||
filtered_df = df.loc[df['AtomicNumber'].isin([104])]
|
filtered_df = df.loc[df['AtomicNumber'].isin([104])]
|
||||||
"""
|
"""
|
||||||
result = evaluate_python_code(
|
result, _ = evaluate_python_code(
|
||||||
code, {"print": print}, state={}, authorized_imports=["pandas"]
|
code, {"print": print}, state={}, authorized_imports=["pandas"]
|
||||||
)
|
)
|
||||||
assert np.array_equal(result.values[0], [104, 1])
|
assert np.array_equal(result.values[0], [104, 1])
|
||||||
|
@ -855,7 +860,9 @@ data = pd.DataFrame.from_dict([
|
||||||
])
|
])
|
||||||
survival_rate_by_class = data.groupby('Pclass')['Survived'].mean()
|
survival_rate_by_class = data.groupby('Pclass')['Survived'].mean()
|
||||||
"""
|
"""
|
||||||
result = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"])
|
result, _ = evaluate_python_code(
|
||||||
|
code, {}, state={}, authorized_imports=["pandas"]
|
||||||
|
)
|
||||||
assert result.values[1] == 0.5
|
assert result.values[1] == 0.5
|
||||||
|
|
||||||
def test_starred(self):
|
def test_starred(self):
|
||||||
|
@ -877,7 +884,7 @@ coords_barcelona = (41.3869, 2.1660)
|
||||||
|
|
||||||
distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona)
|
distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona)
|
||||||
"""
|
"""
|
||||||
result = evaluate_python_code(
|
result, _ = evaluate_python_code(
|
||||||
code, {"print": print, "map": map}, state={}, authorized_imports=["math"]
|
code, {"print": print, "map": map}, state={}, authorized_imports=["math"]
|
||||||
)
|
)
|
||||||
assert round(result, 1) == 622395.4
|
assert round(result, 1) == 622395.4
|
||||||
|
@ -894,5 +901,42 @@ for worker, (start, end) in shifts.items():
|
||||||
shift_intervals[worker] = end
|
shift_intervals[worker] = end
|
||||||
shift_intervals
|
shift_intervals
|
||||||
"""
|
"""
|
||||||
result = evaluate_python_code(code, {"print": print, "map": map}, state={})
|
result, _ = evaluate_python_code(code, {"print": print, "map": map}, state={})
|
||||||
assert result == {"Worker A": "8:00 pm", "Worker B": "11:45 am"}
|
assert result == {"Worker A": "8:00 pm", "Worker B": "11:45 am"}
|
||||||
|
|
||||||
|
def test_fix_final_answer_code(self):
|
||||||
|
test_cases = [
|
||||||
|
(
|
||||||
|
"final_answer = 3.21\nfinal_answer(final_answer)",
|
||||||
|
"final_answer_variable = 3.21\nfinal_answer(final_answer_variable)",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"x = final_answer(5)\nfinal_answer = x + 1\nfinal_answer(final_answer)",
|
||||||
|
"x = final_answer(5)\nfinal_answer_variable = x + 1\nfinal_answer(final_answer_variable)",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"def func():\n final_answer = 42\n return final_answer(final_answer)",
|
||||||
|
"def func():\n final_answer_variable = 42\n return final_answer(final_answer_variable)",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"final_answer(5) # Should not change function calls",
|
||||||
|
"final_answer(5) # Should not change function calls",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"obj.final_answer = 5 # Should not change object attributes",
|
||||||
|
"obj.final_answer = 5 # Should not change object attributes",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"final_answer=3.21;final_answer(final_answer)",
|
||||||
|
"final_answer_variable=3.21;final_answer(final_answer_variable)",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, (input_code, expected) in enumerate(test_cases, 1):
|
||||||
|
result = fix_final_answer_code(input_code)
|
||||||
|
assert result == expected, f"""
|
||||||
|
Test case {i} failed:
|
||||||
|
Input: {input_code}
|
||||||
|
Expected: {expected}
|
||||||
|
Got: {result}
|
||||||
|
"""
|
||||||
|
|
|
@ -1,86 +1,39 @@
|
||||||
import os
|
# coding=utf-8
|
||||||
import shutil
|
# Copyright 2024 HuggingFace Inc.
|
||||||
import tempfile
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
import pytest
|
||||||
|
|
||||||
|
from smolagents.utils import parse_code_blob
|
||||||
|
|
||||||
|
|
||||||
def str_to_bool(value) -> int:
|
class AgentTextTests(unittest.TestCase):
|
||||||
"""
|
def test_parse_code_blob(self):
|
||||||
Converts a string representation of truth to `True` (1) or `False` (0).
|
with pytest.raises(ValueError):
|
||||||
|
parse_code_blob("Wrong blob!")
|
||||||
|
|
||||||
True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`;
|
# Parsing mardkwon with code blobs should work
|
||||||
"""
|
output = parse_code_blob("""
|
||||||
value = value.lower()
|
Here is how to solve the problem:
|
||||||
if value in ("y", "yes", "t", "true", "on", "1"):
|
Code:
|
||||||
return 1
|
```py
|
||||||
elif value in ("n", "no", "f", "false", "off", "0"):
|
import numpy as np
|
||||||
return 0
|
```<end_code>
|
||||||
else:
|
""")
|
||||||
raise ValueError(f"invalid truth value {value}")
|
assert output == "import numpy as np"
|
||||||
|
|
||||||
|
# Parsing code blobs should work
|
||||||
def get_int_from_env(env_keys, default):
|
code_blob = "import numpy as np"
|
||||||
"""Returns the first positive env value found in the `env_keys` list or the default."""
|
output = parse_code_blob(code_blob)
|
||||||
for e in env_keys:
|
assert output == code_blob
|
||||||
val = int(os.environ.get(e, -1))
|
|
||||||
if val >= 0:
|
|
||||||
return val
|
|
||||||
return default
|
|
||||||
|
|
||||||
|
|
||||||
def parse_flag_from_env(key, default=False):
|
|
||||||
"""Returns truthy value for `key` from the env if available else the default."""
|
|
||||||
value = os.environ.get(key, str(default))
|
|
||||||
return (
|
|
||||||
str_to_bool(value) == 1
|
|
||||||
) # As its name indicates `str_to_bool` actually returns an int...
|
|
||||||
|
|
||||||
|
|
||||||
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
|
|
||||||
|
|
||||||
|
|
||||||
def skip(test_case):
|
|
||||||
"Decorator that skips a test unconditionally"
|
|
||||||
return unittest.skip("Test was skipped")(test_case)
|
|
||||||
|
|
||||||
|
|
||||||
def slow(test_case):
|
|
||||||
"""
|
|
||||||
Decorator marking a test as slow. Slow tests are skipped by default. Set the RUN_SLOW environment variable to a
|
|
||||||
truthy value to run them.
|
|
||||||
"""
|
|
||||||
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
|
|
||||||
|
|
||||||
|
|
||||||
class TempDirTestCase(unittest.TestCase):
|
|
||||||
"""
|
|
||||||
A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its
|
|
||||||
data at the start of a test, and then destroyes it at the end of the TestCase.
|
|
||||||
|
|
||||||
Useful for when a class or API requires a single constant folder throughout it's use, such as Weights and Biases
|
|
||||||
|
|
||||||
The temporary directory location will be stored in `self.tmpdir`
|
|
||||||
"""
|
|
||||||
|
|
||||||
clear_on_setup = True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
"Creates a `tempfile.TemporaryDirectory` and stores it in `cls.tmpdir`"
|
|
||||||
cls.tmpdir = Path(tempfile.mkdtemp())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
"Remove `cls.tmpdir` after test suite has finished"
|
|
||||||
if os.path.exists(cls.tmpdir):
|
|
||||||
shutil.rmtree(cls.tmpdir)
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
"Destroy all contents in `self.tmpdir`, but not `self.tmpdir`"
|
|
||||||
if self.clear_on_setup:
|
|
||||||
for path in self.tmpdir.glob("**/*"):
|
|
||||||
if path.is_file():
|
|
||||||
path.unlink()
|
|
||||||
elif path.is_dir():
|
|
||||||
shutil.rmtree(path)
|
|
||||||
|
|
Loading…
Reference in New Issue