From c255c1ff8442a6da0c83282c59a5490390935c98 Mon Sep 17 00:00:00 2001 From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:40:49 +0100 Subject: [PATCH] Fix subpackage import vulnerability (#238) * Fix subpackage import vulnerability --- src/smolagents/gradio_ui.py | 12 +--- src/smolagents/local_python_executor.py | 76 +++++++++++++++++++++++-- tests/test_python_interpreter.py | 16 ++++++ 3 files changed, 91 insertions(+), 13 deletions(-) diff --git a/src/smolagents/gradio_ui.py b/src/smolagents/gradio_ui.py index 884a75f..e04056d 100644 --- a/src/smolagents/gradio_ui.py +++ b/src/smolagents/gradio_ui.py @@ -120,21 +120,15 @@ class GradioUI: """ if file is None: - return gr.Textbox( - "No file uploaded", visible=True - ), file_uploads_log + return gr.Textbox("No file uploaded", visible=True), file_uploads_log try: mime_type, _ = mimetypes.guess_type(file.name) except Exception as e: - return gr.Textbox( - f"Error: {e}", visible=True - ), file_uploads_log + return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log if mime_type not in allowed_file_types: - return gr.Textbox( - "File type disallowed", visible=True - ), file_uploads_log + return gr.Textbox("File type disallowed", visible=True), file_uploads_log # Sanitize file name original_name = os.path.basename(file.name) diff --git a/src/smolagents/local_python_executor.py b/src/smolagents/local_python_executor.py index 1b9b35d..069d526 100644 --- a/src/smolagents/local_python_executor.py +++ b/src/smolagents/local_python_executor.py @@ -21,6 +21,7 @@ import math import re from collections.abc import Mapping from importlib import import_module +from types import ModuleType from typing import Any, Callable, Dict, List, Optional, Tuple import numpy as np @@ -1046,12 +1047,75 @@ def evaluate_with( context.__exit__(None, None, None) +def get_safe_module(unsafe_module, dangerous_patterns, visited=None): + """Creates a safe copy of a module or returns the original if it's a function""" + # If it's a function or non-module object, return it directly + if not isinstance(unsafe_module, ModuleType): + return unsafe_module + + # Handle circular references: Initialize visited set for the first call + if visited is None: + visited = set() + + module_id = id(unsafe_module) + if module_id in visited: + return unsafe_module # Return original for circular refs + + visited.add(module_id) + + # Create new module for actual modules + safe_module = ModuleType(unsafe_module.__name__) + + # Copy all attributes by reference, recursively checking modules + for attr_name in dir(unsafe_module): + # Skip dangerous patterns at any level + if any( + pattern in f"{unsafe_module.__name__}.{attr_name}" + for pattern in dangerous_patterns + ): + continue + + attr_value = getattr(unsafe_module, attr_name) + + # Recursively process nested modules, passing visited set + if isinstance(attr_value, ModuleType): + attr_value = get_safe_module( + attr_value, dangerous_patterns, visited=visited + ) + + setattr(safe_module, attr_name, attr_value) + + return safe_module + + def import_modules(expression, state, authorized_imports): + dangerous_patterns = ( + "_os", + "os", + "subprocess", + "_subprocess", + "pty", + "system", + "popen", + "spawn", + "shutil", + "glob", + "pathlib", + "io", + "socket", + "compile", + "eval", + "exec", + "multiprocessing", + ) + def check_module_authorized(module_name): if "*" in authorized_imports: return True else: module_path = module_name.split(".") + if any([module in dangerous_patterns for module in module_path]): + return False module_subpaths = [ ".".join(module_path[:i]) for i in range(1, len(module_path) + 1) ] @@ -1060,8 +1124,10 @@ def import_modules(expression, state, authorized_imports): if isinstance(expression, ast.Import): for alias in expression.names: if check_module_authorized(alias.name): - module = import_module(alias.name) - state[alias.asname or alias.name] = module + raw_module = import_module(alias.name) + state[alias.asname or alias.name] = get_safe_module( + raw_module, dangerous_patterns + ) else: raise InterpreterError( f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}" @@ -1069,11 +1135,13 @@ def import_modules(expression, state, authorized_imports): return None elif isinstance(expression, ast.ImportFrom): if check_module_authorized(expression.module): - module = __import__( + raw_module = __import__( expression.module, fromlist=[alias.name for alias in expression.names] ) for alias in expression.names: - state[alias.asname or alias.name] = getattr(module, alias.name) + state[alias.asname or alias.name] = get_safe_module( + getattr(raw_module, alias.name), dangerous_patterns + ) else: raise InterpreterError(f"Import from {expression.module} is not allowed.") return None diff --git a/tests/test_python_interpreter.py b/tests/test_python_interpreter.py index 6e5b3c2..4976c56 100644 --- a/tests/test_python_interpreter.py +++ b/tests/test_python_interpreter.py @@ -920,3 +920,19 @@ shift_intervals Expected: {expected} Got: {result} """ + + def test_dangerous_subpackage_access_blocked(self): + # Direct imports with dangerous patterns should fail + code = "import random._os" + with pytest.raises(InterpreterError): + evaluate_python_code(code) + + # Import of whitelisted modules should succeed but dangerous submodules should not exist + code = "import random;random._os.system('echo bad command passed')" + with pytest.raises(AttributeError) as e: + evaluate_python_code(code) + assert "module 'random' has no attribute '_os'" in str(e) + + code = "import doctest;doctest.inspect.os.system('echo bad command passed')" + with pytest.raises(AttributeError): + evaluate_python_code(code, authorized_imports=["doctest"])