Fix get safe module (#405)
* expose get_safe_module lazy loading issue * fix get_safe_module lazy imports * add six as explicit tests dependency * log dangerous attributes * refactor get_safe_module test
This commit is contained in:
parent
181a500c5d
commit
10407813e8
|
@ -18,6 +18,7 @@ import ast
|
|||
import builtins
|
||||
import difflib
|
||||
import inspect
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
|
@ -31,6 +32,9 @@ import pandas as pd
|
|||
from .utils import BASE_BUILTIN_MODULES, truncate_content
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InterpreterError(ValueError):
|
||||
"""
|
||||
An error raised when the interpreter cannot evaluate a Python expression, due to syntax error or unsupported
|
||||
|
@ -960,10 +964,17 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited=
|
|||
pattern in raw_module.__name__.split(".") + [attr_name] and pattern not in authorized_imports
|
||||
for pattern in dangerous_patterns
|
||||
):
|
||||
logger.info(f"Skipping dangerous attribute {raw_module.__name__}.{attr_name}")
|
||||
continue
|
||||
|
||||
attr_value = getattr(raw_module, attr_name)
|
||||
|
||||
try:
|
||||
attr_value = getattr(raw_module, attr_name)
|
||||
except ImportError as e:
|
||||
# lazy / dynamic loading module -> INFO log and skip
|
||||
logger.info(
|
||||
f"Skipping import error while copying {raw_module.__name__}.{attr_name}: {type(e).__name__} - {e}"
|
||||
)
|
||||
continue
|
||||
# Recursively process nested modules, passing visited set
|
||||
if isinstance(attr_value, ModuleType):
|
||||
attr_value = get_safe_module(attr_value, dangerous_patterns, authorized_imports, visited=visited)
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import types
|
||||
import unittest
|
||||
from textwrap import dedent
|
||||
|
||||
|
@ -24,6 +25,7 @@ from smolagents.local_python_executor import (
|
|||
InterpreterError,
|
||||
evaluate_python_code,
|
||||
fix_final_answer_code,
|
||||
get_safe_module,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1069,3 +1071,23 @@ def test_evaluate_augassign_custom(operator, expected_result):
|
|||
state = {}
|
||||
result, _ = evaluate_python_code(code, {}, state=state)
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_get_safe_module_handle_lazy_imports():
|
||||
class FakeModule(types.ModuleType):
|
||||
def __init__(self, name):
|
||||
super().__init__(name)
|
||||
self.non_lazy_attribute = "ok"
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name == "lazy_attribute":
|
||||
raise ImportError("lazy import failure")
|
||||
return super().__getattr__(name)
|
||||
|
||||
def __dir__(self):
|
||||
return super().__dir__() + ["lazy_attribute"]
|
||||
|
||||
fake_module = FakeModule("fake_module")
|
||||
safe_module = get_safe_module(fake_module, dangerous_patterns=[], authorized_imports=set())
|
||||
assert not hasattr(safe_module, "lazy_attribute")
|
||||
assert getattr(safe_module, "non_lazy_attribute") == "ok"
|
||||
|
|
Loading…
Reference in New Issue