Fix subpackage import vulnerability (#238)

* Fix subpackage import vulnerability
This commit is contained in:
Aymeric Roucher 2025-01-17 11:40:49 +01:00 committed by GitHub
parent d5c2ef48e7
commit c255c1ff84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 91 additions and 13 deletions

View File

@ -120,21 +120,15 @@ class GradioUI:
"""
if file is None:
return gr.Textbox(
"No file uploaded", visible=True
), file_uploads_log
return gr.Textbox("No file uploaded", visible=True), file_uploads_log
try:
mime_type, _ = mimetypes.guess_type(file.name)
except Exception as e:
return gr.Textbox(
f"Error: {e}", visible=True
), file_uploads_log
return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log
if mime_type not in allowed_file_types:
return gr.Textbox(
"File type disallowed", visible=True
), file_uploads_log
return gr.Textbox("File type disallowed", visible=True), file_uploads_log
# Sanitize file name
original_name = os.path.basename(file.name)

View File

@ -21,6 +21,7 @@ import math
import re
from collections.abc import Mapping
from importlib import import_module
from types import ModuleType
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
@ -1046,12 +1047,75 @@ def evaluate_with(
context.__exit__(None, None, None)
def get_safe_module(unsafe_module, dangerous_patterns, visited=None):
"""Creates a safe copy of a module or returns the original if it's a function"""
# If it's a function or non-module object, return it directly
if not isinstance(unsafe_module, ModuleType):
return unsafe_module
# Handle circular references: Initialize visited set for the first call
if visited is None:
visited = set()
module_id = id(unsafe_module)
if module_id in visited:
return unsafe_module # Return original for circular refs
visited.add(module_id)
# Create new module for actual modules
safe_module = ModuleType(unsafe_module.__name__)
# Copy all attributes by reference, recursively checking modules
for attr_name in dir(unsafe_module):
# Skip dangerous patterns at any level
if any(
pattern in f"{unsafe_module.__name__}.{attr_name}"
for pattern in dangerous_patterns
):
continue
attr_value = getattr(unsafe_module, attr_name)
# Recursively process nested modules, passing visited set
if isinstance(attr_value, ModuleType):
attr_value = get_safe_module(
attr_value, dangerous_patterns, visited=visited
)
setattr(safe_module, attr_name, attr_value)
return safe_module
def import_modules(expression, state, authorized_imports):
dangerous_patterns = (
"_os",
"os",
"subprocess",
"_subprocess",
"pty",
"system",
"popen",
"spawn",
"shutil",
"glob",
"pathlib",
"io",
"socket",
"compile",
"eval",
"exec",
"multiprocessing",
)
def check_module_authorized(module_name):
if "*" in authorized_imports:
return True
else:
module_path = module_name.split(".")
if any([module in dangerous_patterns for module in module_path]):
return False
module_subpaths = [
".".join(module_path[:i]) for i in range(1, len(module_path) + 1)
]
@ -1060,8 +1124,10 @@ def import_modules(expression, state, authorized_imports):
if isinstance(expression, ast.Import):
for alias in expression.names:
if check_module_authorized(alias.name):
module = import_module(alias.name)
state[alias.asname or alias.name] = module
raw_module = import_module(alias.name)
state[alias.asname or alias.name] = get_safe_module(
raw_module, dangerous_patterns
)
else:
raise InterpreterError(
f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
@ -1069,11 +1135,13 @@ def import_modules(expression, state, authorized_imports):
return None
elif isinstance(expression, ast.ImportFrom):
if check_module_authorized(expression.module):
module = __import__(
raw_module = __import__(
expression.module, fromlist=[alias.name for alias in expression.names]
)
for alias in expression.names:
state[alias.asname or alias.name] = getattr(module, alias.name)
state[alias.asname or alias.name] = get_safe_module(
getattr(raw_module, alias.name), dangerous_patterns
)
else:
raise InterpreterError(f"Import from {expression.module} is not allowed.")
return None

View File

@ -920,3 +920,19 @@ shift_intervals
Expected: {expected}
Got: {result}
"""
def test_dangerous_subpackage_access_blocked(self):
# Direct imports with dangerous patterns should fail
code = "import random._os"
with pytest.raises(InterpreterError):
evaluate_python_code(code)
# Import of whitelisted modules should succeed but dangerous submodules should not exist
code = "import random;random._os.system('echo bad command passed')"
with pytest.raises(AttributeError) as e:
evaluate_python_code(code)
assert "module 'random' has no attribute '_os'" in str(e)
code = "import doctest;doctest.inspect.os.system('echo bad command passed')"
with pytest.raises(AttributeError):
evaluate_python_code(code, authorized_imports=["doctest"])