Move check_module_authorized out of import_module for use in get_safe_module (#507)
Co-authored-by: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com>
This commit is contained in:
parent
8c6f90cc11
commit
9318c8cae3
|
@ -984,7 +984,8 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited=
|
||||||
for attr_name in dir(raw_module):
|
for attr_name in dir(raw_module):
|
||||||
# Skip dangerous patterns at any level
|
# Skip dangerous patterns at any level
|
||||||
if any(
|
if any(
|
||||||
pattern in raw_module.__name__.split(".") + [attr_name] and pattern not in authorized_imports
|
pattern in raw_module.__name__.split(".") + [attr_name]
|
||||||
|
and not check_module_authorized(pattern, authorized_imports, dangerous_patterns)
|
||||||
for pattern in dangerous_patterns
|
for pattern in dangerous_patterns
|
||||||
):
|
):
|
||||||
logger.info(f"Skipping dangerous attribute {raw_module.__name__}.{attr_name}")
|
logger.info(f"Skipping dangerous attribute {raw_module.__name__}.{attr_name}")
|
||||||
|
@ -1007,6 +1008,18 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited=
|
||||||
return safe_module
|
return safe_module
|
||||||
|
|
||||||
|
|
||||||
|
def check_module_authorized(module_name, authorized_imports, dangerous_patterns):
|
||||||
|
if "*" in authorized_imports:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
module_path = module_name.split(".")
|
||||||
|
if any([module in dangerous_patterns and module not in authorized_imports for module in module_path]):
|
||||||
|
return False
|
||||||
|
# ["A", "B", "C"] -> ["A", "A.B", "A.B.C"]
|
||||||
|
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
|
||||||
|
return any(subpath in authorized_imports for subpath in module_subpaths)
|
||||||
|
|
||||||
|
|
||||||
def import_modules(expression, state, authorized_imports):
|
def import_modules(expression, state, authorized_imports):
|
||||||
dangerous_patterns = (
|
dangerous_patterns = (
|
||||||
"_os",
|
"_os",
|
||||||
|
@ -1028,19 +1041,9 @@ def import_modules(expression, state, authorized_imports):
|
||||||
"multiprocessing",
|
"multiprocessing",
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_module_authorized(module_name):
|
|
||||||
if "*" in authorized_imports:
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
module_path = module_name.split(".")
|
|
||||||
if any([module in dangerous_patterns and module not in authorized_imports for module in module_path]):
|
|
||||||
return False
|
|
||||||
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
|
|
||||||
return any(subpath in authorized_imports for subpath in module_subpaths)
|
|
||||||
|
|
||||||
if isinstance(expression, ast.Import):
|
if isinstance(expression, ast.Import):
|
||||||
for alias in expression.names:
|
for alias in expression.names:
|
||||||
if check_module_authorized(alias.name):
|
if check_module_authorized(alias.name, authorized_imports, dangerous_patterns):
|
||||||
raw_module = import_module(alias.name)
|
raw_module = import_module(alias.name)
|
||||||
state[alias.asname or alias.name] = get_safe_module(raw_module, dangerous_patterns, authorized_imports)
|
state[alias.asname or alias.name] = get_safe_module(raw_module, dangerous_patterns, authorized_imports)
|
||||||
else:
|
else:
|
||||||
|
@ -1049,7 +1052,7 @@ def import_modules(expression, state, authorized_imports):
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
elif isinstance(expression, ast.ImportFrom):
|
elif isinstance(expression, ast.ImportFrom):
|
||||||
if check_module_authorized(expression.module):
|
if check_module_authorized(expression.module, authorized_imports, dangerous_patterns):
|
||||||
raw_module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
|
raw_module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
|
||||||
module = get_safe_module(raw_module, dangerous_patterns, authorized_imports)
|
module = get_safe_module(raw_module, dangerous_patterns, authorized_imports)
|
||||||
if expression.names[0].name == "*": # Handle "from module import *"
|
if expression.names[0].name == "*": # Handle "from module import *"
|
||||||
|
|
|
@ -25,6 +25,7 @@ from smolagents.default_tools import BASE_PYTHON_TOOLS
|
||||||
from smolagents.local_python_executor import (
|
from smolagents.local_python_executor import (
|
||||||
InterpreterError,
|
InterpreterError,
|
||||||
PrintContainer,
|
PrintContainer,
|
||||||
|
check_module_authorized,
|
||||||
evaluate_delete,
|
evaluate_delete,
|
||||||
evaluate_python_code,
|
evaluate_python_code,
|
||||||
fix_final_answer_code,
|
fix_final_answer_code,
|
||||||
|
@ -975,6 +976,10 @@ texec(tcompile("1 + 1", "no filename", "exec"))
|
||||||
dangerous_code = "import os; os.listdir('./')"
|
dangerous_code = "import os; os.listdir('./')"
|
||||||
evaluate_python_code(dangerous_code, authorized_imports=["os"])
|
evaluate_python_code(dangerous_code, authorized_imports=["os"])
|
||||||
|
|
||||||
|
def test_can_import_os_if_all_imports_authorized(self):
|
||||||
|
dangerous_code = "import os; os.listdir('./')"
|
||||||
|
evaluate_python_code(dangerous_code, authorized_imports=["*"])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"code, expected_result",
|
"code, expected_result",
|
||||||
|
@ -1205,3 +1210,39 @@ class TestPrintContainer:
|
||||||
pc = PrintContainer()
|
pc = PrintContainer()
|
||||||
pc.append("Hello")
|
pc.append("Hello")
|
||||||
assert len(pc) == 5
|
assert len(pc) == 5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"module,authorized_imports,expected",
|
||||||
|
[
|
||||||
|
("os", ["*"], True),
|
||||||
|
("AnyModule", ["*"], True),
|
||||||
|
("os", ["os"], True),
|
||||||
|
("AnyModule", ["AnyModule"], True),
|
||||||
|
("Module.os", ["Module"], False),
|
||||||
|
("Module.os", ["Module", "os"], True),
|
||||||
|
("os.path", ["os"], True),
|
||||||
|
("os", ["os.path"], False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_check_module_authorized(module: str, authorized_imports: list[str], expected: bool):
|
||||||
|
dangerous_patterns = (
|
||||||
|
"_os",
|
||||||
|
"os",
|
||||||
|
"subprocess",
|
||||||
|
"_subprocess",
|
||||||
|
"pty",
|
||||||
|
"system",
|
||||||
|
"popen",
|
||||||
|
"spawn",
|
||||||
|
"shutil",
|
||||||
|
"sys",
|
||||||
|
"pathlib",
|
||||||
|
"io",
|
||||||
|
"socket",
|
||||||
|
"compile",
|
||||||
|
"eval",
|
||||||
|
"exec",
|
||||||
|
"multiprocessing",
|
||||||
|
)
|
||||||
|
assert check_module_authorized(module, authorized_imports, dangerous_patterns) == expected
|
||||||
|
|
Loading…
Reference in New Issue