From 3c18d4d588a9ae3c83c4c9fc603bda308e6de9ff Mon Sep 17 00:00:00 2001 From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> Date: Mon, 20 Jan 2025 11:40:43 +0100 Subject: [PATCH] Python interpreter: improve suggestions for possible mappings (#266) --- src/smolagents/local_python_executor.py | 12 +++++++----- tests/test_python_interpreter.py | 8 ++++++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/smolagents/local_python_executor.py b/src/smolagents/local_python_executor.py index 9477be9..e46d87a 100644 --- a/src/smolagents/local_python_executor.py +++ b/src/smolagents/local_python_executor.py @@ -641,11 +641,13 @@ def evaluate_subscript( return value[index] elif index in value: return value[index] - elif isinstance(index, str) and isinstance(value, Mapping): - close_matches = difflib.get_close_matches(index, list(value.keys())) - if len(close_matches) > 0: - return value[close_matches[0]] - raise InterpreterError(f"Could not index {value} with '{index}'.") + else: + error_message = f"Could not index {value} with '{index}'." + if isinstance(index, str) and isinstance(value, Mapping): + close_matches = difflib.get_close_matches(index, list(value.keys())) + if len(close_matches) > 0: + error_message += f" Maybe you meant one of these indexes instead: {str(close_matches)}" + raise InterpreterError(error_message) def evaluate_name( diff --git a/tests/test_python_interpreter.py b/tests/test_python_interpreter.py index 0e83177..540720f 100644 --- a/tests/test_python_interpreter.py +++ b/tests/test_python_interpreter.py @@ -897,3 +897,11 @@ shift_intervals code = "import doctest;doctest.inspect.os.system('echo bad command passed')" with pytest.raises(AttributeError): evaluate_python_code(code, authorized_imports=["doctest"]) + + def test_close_matches_subscript(self): + code = 'capitals = {"Czech Republic": "Prague", "Monaco": "Monaco", "Bhutan": "Thimphu"};capitals["Butan"]' + with pytest.raises(Exception) as e: + evaluate_python_code(code) + assert "Maybe you meant one of these indexes instead" in str( + e + ) and "['Bhutan']" in str(e).replace("\\", "")