Fix get safe module (#405)

* expose get_safe_module lazy loading issue

* fix get_safe_module lazy imports

* add six as explicit tests dependency

* log dangerous attributes

* refactor get_safe_module test
This commit is contained in:
Antoine Jeannot 2025-01-31 08:25:22 +01:00 committed by GitHub
parent 181a500c5d
commit 10407813e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 2 deletions

View File

@ -18,6 +18,7 @@ import ast
import builtins
import difflib
import inspect
import logging
import math
import re
from collections.abc import Mapping
@ -31,6 +32,9 @@ import pandas as pd
from .utils import BASE_BUILTIN_MODULES, truncate_content
logger = logging.getLogger(__name__)
class InterpreterError(ValueError):
"""
An error raised when the interpreter cannot evaluate a Python expression, due to syntax error or unsupported
@ -960,10 +964,17 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited=
pattern in raw_module.__name__.split(".") + [attr_name] and pattern not in authorized_imports
for pattern in dangerous_patterns
):
logger.info(f"Skipping dangerous attribute {raw_module.__name__}.{attr_name}")
continue
attr_value = getattr(raw_module, attr_name)
try:
attr_value = getattr(raw_module, attr_name)
except ImportError as e:
# lazy / dynamic loading module -> INFO log and skip
logger.info(
f"Skipping import error while copying {raw_module.__name__}.{attr_name}: {type(e).__name__} - {e}"
)
continue
# Recursively process nested modules, passing visited set
if isinstance(attr_value, ModuleType):
attr_value = get_safe_module(attr_value, dangerous_patterns, authorized_imports, visited=visited)

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import types
import unittest
from textwrap import dedent
@ -24,6 +25,7 @@ from smolagents.local_python_executor import (
InterpreterError,
evaluate_python_code,
fix_final_answer_code,
get_safe_module,
)
@ -1069,3 +1071,23 @@ def test_evaluate_augassign_custom(operator, expected_result):
state = {}
result, _ = evaluate_python_code(code, {}, state=state)
assert result == expected_result
def test_get_safe_module_handle_lazy_imports():
class FakeModule(types.ModuleType):
def __init__(self, name):
super().__init__(name)
self.non_lazy_attribute = "ok"
def __getattr__(self, name):
if name == "lazy_attribute":
raise ImportError("lazy import failure")
return super().__getattr__(name)
def __dir__(self):
return super().__dir__() + ["lazy_attribute"]
fake_module = FakeModule("fake_module")
safe_module = get_safe_module(fake_module, dangerous_patterns=[], authorized_imports=set())
assert not hasattr(safe_module, "lazy_attribute")
assert getattr(safe_module, "non_lazy_attribute") == "ok"