Add support for non-bool comparison operators. (#612)

This commit is contained in:
Billy Neuson 2025-02-12 13:04:31 -05:00 committed by GitHub
parent 9b96199d00
commit 5fd0a2e86e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 29 additions and 1 deletions

View File

@ -713,7 +713,7 @@ def evaluate_condition(
static_tools: Dict[str, Callable], static_tools: Dict[str, Callable],
custom_tools: Dict[str, Callable], custom_tools: Dict[str, Callable],
authorized_imports: List[str], authorized_imports: List[str],
) -> bool: ) -> bool | object:
left = evaluate_ast(condition.left, state, static_tools, custom_tools, authorized_imports) left = evaluate_ast(condition.left, state, static_tools, custom_tools, authorized_imports)
comparators = [ comparators = [
evaluate_ast(c, state, static_tools, custom_tools, authorized_imports) for c in condition.comparators evaluate_ast(c, state, static_tools, custom_tools, authorized_imports) for c in condition.comparators
@ -747,6 +747,9 @@ def evaluate_condition(
else: else:
raise InterpreterError(f"Operator not supported: {op}") raise InterpreterError(f"Operator not supported: {op}")
if not isinstance(current_result, bool):
return current_result
result = result & current_result result = result & current_result
current_left = comparator current_left = comparator

View File

@ -1181,6 +1181,31 @@ def test_get_safe_module_handle_lazy_imports():
assert getattr(safe_module, "non_lazy_attribute") == "ok" assert getattr(safe_module, "non_lazy_attribute") == "ok"
def test_non_standard_comparisons():
code = """
class NonStdEqualsResult:
def __init__(self, left:object, right:object):
self._left = left
self._right = right
def __str__(self) -> str:
return f'{self._left}=={self._right}'
class NonStdComparisonClass:
def __init__(self, value: str ):
self._value = value
def __str__(self):
return self._value
def __eq__(self, other):
return NonStdEqualsResult(self, other)
a = NonStdComparisonClass("a")
b = NonStdComparisonClass("b")
result = a == b
"""
result, _ = evaluate_python_code(code, state={})
assert not isinstance(result, bool)
assert str(result) == "a==b"
class TestPrintContainer: class TestPrintContainer:
def test_initial_value(self): def test_initial_value(self):
pc = PrintContainer() pc = PrintContainer()