Move check_module_authorized out of import_module for use in get_safe_module (#507)

Co-authored-by: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com>
This commit is contained in:
CalOmnie 2025-02-11 11:21:43 +01:00 committed by GitHub
parent 8c6f90cc11
commit 9318c8cae3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 57 additions and 13 deletions

View File

@ -984,7 +984,8 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited=
for attr_name in dir(raw_module):
# Skip dangerous patterns at any level
if any(
pattern in raw_module.__name__.split(".") + [attr_name] and pattern not in authorized_imports
pattern in raw_module.__name__.split(".") + [attr_name]
and not check_module_authorized(pattern, authorized_imports, dangerous_patterns)
for pattern in dangerous_patterns
):
logger.info(f"Skipping dangerous attribute {raw_module.__name__}.{attr_name}")
@ -1007,6 +1008,18 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited=
return safe_module
def check_module_authorized(module_name, authorized_imports, dangerous_patterns):
if "*" in authorized_imports:
return True
else:
module_path = module_name.split(".")
if any([module in dangerous_patterns and module not in authorized_imports for module in module_path]):
return False
# ["A", "B", "C"] -> ["A", "A.B", "A.B.C"]
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
return any(subpath in authorized_imports for subpath in module_subpaths)
def import_modules(expression, state, authorized_imports):
dangerous_patterns = (
"_os",
@ -1028,19 +1041,9 @@ def import_modules(expression, state, authorized_imports):
"multiprocessing",
)
def check_module_authorized(module_name):
if "*" in authorized_imports:
return True
else:
module_path = module_name.split(".")
if any([module in dangerous_patterns and module not in authorized_imports for module in module_path]):
return False
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
return any(subpath in authorized_imports for subpath in module_subpaths)
if isinstance(expression, ast.Import):
for alias in expression.names:
if check_module_authorized(alias.name):
if check_module_authorized(alias.name, authorized_imports, dangerous_patterns):
raw_module = import_module(alias.name)
state[alias.asname or alias.name] = get_safe_module(raw_module, dangerous_patterns, authorized_imports)
else:
@ -1049,7 +1052,7 @@ def import_modules(expression, state, authorized_imports):
)
return None
elif isinstance(expression, ast.ImportFrom):
if check_module_authorized(expression.module):
if check_module_authorized(expression.module, authorized_imports, dangerous_patterns):
raw_module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
module = get_safe_module(raw_module, dangerous_patterns, authorized_imports)
if expression.names[0].name == "*": # Handle "from module import *"

View File

@ -25,6 +25,7 @@ from smolagents.default_tools import BASE_PYTHON_TOOLS
from smolagents.local_python_executor import (
InterpreterError,
PrintContainer,
check_module_authorized,
evaluate_delete,
evaluate_python_code,
fix_final_answer_code,
@ -975,6 +976,10 @@ texec(tcompile("1 + 1", "no filename", "exec"))
dangerous_code = "import os; os.listdir('./')"
evaluate_python_code(dangerous_code, authorized_imports=["os"])
def test_can_import_os_if_all_imports_authorized(self):
dangerous_code = "import os; os.listdir('./')"
evaluate_python_code(dangerous_code, authorized_imports=["*"])
@pytest.mark.parametrize(
"code, expected_result",
@ -1205,3 +1210,39 @@ class TestPrintContainer:
pc = PrintContainer()
pc.append("Hello")
assert len(pc) == 5
@pytest.mark.parametrize(
"module,authorized_imports,expected",
[
("os", ["*"], True),
("AnyModule", ["*"], True),
("os", ["os"], True),
("AnyModule", ["AnyModule"], True),
("Module.os", ["Module"], False),
("Module.os", ["Module", "os"], True),
("os.path", ["os"], True),
("os", ["os.path"], False),
],
)
def test_check_module_authorized(module: str, authorized_imports: list[str], expected: bool):
dangerous_patterns = (
"_os",
"os",
"subprocess",
"_subprocess",
"pty",
"system",
"popen",
"spawn",
"shutil",
"sys",
"pathlib",
"io",
"socket",
"compile",
"eval",
"exec",
"multiprocessing",
)
assert check_module_authorized(module, authorized_imports, dangerous_patterns) == expected