From ad5f84b101ad5adb53b53ee2ee84209a664beb29 Mon Sep 17 00:00:00 2001 From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> Date: Tue, 28 Jan 2025 11:00:43 +0100 Subject: [PATCH] Fix blocking of os in authorized imports (#386) --- src/smolagents/local_python_executor.py | 30 ++++++++++++++----------- tests/test_python_interpreter.py | 4 ++++ 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/smolagents/local_python_executor.py b/src/smolagents/local_python_executor.py index a4f046d..6f3e1eb 100644 --- a/src/smolagents/local_python_executor.py +++ b/src/smolagents/local_python_executor.py @@ -934,36 +934,39 @@ def evaluate_with( context.__exit__(None, None, None) -def get_safe_module(unsafe_module, dangerous_patterns, visited=None): +def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited=None): """Creates a safe copy of a module or returns the original if it's a function""" # If it's a function or non-module object, return it directly - if not isinstance(unsafe_module, ModuleType): - return unsafe_module + if not isinstance(raw_module, ModuleType): + return raw_module # Handle circular references: Initialize visited set for the first call if visited is None: visited = set() - module_id = id(unsafe_module) + module_id = id(raw_module) if module_id in visited: - return unsafe_module # Return original for circular refs + return raw_module # Return original for circular refs visited.add(module_id) # Create new module for actual modules - safe_module = ModuleType(unsafe_module.__name__) + safe_module = ModuleType(raw_module.__name__) # Copy all attributes by reference, recursively checking modules - for attr_name in dir(unsafe_module): + for attr_name in dir(raw_module): # Skip dangerous patterns at any level - if any(pattern in f"{unsafe_module.__name__}.{attr_name}" for pattern in dangerous_patterns): + if any( + pattern in raw_module.__name__.split(".") + [attr_name] and pattern not in authorized_imports + for pattern in dangerous_patterns + ): continue - attr_value = getattr(unsafe_module, attr_name) + attr_value = getattr(raw_module, attr_name) # Recursively process nested modules, passing visited set if isinstance(attr_value, ModuleType): - attr_value = get_safe_module(attr_value, dangerous_patterns, visited=visited) + attr_value = get_safe_module(attr_value, dangerous_patterns, authorized_imports, visited=visited) setattr(safe_module, attr_name, attr_value) @@ -996,7 +999,7 @@ def import_modules(expression, state, authorized_imports): return True else: module_path = module_name.split(".") - if any([module in dangerous_patterns for module in module_path]): + if any([module in dangerous_patterns and module not in authorized_imports for module in module_path]): return False module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)] return any(subpath in authorized_imports for subpath in module_subpaths) @@ -1005,7 +1008,7 @@ def import_modules(expression, state, authorized_imports): for alias in expression.names: if check_module_authorized(alias.name): raw_module = import_module(alias.name) - state[alias.asname or alias.name] = get_safe_module(raw_module, dangerous_patterns) + state[alias.asname or alias.name] = get_safe_module(raw_module, dangerous_patterns, authorized_imports) else: raise InterpreterError( f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}" @@ -1013,7 +1016,8 @@ def import_modules(expression, state, authorized_imports): return None elif isinstance(expression, ast.ImportFrom): if check_module_authorized(expression.module): - module = __import__(expression.module, fromlist=[alias.name for alias in expression.names]) + raw_module = __import__(expression.module, fromlist=[alias.name for alias in expression.names]) + module = get_safe_module(raw_module, dangerous_patterns, authorized_imports) if expression.names[0].name == "*": # Handle "from module import *" if hasattr(module, "__all__"): # If module has __all__, import only those names for name in module.__all__: diff --git a/tests/test_python_interpreter.py b/tests/test_python_interpreter.py index 8aec8fe..a44fbea 100644 --- a/tests/test_python_interpreter.py +++ b/tests/test_python_interpreter.py @@ -950,6 +950,10 @@ texec(tcompile("1 + 1", "no filename", "exec")) dangerous_code, static_tools={"tcompile": compile, "teval": eval, "texec": exec} | BASE_PYTHON_TOOLS ) + def test_can_import_os_if_explicitly_authorized(self): + dangerous_code = "import os; os.listdir('./')" + evaluate_python_code(dangerous_code, authorized_imports=["os"]) + @pytest.mark.parametrize( "code, expected_result",