Fix evaluate_condition for non-bool result (#638)
This commit is contained in:
parent
d02093dc24
commit
94371331bb
|
@ -714,49 +714,39 @@ def evaluate_condition(
|
||||||
custom_tools: Dict[str, Callable],
|
custom_tools: Dict[str, Callable],
|
||||||
authorized_imports: List[str],
|
authorized_imports: List[str],
|
||||||
) -> bool | object:
|
) -> bool | object:
|
||||||
left = evaluate_ast(condition.left, state, static_tools, custom_tools, authorized_imports)
|
|
||||||
comparators = [
|
|
||||||
evaluate_ast(c, state, static_tools, custom_tools, authorized_imports) for c in condition.comparators
|
|
||||||
]
|
|
||||||
ops = [type(op) for op in condition.ops]
|
|
||||||
|
|
||||||
result = True
|
result = True
|
||||||
current_left = left
|
left = evaluate_ast(condition.left, state, static_tools, custom_tools, authorized_imports)
|
||||||
|
for i, (op, comparator) in enumerate(zip(condition.ops, condition.comparators)):
|
||||||
for op, comparator in zip(ops, comparators):
|
op = type(op)
|
||||||
|
right = evaluate_ast(comparator, state, static_tools, custom_tools, authorized_imports)
|
||||||
if op == ast.Eq:
|
if op == ast.Eq:
|
||||||
current_result = current_left == comparator
|
current_result = left == right
|
||||||
elif op == ast.NotEq:
|
elif op == ast.NotEq:
|
||||||
current_result = current_left != comparator
|
current_result = left != right
|
||||||
elif op == ast.Lt:
|
elif op == ast.Lt:
|
||||||
current_result = current_left < comparator
|
current_result = left < right
|
||||||
elif op == ast.LtE:
|
elif op == ast.LtE:
|
||||||
current_result = current_left <= comparator
|
current_result = left <= right
|
||||||
elif op == ast.Gt:
|
elif op == ast.Gt:
|
||||||
current_result = current_left > comparator
|
current_result = left > right
|
||||||
elif op == ast.GtE:
|
elif op == ast.GtE:
|
||||||
current_result = current_left >= comparator
|
current_result = left >= right
|
||||||
elif op == ast.Is:
|
elif op == ast.Is:
|
||||||
current_result = current_left is comparator
|
current_result = left is right
|
||||||
elif op == ast.IsNot:
|
elif op == ast.IsNot:
|
||||||
current_result = current_left is not comparator
|
current_result = left is not right
|
||||||
elif op == ast.In:
|
elif op == ast.In:
|
||||||
current_result = current_left in comparator
|
current_result = left in right
|
||||||
elif op == ast.NotIn:
|
elif op == ast.NotIn:
|
||||||
current_result = current_left not in comparator
|
current_result = left not in right
|
||||||
else:
|
else:
|
||||||
raise InterpreterError(f"Operator not supported: {op}")
|
raise InterpreterError(f"Unsupported comparison operator: {op}")
|
||||||
|
|
||||||
if not isinstance(current_result, bool):
|
if current_result is False:
|
||||||
return current_result
|
return False
|
||||||
|
result = current_result if i == 0 else (result and current_result)
|
||||||
result = result & current_result
|
left = right
|
||||||
current_left = comparator
|
return result
|
||||||
|
|
||||||
if isinstance(result, bool) and not result:
|
|
||||||
break
|
|
||||||
|
|
||||||
return result if isinstance(result, (bool, pd.Series)) else result.all()
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate_if(
|
def evaluate_if(
|
||||||
|
|
|
@ -19,6 +19,7 @@ import unittest
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from smolagents.default_tools import BASE_PYTHON_TOOLS
|
from smolagents.default_tools import BASE_PYTHON_TOOLS
|
||||||
|
@ -1188,11 +1189,11 @@ def test_evaluate_delete(code, state, expectation):
|
||||||
("a in b", {"a": 4, "b": [1, 2, 3]}, False),
|
("a in b", {"a": 4, "b": [1, 2, 3]}, False),
|
||||||
("a not in b", {"a": 1, "b": [1, 2, 3]}, False),
|
("a not in b", {"a": 1, "b": [1, 2, 3]}, False),
|
||||||
("a not in b", {"a": 4, "b": [1, 2, 3]}, True),
|
("a not in b", {"a": 4, "b": [1, 2, 3]}, True),
|
||||||
# Composite conditions:
|
# Chained conditions:
|
||||||
("a == b == c", {"a": 1, "b": 1, "c": 1}, True),
|
("a == b == c", {"a": 1, "b": 1, "c": 1}, True),
|
||||||
("a == b == c", {"a": 1, "b": 2, "c": 1}, False),
|
("a == b == c", {"a": 1, "b": 2, "c": 1}, False),
|
||||||
("a == b < c", {"a": 1, "b": 1, "c": 1}, False),
|
("a == b < c", {"a": 2, "b": 2, "c": 2}, False),
|
||||||
("a == b < c", {"a": 1, "b": 1, "c": 2}, True),
|
("a == b < c", {"a": 0, "b": 0, "c": 1}, True),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_evaluate_condition(condition, state, expected_result):
|
def test_evaluate_condition(condition, state, expected_result):
|
||||||
|
@ -1201,6 +1202,91 @@ def test_evaluate_condition(condition, state, expected_result):
|
||||||
assert result == expected_result
|
assert result == expected_result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"condition, state, expected_result",
|
||||||
|
[
|
||||||
|
("a == b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([False, True, False])),
|
||||||
|
("a != b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([True, False, True])),
|
||||||
|
("a < b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([True, False, False])),
|
||||||
|
("a <= b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([True, True, False])),
|
||||||
|
("a > b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([False, False, True])),
|
||||||
|
("a >= b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([False, True, True])),
|
||||||
|
(
|
||||||
|
"a == b",
|
||||||
|
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [1, 2], "y": [3, 5]})},
|
||||||
|
pd.DataFrame({"x": [True, True], "y": [True, False]}),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"a != b",
|
||||||
|
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [1, 2], "y": [3, 5]})},
|
||||||
|
pd.DataFrame({"x": [False, False], "y": [False, True]}),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"a < b",
|
||||||
|
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})},
|
||||||
|
pd.DataFrame({"x": [True, False], "y": [False, False]}),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"a <= b",
|
||||||
|
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})},
|
||||||
|
pd.DataFrame({"x": [True, True], "y": [False, False]}),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"a > b",
|
||||||
|
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})},
|
||||||
|
pd.DataFrame({"x": [False, False], "y": [True, True]}),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"a >= b",
|
||||||
|
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})},
|
||||||
|
pd.DataFrame({"x": [False, True], "y": [True, True]}),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_evaluate_condition_with_pandas(condition, state, expected_result):
|
||||||
|
condition_ast = ast.parse(condition, mode="eval").body
|
||||||
|
result = evaluate_condition(condition_ast, state, {}, {}, [])
|
||||||
|
if isinstance(result, pd.Series):
|
||||||
|
pd.testing.assert_series_equal(result, expected_result)
|
||||||
|
else:
|
||||||
|
pd.testing.assert_frame_equal(result, expected_result)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"condition, state, expected_exception",
|
||||||
|
[
|
||||||
|
# Chained conditions:
|
||||||
|
(
|
||||||
|
"a == b == c",
|
||||||
|
{
|
||||||
|
"a": pd.Series([1, 2, 3]),
|
||||||
|
"b": pd.Series([2, 2, 2]),
|
||||||
|
"c": pd.Series([3, 3, 3]),
|
||||||
|
},
|
||||||
|
ValueError(
|
||||||
|
"The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all()."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"a == b == c",
|
||||||
|
{
|
||||||
|
"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}),
|
||||||
|
"b": pd.DataFrame({"x": [2, 2], "y": [2, 2]}),
|
||||||
|
"c": pd.DataFrame({"x": [3, 3], "y": [3, 3]}),
|
||||||
|
},
|
||||||
|
ValueError(
|
||||||
|
"The truth value of a DataFrame is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all()."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_evaluate_condition_with_pandas_exceptions(condition, state, expected_exception):
|
||||||
|
condition_ast = ast.parse(condition, mode="eval").body
|
||||||
|
with pytest.raises(type(expected_exception)) as exception_info:
|
||||||
|
_ = evaluate_condition(condition_ast, state, {}, {}, [])
|
||||||
|
assert str(expected_exception) in str(exception_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_get_safe_module_handle_lazy_imports():
|
def test_get_safe_module_handle_lazy_imports():
|
||||||
class FakeModule(types.ModuleType):
|
class FakeModule(types.ModuleType):
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
|
@ -1222,7 +1308,7 @@ def test_get_safe_module_handle_lazy_imports():
|
||||||
|
|
||||||
|
|
||||||
def test_non_standard_comparisons():
|
def test_non_standard_comparisons():
|
||||||
code = """
|
code = dedent("""\
|
||||||
class NonStdEqualsResult:
|
class NonStdEqualsResult:
|
||||||
def __init__(self, left:object, right:object):
|
def __init__(self, left:object, right:object):
|
||||||
self._left = left
|
self._left = left
|
||||||
|
@ -1240,7 +1326,7 @@ class NonStdComparisonClass:
|
||||||
a = NonStdComparisonClass("a")
|
a = NonStdComparisonClass("a")
|
||||||
b = NonStdComparisonClass("b")
|
b = NonStdComparisonClass("b")
|
||||||
result = a == b
|
result = a == b
|
||||||
"""
|
""")
|
||||||
result, _ = evaluate_python_code(code, state={})
|
result, _ = evaluate_python_code(code, state={})
|
||||||
assert not isinstance(result, bool)
|
assert not isinstance(result, bool)
|
||||||
assert str(result) == "a == b"
|
assert str(result) == "a == b"
|
||||||
|
|
Loading…
Reference in New Issue