Try first dunder method in evaluate_augassign (#285)

* Test evaluate_augassign
This commit is contained in:
Albert Villanova del Moral 2025-01-21 10:41:25 +01:00 committed by GitHub
parent a2b37caff1
commit 0e0d73b096
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 95 additions and 14 deletions

View File

@ -368,30 +368,69 @@ def evaluate_augassign(
if isinstance(current_value, list):
if not isinstance(value_to_add, list):
raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.")
try:
updated_value = current_value.__iadd__(value_to_add)
except AttributeError:
updated_value = current_value + value_to_add
else:
try:
updated_value = current_value.__iadd__(value_to_add)
except AttributeError:
updated_value = current_value + value_to_add
elif isinstance(expression.op, ast.Sub):
try:
updated_value = current_value.__isub__(value_to_add)
except AttributeError:
updated_value = current_value - value_to_add
elif isinstance(expression.op, ast.Mult):
try:
updated_value = current_value.__imul__(value_to_add)
except AttributeError:
updated_value = current_value * value_to_add
elif isinstance(expression.op, ast.Div):
try:
updated_value = current_value.__itruediv__(value_to_add)
except AttributeError:
updated_value = current_value / value_to_add
elif isinstance(expression.op, ast.Mod):
try:
updated_value = current_value.__imod__(value_to_add)
except AttributeError:
updated_value = current_value % value_to_add
elif isinstance(expression.op, ast.Pow):
try:
updated_value = current_value.__ipow__(value_to_add)
except AttributeError:
updated_value = current_value**value_to_add
elif isinstance(expression.op, ast.FloorDiv):
try:
updated_value = current_value.__ifloordiv__(value_to_add)
except AttributeError:
updated_value = current_value // value_to_add
elif isinstance(expression.op, ast.BitAnd):
try:
updated_value = current_value.__iand__(value_to_add)
except AttributeError:
updated_value = current_value & value_to_add
elif isinstance(expression.op, ast.BitOr):
try:
updated_value = current_value.__ior__(value_to_add)
except AttributeError:
updated_value = current_value | value_to_add
elif isinstance(expression.op, ast.BitXor):
try:
updated_value = current_value.__ixor__(value_to_add)
except AttributeError:
updated_value = current_value ^ value_to_add
elif isinstance(expression.op, ast.LShift):
try:
updated_value = current_value.__ilshift__(value_to_add)
except AttributeError:
updated_value = current_value << value_to_add
elif isinstance(expression.op, ast.RShift):
try:
updated_value = current_value.__irshift__(value_to_add)
except AttributeError:
updated_value = current_value >> value_to_add
else:
raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.")
@ -1069,7 +1108,7 @@ def evaluate_ast(
Evaluate an abstract syntax tree using the content of the variables stored in a state and only evaluating a given
set of functions.
This function will recurse trough the nodes of the tree provided.
This function will recurse through the nodes of the tree provided.
Args:
expression (`ast.AST`):

View File

@ -14,6 +14,7 @@
# limitations under the License.
import unittest
from textwrap import dedent
import numpy as np
import pytest
@ -948,3 +949,44 @@ texec(tcompile("1 + 1", "no filename", "exec"))
evaluate_python_code(
dangerous_code, static_tools={"tcompile": compile, "teval": eval, "texec": exec} | BASE_PYTHON_TOOLS
)
class TestPythonInterpreter:
@pytest.mark.parametrize(
"code, expected_result",
[
(
dedent("""\
x = 1
x += 2
"""),
3,
),
(
dedent("""\
x = "a"
x += "b"
"""),
"ab",
),
(
dedent("""\
class Custom:
def __init__(self, value):
self.value = value
def __iadd__(self, other):
self.value += other * 10
return self
x = Custom(1)
x += 2
x.value
"""),
21,
),
],
)
def test_evaluate_augassign(self, code, expected_result):
state = {}
result, _ = evaluate_python_code(code, {}, state=state)
assert result == expected_result