diff --git a/src/smolagents/local_python_executor.py b/src/smolagents/local_python_executor.py index a7b2fb3..e2df293 100644 --- a/src/smolagents/local_python_executor.py +++ b/src/smolagents/local_python_executor.py @@ -114,6 +114,26 @@ BASE_PYTHON_TOOLS = { "complex": complex, } +DANGEROUS_PATTERNS = ( + "_os", + "os", + "subprocess", + "_subprocess", + "pty", + "system", + "popen", + "spawn", + "shutil", + "sys", + "pathlib", + "io", + "socket", + "compile", + "eval", + "exec", + "multiprocessing", +) + class PrintContainer: def __init__(self): @@ -954,7 +974,7 @@ def evaluate_with( context.__exit__(None, None, None) -def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited=None): +def get_safe_module(raw_module, authorized_imports, visited=None): """Creates a safe copy of a module or returns the original if it's a function""" # If it's a function or non-module object, return it directly if not isinstance(raw_module, ModuleType): @@ -978,8 +998,8 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited= # Skip dangerous patterns at any level if any( pattern in raw_module.__name__.split(".") + [attr_name] - and not check_module_authorized(pattern, authorized_imports, dangerous_patterns) - for pattern in dangerous_patterns + and not check_module_authorized(pattern, authorized_imports) + for pattern in DANGEROUS_PATTERNS ): logger.info(f"Skipping dangerous attribute {raw_module.__name__}.{attr_name}") continue @@ -994,19 +1014,19 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited= continue # Recursively process nested modules, passing visited set if isinstance(attr_value, ModuleType): - attr_value = get_safe_module(attr_value, dangerous_patterns, authorized_imports, visited=visited) + attr_value = get_safe_module(attr_value, authorized_imports, visited=visited) setattr(safe_module, attr_name, attr_value) return safe_module -def check_module_authorized(module_name, authorized_imports, dangerous_patterns): +def check_module_authorized(module_name, authorized_imports): if "*" in authorized_imports: return True else: module_path = module_name.split(".") - if any([module in dangerous_patterns and module not in authorized_imports for module in module_path]): + if any([module in DANGEROUS_PATTERNS and module not in authorized_imports for module in module_path]): return False # ["A", "B", "C"] -> ["A", "A.B", "A.B.C"] module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)] @@ -1014,40 +1034,20 @@ def check_module_authorized(module_name, authorized_imports, dangerous_patterns) def import_modules(expression, state, authorized_imports): - dangerous_patterns = ( - "_os", - "os", - "subprocess", - "_subprocess", - "pty", - "system", - "popen", - "spawn", - "shutil", - "sys", - "pathlib", - "io", - "socket", - "compile", - "eval", - "exec", - "multiprocessing", - ) - if isinstance(expression, ast.Import): for alias in expression.names: - if check_module_authorized(alias.name, authorized_imports, dangerous_patterns): + if check_module_authorized(alias.name, authorized_imports): raw_module = import_module(alias.name) - state[alias.asname or alias.name] = get_safe_module(raw_module, dangerous_patterns, authorized_imports) + state[alias.asname or alias.name] = get_safe_module(raw_module, authorized_imports) else: raise InterpreterError( f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}" ) return None elif isinstance(expression, ast.ImportFrom): - if check_module_authorized(expression.module, authorized_imports, dangerous_patterns): + if check_module_authorized(expression.module, authorized_imports): raw_module = __import__(expression.module, fromlist=[alias.name for alias in expression.names]) - module = get_safe_module(raw_module, dangerous_patterns, authorized_imports) + module = get_safe_module(raw_module, authorized_imports) if expression.names[0].name == "*": # Handle "from module import *" if hasattr(module, "__all__"): # If module has __all__, import only those names for name in module.__all__: diff --git a/tests/test_local_python_executor.py b/tests/test_local_python_executor.py index f7ecc91..7777631 100644 --- a/tests/test_local_python_executor.py +++ b/tests/test_local_python_executor.py @@ -1302,7 +1302,7 @@ def test_get_safe_module_handle_lazy_imports(): return super().__dir__() + ["lazy_attribute"] fake_module = FakeModule("fake_module") - safe_module = get_safe_module(fake_module, dangerous_patterns=[], authorized_imports=set()) + safe_module = get_safe_module(fake_module, authorized_imports=set()) assert not hasattr(safe_module, "lazy_attribute") assert getattr(safe_module, "non_lazy_attribute") == "ok" @@ -1377,23 +1377,4 @@ class TestPrintContainer: ], ) def test_check_module_authorized(module: str, authorized_imports: list[str], expected: bool): - dangerous_patterns = ( - "_os", - "os", - "subprocess", - "_subprocess", - "pty", - "system", - "popen", - "spawn", - "shutil", - "sys", - "pathlib", - "io", - "socket", - "compile", - "eval", - "exec", - "multiprocessing", - ) - assert check_module_authorized(module, authorized_imports, dangerous_patterns) == expected + assert check_module_authorized(module, authorized_imports) == expected