Refactor evaluate_augassign and test all operators (#313)

* Test all augassign operators

* Refactor evaluate_augassign
This commit is contained in:
Albert Villanova del Moral 2025-01-22 14:59:09 +01:00 committed by GitHub
parent 6196958deb
commit 0ead477263
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 129 additions and 93 deletions

View File

@ -369,84 +369,45 @@ def evaluate_augassign(
if isinstance(current_value, list): if isinstance(current_value, list):
if not isinstance(value_to_add, list): if not isinstance(value_to_add, list):
raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.") raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.")
try: current_value += value_to_add
updated_value = current_value.__iadd__(value_to_add)
except AttributeError:
updated_value = current_value + value_to_add
else: else:
try: current_value += value_to_add
updated_value = current_value.__iadd__(value_to_add)
except AttributeError:
updated_value = current_value + value_to_add
elif isinstance(expression.op, ast.Sub): elif isinstance(expression.op, ast.Sub):
try: current_value -= value_to_add
updated_value = current_value.__isub__(value_to_add)
except AttributeError:
updated_value = current_value - value_to_add
elif isinstance(expression.op, ast.Mult): elif isinstance(expression.op, ast.Mult):
try: current_value *= value_to_add
updated_value = current_value.__imul__(value_to_add)
except AttributeError:
updated_value = current_value * value_to_add
elif isinstance(expression.op, ast.Div): elif isinstance(expression.op, ast.Div):
try: current_value /= value_to_add
updated_value = current_value.__itruediv__(value_to_add)
except AttributeError:
updated_value = current_value / value_to_add
elif isinstance(expression.op, ast.Mod): elif isinstance(expression.op, ast.Mod):
try: current_value %= value_to_add
updated_value = current_value.__imod__(value_to_add)
except AttributeError:
updated_value = current_value % value_to_add
elif isinstance(expression.op, ast.Pow): elif isinstance(expression.op, ast.Pow):
try: current_value **= value_to_add
updated_value = current_value.__ipow__(value_to_add)
except AttributeError:
updated_value = current_value**value_to_add
elif isinstance(expression.op, ast.FloorDiv): elif isinstance(expression.op, ast.FloorDiv):
try: current_value //= value_to_add
updated_value = current_value.__ifloordiv__(value_to_add)
except AttributeError:
updated_value = current_value // value_to_add
elif isinstance(expression.op, ast.BitAnd): elif isinstance(expression.op, ast.BitAnd):
try: current_value &= value_to_add
updated_value = current_value.__iand__(value_to_add)
except AttributeError:
updated_value = current_value & value_to_add
elif isinstance(expression.op, ast.BitOr): elif isinstance(expression.op, ast.BitOr):
try: current_value |= value_to_add
updated_value = current_value.__ior__(value_to_add)
except AttributeError:
updated_value = current_value | value_to_add
elif isinstance(expression.op, ast.BitXor): elif isinstance(expression.op, ast.BitXor):
try: current_value ^= value_to_add
updated_value = current_value.__ixor__(value_to_add)
except AttributeError:
updated_value = current_value ^ value_to_add
elif isinstance(expression.op, ast.LShift): elif isinstance(expression.op, ast.LShift):
try: current_value <<= value_to_add
updated_value = current_value.__ilshift__(value_to_add)
except AttributeError:
updated_value = current_value << value_to_add
elif isinstance(expression.op, ast.RShift): elif isinstance(expression.op, ast.RShift):
try: current_value >>= value_to_add
updated_value = current_value.__irshift__(value_to_add)
except AttributeError:
updated_value = current_value >> value_to_add
else: else:
raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.") raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.")
# Update the state # Update the state: current_value has been updated in-place
set_value( set_value(
expression.target, expression.target,
updated_value, current_value,
state, state,
static_tools, static_tools,
custom_tools, custom_tools,
authorized_imports, authorized_imports,
) )
return updated_value return current_value
def evaluate_boolop( def evaluate_boolop(

View File

@ -951,42 +951,117 @@ texec(tcompile("1 + 1", "no filename", "exec"))
) )
class TestPythonInterpreter: @pytest.mark.parametrize(
@pytest.mark.parametrize( "code, expected_result",
"code, expected_result", [
[ (
( dedent("""\
dedent("""\ x = 1
x = 1 x += 2
x += 2 """),
"""), 3,
3, ),
), (
( dedent("""\
dedent("""\ x = "a"
x = "a" x += "b"
x += "b" """),
"""), "ab",
"ab", ),
), (
( dedent("""\
dedent("""\ class Custom:
class Custom: def __init__(self, value):
def __init__(self, value): self.value = value
self.value = value def __iadd__(self, other):
def __iadd__(self, other): self.value += other * 10
self.value += other * 10 return self
return self
x = Custom(1) x = Custom(1)
x += 2 x += 2
x.value x.value
"""), """),
21, 21,
), ),
], ],
) )
def test_evaluate_augassign(self, code, expected_result): def test_evaluate_augassign(code, expected_result):
state = {} state = {}
result, _ = evaluate_python_code(code, {}, state=state) result, _ = evaluate_python_code(code, {}, state=state)
assert result == expected_result assert result == expected_result
@pytest.mark.parametrize(
"operator, expected_result",
[
("+=", 7),
("-=", 3),
("*=", 10),
("/=", 2.5),
("//=", 2),
("%=", 1),
("**=", 25),
("&=", 0),
("|=", 7),
("^=", 7),
(">>=", 1),
("<<=", 20),
],
)
def test_evaluate_augassign_number(operator, expected_result):
code = dedent("""\
x = 5
x {operator} 2
""").format(operator=operator)
state = {}
result, _ = evaluate_python_code(code, {}, state=state)
assert result == expected_result
@pytest.mark.parametrize(
"operator, expected_result",
[
("+=", 7),
("-=", 3),
("*=", 10),
("/=", 2.5),
("//=", 2),
("%=", 1),
("**=", 25),
("&=", 0),
("|=", 7),
("^=", 7),
(">>=", 1),
("<<=", 20),
],
)
def test_evaluate_augassign_custom(operator, expected_result):
operator_names = {
"+=": "iadd",
"-=": "isub",
"*=": "imul",
"/=": "itruediv",
"//=": "ifloordiv",
"%=": "imod",
"**=": "ipow",
"&=": "iand",
"|=": "ior",
"^=": "ixor",
">>=": "irshift",
"<<=": "ilshift",
}
code = dedent("""\
class Custom:
def __init__(self, value):
self.value = value
def __{operator_name}__(self, other):
self.value {operator} other
return self
x = Custom(5)
x {operator} 2
x.value
""").format(operator=operator, operator_name=operator_names[operator])
state = {}
result, _ = evaluate_python_code(code, {}, state=state)
assert result == expected_result