Make dangerous_patterns a module variable (#505)
This commit is contained in:
parent
2982e88409
commit
1df9dca33a
|
@ -114,6 +114,26 @@ BASE_PYTHON_TOOLS = {
|
||||||
"complex": complex,
|
"complex": complex,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DANGEROUS_PATTERNS = (
|
||||||
|
"_os",
|
||||||
|
"os",
|
||||||
|
"subprocess",
|
||||||
|
"_subprocess",
|
||||||
|
"pty",
|
||||||
|
"system",
|
||||||
|
"popen",
|
||||||
|
"spawn",
|
||||||
|
"shutil",
|
||||||
|
"sys",
|
||||||
|
"pathlib",
|
||||||
|
"io",
|
||||||
|
"socket",
|
||||||
|
"compile",
|
||||||
|
"eval",
|
||||||
|
"exec",
|
||||||
|
"multiprocessing",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PrintContainer:
|
class PrintContainer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -954,7 +974,7 @@ def evaluate_with(
|
||||||
context.__exit__(None, None, None)
|
context.__exit__(None, None, None)
|
||||||
|
|
||||||
|
|
||||||
def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited=None):
|
def get_safe_module(raw_module, authorized_imports, visited=None):
|
||||||
"""Creates a safe copy of a module or returns the original if it's a function"""
|
"""Creates a safe copy of a module or returns the original if it's a function"""
|
||||||
# If it's a function or non-module object, return it directly
|
# If it's a function or non-module object, return it directly
|
||||||
if not isinstance(raw_module, ModuleType):
|
if not isinstance(raw_module, ModuleType):
|
||||||
|
@ -978,8 +998,8 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited=
|
||||||
# Skip dangerous patterns at any level
|
# Skip dangerous patterns at any level
|
||||||
if any(
|
if any(
|
||||||
pattern in raw_module.__name__.split(".") + [attr_name]
|
pattern in raw_module.__name__.split(".") + [attr_name]
|
||||||
and not check_module_authorized(pattern, authorized_imports, dangerous_patterns)
|
and not check_module_authorized(pattern, authorized_imports)
|
||||||
for pattern in dangerous_patterns
|
for pattern in DANGEROUS_PATTERNS
|
||||||
):
|
):
|
||||||
logger.info(f"Skipping dangerous attribute {raw_module.__name__}.{attr_name}")
|
logger.info(f"Skipping dangerous attribute {raw_module.__name__}.{attr_name}")
|
||||||
continue
|
continue
|
||||||
|
@ -994,19 +1014,19 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited=
|
||||||
continue
|
continue
|
||||||
# Recursively process nested modules, passing visited set
|
# Recursively process nested modules, passing visited set
|
||||||
if isinstance(attr_value, ModuleType):
|
if isinstance(attr_value, ModuleType):
|
||||||
attr_value = get_safe_module(attr_value, dangerous_patterns, authorized_imports, visited=visited)
|
attr_value = get_safe_module(attr_value, authorized_imports, visited=visited)
|
||||||
|
|
||||||
setattr(safe_module, attr_name, attr_value)
|
setattr(safe_module, attr_name, attr_value)
|
||||||
|
|
||||||
return safe_module
|
return safe_module
|
||||||
|
|
||||||
|
|
||||||
def check_module_authorized(module_name, authorized_imports, dangerous_patterns):
|
def check_module_authorized(module_name, authorized_imports):
|
||||||
if "*" in authorized_imports:
|
if "*" in authorized_imports:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
module_path = module_name.split(".")
|
module_path = module_name.split(".")
|
||||||
if any([module in dangerous_patterns and module not in authorized_imports for module in module_path]):
|
if any([module in DANGEROUS_PATTERNS and module not in authorized_imports for module in module_path]):
|
||||||
return False
|
return False
|
||||||
# ["A", "B", "C"] -> ["A", "A.B", "A.B.C"]
|
# ["A", "B", "C"] -> ["A", "A.B", "A.B.C"]
|
||||||
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
|
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
|
||||||
|
@ -1014,40 +1034,20 @@ def check_module_authorized(module_name, authorized_imports, dangerous_patterns)
|
||||||
|
|
||||||
|
|
||||||
def import_modules(expression, state, authorized_imports):
|
def import_modules(expression, state, authorized_imports):
|
||||||
dangerous_patterns = (
|
|
||||||
"_os",
|
|
||||||
"os",
|
|
||||||
"subprocess",
|
|
||||||
"_subprocess",
|
|
||||||
"pty",
|
|
||||||
"system",
|
|
||||||
"popen",
|
|
||||||
"spawn",
|
|
||||||
"shutil",
|
|
||||||
"sys",
|
|
||||||
"pathlib",
|
|
||||||
"io",
|
|
||||||
"socket",
|
|
||||||
"compile",
|
|
||||||
"eval",
|
|
||||||
"exec",
|
|
||||||
"multiprocessing",
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(expression, ast.Import):
|
if isinstance(expression, ast.Import):
|
||||||
for alias in expression.names:
|
for alias in expression.names:
|
||||||
if check_module_authorized(alias.name, authorized_imports, dangerous_patterns):
|
if check_module_authorized(alias.name, authorized_imports):
|
||||||
raw_module = import_module(alias.name)
|
raw_module = import_module(alias.name)
|
||||||
state[alias.asname or alias.name] = get_safe_module(raw_module, dangerous_patterns, authorized_imports)
|
state[alias.asname or alias.name] = get_safe_module(raw_module, authorized_imports)
|
||||||
else:
|
else:
|
||||||
raise InterpreterError(
|
raise InterpreterError(
|
||||||
f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
|
f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
elif isinstance(expression, ast.ImportFrom):
|
elif isinstance(expression, ast.ImportFrom):
|
||||||
if check_module_authorized(expression.module, authorized_imports, dangerous_patterns):
|
if check_module_authorized(expression.module, authorized_imports):
|
||||||
raw_module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
|
raw_module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
|
||||||
module = get_safe_module(raw_module, dangerous_patterns, authorized_imports)
|
module = get_safe_module(raw_module, authorized_imports)
|
||||||
if expression.names[0].name == "*": # Handle "from module import *"
|
if expression.names[0].name == "*": # Handle "from module import *"
|
||||||
if hasattr(module, "__all__"): # If module has __all__, import only those names
|
if hasattr(module, "__all__"): # If module has __all__, import only those names
|
||||||
for name in module.__all__:
|
for name in module.__all__:
|
||||||
|
|
|
@ -1302,7 +1302,7 @@ def test_get_safe_module_handle_lazy_imports():
|
||||||
return super().__dir__() + ["lazy_attribute"]
|
return super().__dir__() + ["lazy_attribute"]
|
||||||
|
|
||||||
fake_module = FakeModule("fake_module")
|
fake_module = FakeModule("fake_module")
|
||||||
safe_module = get_safe_module(fake_module, dangerous_patterns=[], authorized_imports=set())
|
safe_module = get_safe_module(fake_module, authorized_imports=set())
|
||||||
assert not hasattr(safe_module, "lazy_attribute")
|
assert not hasattr(safe_module, "lazy_attribute")
|
||||||
assert getattr(safe_module, "non_lazy_attribute") == "ok"
|
assert getattr(safe_module, "non_lazy_attribute") == "ok"
|
||||||
|
|
||||||
|
@ -1377,23 +1377,4 @@ class TestPrintContainer:
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_check_module_authorized(module: str, authorized_imports: list[str], expected: bool):
|
def test_check_module_authorized(module: str, authorized_imports: list[str], expected: bool):
|
||||||
dangerous_patterns = (
|
assert check_module_authorized(module, authorized_imports) == expected
|
||||||
"_os",
|
|
||||||
"os",
|
|
||||||
"subprocess",
|
|
||||||
"_subprocess",
|
|
||||||
"pty",
|
|
||||||
"system",
|
|
||||||
"popen",
|
|
||||||
"spawn",
|
|
||||||
"shutil",
|
|
||||||
"sys",
|
|
||||||
"pathlib",
|
|
||||||
"io",
|
|
||||||
"socket",
|
|
||||||
"compile",
|
|
||||||
"eval",
|
|
||||||
"exec",
|
|
||||||
"multiprocessing",
|
|
||||||
)
|
|
||||||
assert check_module_authorized(module, authorized_imports, dangerous_patterns) == expected
|
|
||||||
|
|
Loading…
Reference in New Issue