Fix subpackage import vulnerability (#238)
* Fix subpackage import vulnerability
This commit is contained in:
parent
d5c2ef48e7
commit
c255c1ff84
|
@ -120,21 +120,15 @@ class GradioUI:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if file is None:
|
if file is None:
|
||||||
return gr.Textbox(
|
return gr.Textbox("No file uploaded", visible=True), file_uploads_log
|
||||||
"No file uploaded", visible=True
|
|
||||||
), file_uploads_log
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
mime_type, _ = mimetypes.guess_type(file.name)
|
mime_type, _ = mimetypes.guess_type(file.name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return gr.Textbox(
|
return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log
|
||||||
f"Error: {e}", visible=True
|
|
||||||
), file_uploads_log
|
|
||||||
|
|
||||||
if mime_type not in allowed_file_types:
|
if mime_type not in allowed_file_types:
|
||||||
return gr.Textbox(
|
return gr.Textbox("File type disallowed", visible=True), file_uploads_log
|
||||||
"File type disallowed", visible=True
|
|
||||||
), file_uploads_log
|
|
||||||
|
|
||||||
# Sanitize file name
|
# Sanitize file name
|
||||||
original_name = os.path.basename(file.name)
|
original_name = os.path.basename(file.name)
|
||||||
|
|
|
@ -21,6 +21,7 @@ import math
|
||||||
import re
|
import re
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
from types import ModuleType
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -1046,12 +1047,75 @@ def evaluate_with(
|
||||||
context.__exit__(None, None, None)
|
context.__exit__(None, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_safe_module(unsafe_module, dangerous_patterns, visited=None):
|
||||||
|
"""Creates a safe copy of a module or returns the original if it's a function"""
|
||||||
|
# If it's a function or non-module object, return it directly
|
||||||
|
if not isinstance(unsafe_module, ModuleType):
|
||||||
|
return unsafe_module
|
||||||
|
|
||||||
|
# Handle circular references: Initialize visited set for the first call
|
||||||
|
if visited is None:
|
||||||
|
visited = set()
|
||||||
|
|
||||||
|
module_id = id(unsafe_module)
|
||||||
|
if module_id in visited:
|
||||||
|
return unsafe_module # Return original for circular refs
|
||||||
|
|
||||||
|
visited.add(module_id)
|
||||||
|
|
||||||
|
# Create new module for actual modules
|
||||||
|
safe_module = ModuleType(unsafe_module.__name__)
|
||||||
|
|
||||||
|
# Copy all attributes by reference, recursively checking modules
|
||||||
|
for attr_name in dir(unsafe_module):
|
||||||
|
# Skip dangerous patterns at any level
|
||||||
|
if any(
|
||||||
|
pattern in f"{unsafe_module.__name__}.{attr_name}"
|
||||||
|
for pattern in dangerous_patterns
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
attr_value = getattr(unsafe_module, attr_name)
|
||||||
|
|
||||||
|
# Recursively process nested modules, passing visited set
|
||||||
|
if isinstance(attr_value, ModuleType):
|
||||||
|
attr_value = get_safe_module(
|
||||||
|
attr_value, dangerous_patterns, visited=visited
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(safe_module, attr_name, attr_value)
|
||||||
|
|
||||||
|
return safe_module
|
||||||
|
|
||||||
|
|
||||||
def import_modules(expression, state, authorized_imports):
|
def import_modules(expression, state, authorized_imports):
|
||||||
|
dangerous_patterns = (
|
||||||
|
"_os",
|
||||||
|
"os",
|
||||||
|
"subprocess",
|
||||||
|
"_subprocess",
|
||||||
|
"pty",
|
||||||
|
"system",
|
||||||
|
"popen",
|
||||||
|
"spawn",
|
||||||
|
"shutil",
|
||||||
|
"glob",
|
||||||
|
"pathlib",
|
||||||
|
"io",
|
||||||
|
"socket",
|
||||||
|
"compile",
|
||||||
|
"eval",
|
||||||
|
"exec",
|
||||||
|
"multiprocessing",
|
||||||
|
)
|
||||||
|
|
||||||
def check_module_authorized(module_name):
|
def check_module_authorized(module_name):
|
||||||
if "*" in authorized_imports:
|
if "*" in authorized_imports:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
module_path = module_name.split(".")
|
module_path = module_name.split(".")
|
||||||
|
if any([module in dangerous_patterns for module in module_path]):
|
||||||
|
return False
|
||||||
module_subpaths = [
|
module_subpaths = [
|
||||||
".".join(module_path[:i]) for i in range(1, len(module_path) + 1)
|
".".join(module_path[:i]) for i in range(1, len(module_path) + 1)
|
||||||
]
|
]
|
||||||
|
@ -1060,8 +1124,10 @@ def import_modules(expression, state, authorized_imports):
|
||||||
if isinstance(expression, ast.Import):
|
if isinstance(expression, ast.Import):
|
||||||
for alias in expression.names:
|
for alias in expression.names:
|
||||||
if check_module_authorized(alias.name):
|
if check_module_authorized(alias.name):
|
||||||
module = import_module(alias.name)
|
raw_module = import_module(alias.name)
|
||||||
state[alias.asname or alias.name] = module
|
state[alias.asname or alias.name] = get_safe_module(
|
||||||
|
raw_module, dangerous_patterns
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise InterpreterError(
|
raise InterpreterError(
|
||||||
f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
|
f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
|
||||||
|
@ -1069,11 +1135,13 @@ def import_modules(expression, state, authorized_imports):
|
||||||
return None
|
return None
|
||||||
elif isinstance(expression, ast.ImportFrom):
|
elif isinstance(expression, ast.ImportFrom):
|
||||||
if check_module_authorized(expression.module):
|
if check_module_authorized(expression.module):
|
||||||
module = __import__(
|
raw_module = __import__(
|
||||||
expression.module, fromlist=[alias.name for alias in expression.names]
|
expression.module, fromlist=[alias.name for alias in expression.names]
|
||||||
)
|
)
|
||||||
for alias in expression.names:
|
for alias in expression.names:
|
||||||
state[alias.asname or alias.name] = getattr(module, alias.name)
|
state[alias.asname or alias.name] = get_safe_module(
|
||||||
|
getattr(raw_module, alias.name), dangerous_patterns
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise InterpreterError(f"Import from {expression.module} is not allowed.")
|
raise InterpreterError(f"Import from {expression.module} is not allowed.")
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -920,3 +920,19 @@ shift_intervals
|
||||||
Expected: {expected}
|
Expected: {expected}
|
||||||
Got: {result}
|
Got: {result}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def test_dangerous_subpackage_access_blocked(self):
|
||||||
|
# Direct imports with dangerous patterns should fail
|
||||||
|
code = "import random._os"
|
||||||
|
with pytest.raises(InterpreterError):
|
||||||
|
evaluate_python_code(code)
|
||||||
|
|
||||||
|
# Import of whitelisted modules should succeed but dangerous submodules should not exist
|
||||||
|
code = "import random;random._os.system('echo bad command passed')"
|
||||||
|
with pytest.raises(AttributeError) as e:
|
||||||
|
evaluate_python_code(code)
|
||||||
|
assert "module 'random' has no attribute '_os'" in str(e)
|
||||||
|
|
||||||
|
code = "import doctest;doctest.inspect.os.system('echo bad command passed')"
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
evaluate_python_code(code, authorized_imports=["doctest"])
|
||||||
|
|
Loading…
Reference in New Issue