Fix blocking of os in authorized imports (#386)
This commit is contained in:
parent
a4f89b68b2
commit
ad5f84b101
|
@ -934,36 +934,39 @@ def evaluate_with(
|
||||||
context.__exit__(None, None, None)
|
context.__exit__(None, None, None)
|
||||||
|
|
||||||
|
|
||||||
def get_safe_module(unsafe_module, dangerous_patterns, visited=None):
|
def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited=None):
|
||||||
"""Creates a safe copy of a module or returns the original if it's a function"""
|
"""Creates a safe copy of a module or returns the original if it's a function"""
|
||||||
# If it's a function or non-module object, return it directly
|
# If it's a function or non-module object, return it directly
|
||||||
if not isinstance(unsafe_module, ModuleType):
|
if not isinstance(raw_module, ModuleType):
|
||||||
return unsafe_module
|
return raw_module
|
||||||
|
|
||||||
# Handle circular references: Initialize visited set for the first call
|
# Handle circular references: Initialize visited set for the first call
|
||||||
if visited is None:
|
if visited is None:
|
||||||
visited = set()
|
visited = set()
|
||||||
|
|
||||||
module_id = id(unsafe_module)
|
module_id = id(raw_module)
|
||||||
if module_id in visited:
|
if module_id in visited:
|
||||||
return unsafe_module # Return original for circular refs
|
return raw_module # Return original for circular refs
|
||||||
|
|
||||||
visited.add(module_id)
|
visited.add(module_id)
|
||||||
|
|
||||||
# Create new module for actual modules
|
# Create new module for actual modules
|
||||||
safe_module = ModuleType(unsafe_module.__name__)
|
safe_module = ModuleType(raw_module.__name__)
|
||||||
|
|
||||||
# Copy all attributes by reference, recursively checking modules
|
# Copy all attributes by reference, recursively checking modules
|
||||||
for attr_name in dir(unsafe_module):
|
for attr_name in dir(raw_module):
|
||||||
# Skip dangerous patterns at any level
|
# Skip dangerous patterns at any level
|
||||||
if any(pattern in f"{unsafe_module.__name__}.{attr_name}" for pattern in dangerous_patterns):
|
if any(
|
||||||
|
pattern in raw_module.__name__.split(".") + [attr_name] and pattern not in authorized_imports
|
||||||
|
for pattern in dangerous_patterns
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
attr_value = getattr(unsafe_module, attr_name)
|
attr_value = getattr(raw_module, attr_name)
|
||||||
|
|
||||||
# Recursively process nested modules, passing visited set
|
# Recursively process nested modules, passing visited set
|
||||||
if isinstance(attr_value, ModuleType):
|
if isinstance(attr_value, ModuleType):
|
||||||
attr_value = get_safe_module(attr_value, dangerous_patterns, visited=visited)
|
attr_value = get_safe_module(attr_value, dangerous_patterns, authorized_imports, visited=visited)
|
||||||
|
|
||||||
setattr(safe_module, attr_name, attr_value)
|
setattr(safe_module, attr_name, attr_value)
|
||||||
|
|
||||||
|
@ -996,7 +999,7 @@ def import_modules(expression, state, authorized_imports):
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
module_path = module_name.split(".")
|
module_path = module_name.split(".")
|
||||||
if any([module in dangerous_patterns for module in module_path]):
|
if any([module in dangerous_patterns and module not in authorized_imports for module in module_path]):
|
||||||
return False
|
return False
|
||||||
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
|
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
|
||||||
return any(subpath in authorized_imports for subpath in module_subpaths)
|
return any(subpath in authorized_imports for subpath in module_subpaths)
|
||||||
|
@ -1005,7 +1008,7 @@ def import_modules(expression, state, authorized_imports):
|
||||||
for alias in expression.names:
|
for alias in expression.names:
|
||||||
if check_module_authorized(alias.name):
|
if check_module_authorized(alias.name):
|
||||||
raw_module = import_module(alias.name)
|
raw_module = import_module(alias.name)
|
||||||
state[alias.asname or alias.name] = get_safe_module(raw_module, dangerous_patterns)
|
state[alias.asname or alias.name] = get_safe_module(raw_module, dangerous_patterns, authorized_imports)
|
||||||
else:
|
else:
|
||||||
raise InterpreterError(
|
raise InterpreterError(
|
||||||
f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
|
f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
|
||||||
|
@ -1013,7 +1016,8 @@ def import_modules(expression, state, authorized_imports):
|
||||||
return None
|
return None
|
||||||
elif isinstance(expression, ast.ImportFrom):
|
elif isinstance(expression, ast.ImportFrom):
|
||||||
if check_module_authorized(expression.module):
|
if check_module_authorized(expression.module):
|
||||||
module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
|
raw_module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
|
||||||
|
module = get_safe_module(raw_module, dangerous_patterns, authorized_imports)
|
||||||
if expression.names[0].name == "*": # Handle "from module import *"
|
if expression.names[0].name == "*": # Handle "from module import *"
|
||||||
if hasattr(module, "__all__"): # If module has __all__, import only those names
|
if hasattr(module, "__all__"): # If module has __all__, import only those names
|
||||||
for name in module.__all__:
|
for name in module.__all__:
|
||||||
|
|
|
@ -950,6 +950,10 @@ texec(tcompile("1 + 1", "no filename", "exec"))
|
||||||
dangerous_code, static_tools={"tcompile": compile, "teval": eval, "texec": exec} | BASE_PYTHON_TOOLS
|
dangerous_code, static_tools={"tcompile": compile, "teval": eval, "texec": exec} | BASE_PYTHON_TOOLS
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_can_import_os_if_explicitly_authorized(self):
|
||||||
|
dangerous_code = "import os; os.listdir('./')"
|
||||||
|
evaluate_python_code(dangerous_code, authorized_imports=["os"])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"code, expected_result",
|
"code, expected_result",
|
||||||
|
|
Loading…
Reference in New Issue