Move check_module_authorized out of import_module for use in get_safe_module (#507)
Co-authored-by: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com>
This commit is contained in:
parent
8c6f90cc11
commit
9318c8cae3
|
@ -984,7 +984,8 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited=
|
|||
for attr_name in dir(raw_module):
|
||||
# Skip dangerous patterns at any level
|
||||
if any(
|
||||
pattern in raw_module.__name__.split(".") + [attr_name] and pattern not in authorized_imports
|
||||
pattern in raw_module.__name__.split(".") + [attr_name]
|
||||
and not check_module_authorized(pattern, authorized_imports, dangerous_patterns)
|
||||
for pattern in dangerous_patterns
|
||||
):
|
||||
logger.info(f"Skipping dangerous attribute {raw_module.__name__}.{attr_name}")
|
||||
|
@ -1007,6 +1008,18 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited=
|
|||
return safe_module
|
||||
|
||||
|
||||
def check_module_authorized(module_name, authorized_imports, dangerous_patterns):
|
||||
if "*" in authorized_imports:
|
||||
return True
|
||||
else:
|
||||
module_path = module_name.split(".")
|
||||
if any([module in dangerous_patterns and module not in authorized_imports for module in module_path]):
|
||||
return False
|
||||
# ["A", "B", "C"] -> ["A", "A.B", "A.B.C"]
|
||||
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
|
||||
return any(subpath in authorized_imports for subpath in module_subpaths)
|
||||
|
||||
|
||||
def import_modules(expression, state, authorized_imports):
|
||||
dangerous_patterns = (
|
||||
"_os",
|
||||
|
@ -1028,19 +1041,9 @@ def import_modules(expression, state, authorized_imports):
|
|||
"multiprocessing",
|
||||
)
|
||||
|
||||
def check_module_authorized(module_name):
|
||||
if "*" in authorized_imports:
|
||||
return True
|
||||
else:
|
||||
module_path = module_name.split(".")
|
||||
if any([module in dangerous_patterns and module not in authorized_imports for module in module_path]):
|
||||
return False
|
||||
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
|
||||
return any(subpath in authorized_imports for subpath in module_subpaths)
|
||||
|
||||
if isinstance(expression, ast.Import):
|
||||
for alias in expression.names:
|
||||
if check_module_authorized(alias.name):
|
||||
if check_module_authorized(alias.name, authorized_imports, dangerous_patterns):
|
||||
raw_module = import_module(alias.name)
|
||||
state[alias.asname or alias.name] = get_safe_module(raw_module, dangerous_patterns, authorized_imports)
|
||||
else:
|
||||
|
@ -1049,7 +1052,7 @@ def import_modules(expression, state, authorized_imports):
|
|||
)
|
||||
return None
|
||||
elif isinstance(expression, ast.ImportFrom):
|
||||
if check_module_authorized(expression.module):
|
||||
if check_module_authorized(expression.module, authorized_imports, dangerous_patterns):
|
||||
raw_module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
|
||||
module = get_safe_module(raw_module, dangerous_patterns, authorized_imports)
|
||||
if expression.names[0].name == "*": # Handle "from module import *"
|
||||
|
|
|
@ -25,6 +25,7 @@ from smolagents.default_tools import BASE_PYTHON_TOOLS
|
|||
from smolagents.local_python_executor import (
|
||||
InterpreterError,
|
||||
PrintContainer,
|
||||
check_module_authorized,
|
||||
evaluate_delete,
|
||||
evaluate_python_code,
|
||||
fix_final_answer_code,
|
||||
|
@ -975,6 +976,10 @@ texec(tcompile("1 + 1", "no filename", "exec"))
|
|||
dangerous_code = "import os; os.listdir('./')"
|
||||
evaluate_python_code(dangerous_code, authorized_imports=["os"])
|
||||
|
||||
def test_can_import_os_if_all_imports_authorized(self):
|
||||
dangerous_code = "import os; os.listdir('./')"
|
||||
evaluate_python_code(dangerous_code, authorized_imports=["*"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"code, expected_result",
|
||||
|
@ -1205,3 +1210,39 @@ class TestPrintContainer:
|
|||
pc = PrintContainer()
|
||||
pc.append("Hello")
|
||||
assert len(pc) == 5
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"module,authorized_imports,expected",
|
||||
[
|
||||
("os", ["*"], True),
|
||||
("AnyModule", ["*"], True),
|
||||
("os", ["os"], True),
|
||||
("AnyModule", ["AnyModule"], True),
|
||||
("Module.os", ["Module"], False),
|
||||
("Module.os", ["Module", "os"], True),
|
||||
("os.path", ["os"], True),
|
||||
("os", ["os.path"], False),
|
||||
],
|
||||
)
|
||||
def test_check_module_authorized(module: str, authorized_imports: list[str], expected: bool):
|
||||
dangerous_patterns = (
|
||||
"_os",
|
||||
"os",
|
||||
"subprocess",
|
||||
"_subprocess",
|
||||
"pty",
|
||||
"system",
|
||||
"popen",
|
||||
"spawn",
|
||||
"shutil",
|
||||
"sys",
|
||||
"pathlib",
|
||||
"io",
|
||||
"socket",
|
||||
"compile",
|
||||
"eval",
|
||||
"exec",
|
||||
"multiprocessing",
|
||||
)
|
||||
assert check_module_authorized(module, authorized_imports, dangerous_patterns) == expected
|
||||
|
|
Loading…
Reference in New Issue