diff --git a/src/smolagents/local_python_executor.py b/src/smolagents/local_python_executor.py index 5359cf5..27179a5 100644 --- a/src/smolagents/local_python_executor.py +++ b/src/smolagents/local_python_executor.py @@ -369,84 +369,45 @@ def evaluate_augassign( if isinstance(current_value, list): if not isinstance(value_to_add, list): raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.") - try: - updated_value = current_value.__iadd__(value_to_add) - except AttributeError: - updated_value = current_value + value_to_add + current_value += value_to_add else: - try: - updated_value = current_value.__iadd__(value_to_add) - except AttributeError: - updated_value = current_value + value_to_add + current_value += value_to_add elif isinstance(expression.op, ast.Sub): - try: - updated_value = current_value.__isub__(value_to_add) - except AttributeError: - updated_value = current_value - value_to_add + current_value -= value_to_add elif isinstance(expression.op, ast.Mult): - try: - updated_value = current_value.__imul__(value_to_add) - except AttributeError: - updated_value = current_value * value_to_add + current_value *= value_to_add elif isinstance(expression.op, ast.Div): - try: - updated_value = current_value.__itruediv__(value_to_add) - except AttributeError: - updated_value = current_value / value_to_add + current_value /= value_to_add elif isinstance(expression.op, ast.Mod): - try: - updated_value = current_value.__imod__(value_to_add) - except AttributeError: - updated_value = current_value % value_to_add + current_value %= value_to_add elif isinstance(expression.op, ast.Pow): - try: - updated_value = current_value.__ipow__(value_to_add) - except AttributeError: - updated_value = current_value**value_to_add + current_value **= value_to_add elif isinstance(expression.op, ast.FloorDiv): - try: - updated_value = current_value.__ifloordiv__(value_to_add) - except AttributeError: - updated_value = current_value // value_to_add + current_value //= value_to_add elif isinstance(expression.op, ast.BitAnd): - try: - updated_value = current_value.__iand__(value_to_add) - except AttributeError: - updated_value = current_value & value_to_add + current_value &= value_to_add elif isinstance(expression.op, ast.BitOr): - try: - updated_value = current_value.__ior__(value_to_add) - except AttributeError: - updated_value = current_value | value_to_add + current_value |= value_to_add elif isinstance(expression.op, ast.BitXor): - try: - updated_value = current_value.__ixor__(value_to_add) - except AttributeError: - updated_value = current_value ^ value_to_add + current_value ^= value_to_add elif isinstance(expression.op, ast.LShift): - try: - updated_value = current_value.__ilshift__(value_to_add) - except AttributeError: - updated_value = current_value << value_to_add + current_value <<= value_to_add elif isinstance(expression.op, ast.RShift): - try: - updated_value = current_value.__irshift__(value_to_add) - except AttributeError: - updated_value = current_value >> value_to_add + current_value >>= value_to_add else: raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.") - # Update the state + # Update the state: current_value has been updated in-place set_value( expression.target, - updated_value, + current_value, state, static_tools, custom_tools, authorized_imports, ) - return updated_value + return current_value def evaluate_boolop( diff --git a/tests/test_python_interpreter.py b/tests/test_python_interpreter.py index 870abe5..8aec8fe 100644 --- a/tests/test_python_interpreter.py +++ b/tests/test_python_interpreter.py @@ -951,42 +951,117 @@ texec(tcompile("1 + 1", "no filename", "exec")) ) -class TestPythonInterpreter: - @pytest.mark.parametrize( - "code, expected_result", - [ - ( - dedent("""\ - x = 1 - x += 2 - """), - 3, - ), - ( - dedent("""\ - x = "a" - x += "b" - """), - "ab", - ), - ( - dedent("""\ - class Custom: - def __init__(self, value): - self.value = value - def __iadd__(self, other): - self.value += other * 10 - return self +@pytest.mark.parametrize( + "code, expected_result", + [ + ( + dedent("""\ + x = 1 + x += 2 + """), + 3, + ), + ( + dedent("""\ + x = "a" + x += "b" + """), + "ab", + ), + ( + dedent("""\ + class Custom: + def __init__(self, value): + self.value = value + def __iadd__(self, other): + self.value += other * 10 + return self - x = Custom(1) - x += 2 - x.value - """), - 21, - ), - ], - ) - def test_evaluate_augassign(self, code, expected_result): - state = {} - result, _ = evaluate_python_code(code, {}, state=state) - assert result == expected_result + x = Custom(1) + x += 2 + x.value + """), + 21, + ), + ], +) +def test_evaluate_augassign(code, expected_result): + state = {} + result, _ = evaluate_python_code(code, {}, state=state) + assert result == expected_result + + +@pytest.mark.parametrize( + "operator, expected_result", + [ + ("+=", 7), + ("-=", 3), + ("*=", 10), + ("/=", 2.5), + ("//=", 2), + ("%=", 1), + ("**=", 25), + ("&=", 0), + ("|=", 7), + ("^=", 7), + (">>=", 1), + ("<<=", 20), + ], +) +def test_evaluate_augassign_number(operator, expected_result): + code = dedent("""\ + x = 5 + x {operator} 2 + """).format(operator=operator) + state = {} + result, _ = evaluate_python_code(code, {}, state=state) + assert result == expected_result + + +@pytest.mark.parametrize( + "operator, expected_result", + [ + ("+=", 7), + ("-=", 3), + ("*=", 10), + ("/=", 2.5), + ("//=", 2), + ("%=", 1), + ("**=", 25), + ("&=", 0), + ("|=", 7), + ("^=", 7), + (">>=", 1), + ("<<=", 20), + ], +) +def test_evaluate_augassign_custom(operator, expected_result): + operator_names = { + "+=": "iadd", + "-=": "isub", + "*=": "imul", + "/=": "itruediv", + "//=": "ifloordiv", + "%=": "imod", + "**=": "ipow", + "&=": "iand", + "|=": "ior", + "^=": "ixor", + ">>=": "irshift", + "<<=": "ilshift", + } + code = dedent("""\ + class Custom: + def __init__(self, value): + self.value = value + def __{operator_name}__(self, other): + self.value {operator} other + return self + + x = Custom(5) + x {operator} 2 + x.value + """).format(operator=operator, operator_name=operator_names[operator]) + state = {} + result, _ = evaluate_python_code(code, {}, state=state) + assert result == expected_result