Fix subpackage import vulnerability (#238)

* Fix subpackage import vulnerability
This commit is contained in:
Aymeric Roucher 2025-01-17 11:40:49 +01:00 committed by GitHub
parent d5c2ef48e7
commit c255c1ff84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 91 additions and 13 deletions

View File

@ -120,21 +120,15 @@ class GradioUI:
""" """
if file is None: if file is None:
return gr.Textbox( return gr.Textbox("No file uploaded", visible=True), file_uploads_log
"No file uploaded", visible=True
), file_uploads_log
try: try:
mime_type, _ = mimetypes.guess_type(file.name) mime_type, _ = mimetypes.guess_type(file.name)
except Exception as e: except Exception as e:
return gr.Textbox( return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log
f"Error: {e}", visible=True
), file_uploads_log
if mime_type not in allowed_file_types: if mime_type not in allowed_file_types:
return gr.Textbox( return gr.Textbox("File type disallowed", visible=True), file_uploads_log
"File type disallowed", visible=True
), file_uploads_log
# Sanitize file name # Sanitize file name
original_name = os.path.basename(file.name) original_name = os.path.basename(file.name)

View File

@ -21,6 +21,7 @@ import math
import re import re
from collections.abc import Mapping from collections.abc import Mapping
from importlib import import_module from importlib import import_module
from types import ModuleType
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np import numpy as np
@ -1046,12 +1047,75 @@ def evaluate_with(
context.__exit__(None, None, None) context.__exit__(None, None, None)
def get_safe_module(unsafe_module, dangerous_patterns, visited=None):
"""Creates a safe copy of a module or returns the original if it's a function"""
# If it's a function or non-module object, return it directly
if not isinstance(unsafe_module, ModuleType):
return unsafe_module
# Handle circular references: Initialize visited set for the first call
if visited is None:
visited = set()
module_id = id(unsafe_module)
if module_id in visited:
return unsafe_module # Return original for circular refs
visited.add(module_id)
# Create new module for actual modules
safe_module = ModuleType(unsafe_module.__name__)
# Copy all attributes by reference, recursively checking modules
for attr_name in dir(unsafe_module):
# Skip dangerous patterns at any level
if any(
pattern in f"{unsafe_module.__name__}.{attr_name}"
for pattern in dangerous_patterns
):
continue
attr_value = getattr(unsafe_module, attr_name)
# Recursively process nested modules, passing visited set
if isinstance(attr_value, ModuleType):
attr_value = get_safe_module(
attr_value, dangerous_patterns, visited=visited
)
setattr(safe_module, attr_name, attr_value)
return safe_module
def import_modules(expression, state, authorized_imports): def import_modules(expression, state, authorized_imports):
dangerous_patterns = (
"_os",
"os",
"subprocess",
"_subprocess",
"pty",
"system",
"popen",
"spawn",
"shutil",
"glob",
"pathlib",
"io",
"socket",
"compile",
"eval",
"exec",
"multiprocessing",
)
def check_module_authorized(module_name): def check_module_authorized(module_name):
if "*" in authorized_imports: if "*" in authorized_imports:
return True return True
else: else:
module_path = module_name.split(".") module_path = module_name.split(".")
if any([module in dangerous_patterns for module in module_path]):
return False
module_subpaths = [ module_subpaths = [
".".join(module_path[:i]) for i in range(1, len(module_path) + 1) ".".join(module_path[:i]) for i in range(1, len(module_path) + 1)
] ]
@ -1060,8 +1124,10 @@ def import_modules(expression, state, authorized_imports):
if isinstance(expression, ast.Import): if isinstance(expression, ast.Import):
for alias in expression.names: for alias in expression.names:
if check_module_authorized(alias.name): if check_module_authorized(alias.name):
module = import_module(alias.name) raw_module = import_module(alias.name)
state[alias.asname or alias.name] = module state[alias.asname or alias.name] = get_safe_module(
raw_module, dangerous_patterns
)
else: else:
raise InterpreterError( raise InterpreterError(
f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}" f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
@ -1069,11 +1135,13 @@ def import_modules(expression, state, authorized_imports):
return None return None
elif isinstance(expression, ast.ImportFrom): elif isinstance(expression, ast.ImportFrom):
if check_module_authorized(expression.module): if check_module_authorized(expression.module):
module = __import__( raw_module = __import__(
expression.module, fromlist=[alias.name for alias in expression.names] expression.module, fromlist=[alias.name for alias in expression.names]
) )
for alias in expression.names: for alias in expression.names:
state[alias.asname or alias.name] = getattr(module, alias.name) state[alias.asname or alias.name] = get_safe_module(
getattr(raw_module, alias.name), dangerous_patterns
)
else: else:
raise InterpreterError(f"Import from {expression.module} is not allowed.") raise InterpreterError(f"Import from {expression.module} is not allowed.")
return None return None

View File

@ -920,3 +920,19 @@ shift_intervals
Expected: {expected} Expected: {expected}
Got: {result} Got: {result}
""" """
def test_dangerous_subpackage_access_blocked(self):
# Direct imports with dangerous patterns should fail
code = "import random._os"
with pytest.raises(InterpreterError):
evaluate_python_code(code)
# Import of whitelisted modules should succeed but dangerous submodules should not exist
code = "import random;random._os.system('echo bad command passed')"
with pytest.raises(AttributeError) as e:
evaluate_python_code(code)
assert "module 'random' has no attribute '_os'" in str(e)
code = "import doctest;doctest.inspect.os.system('echo bad command passed')"
with pytest.raises(AttributeError):
evaluate_python_code(code, authorized_imports=["doctest"])