Start fixing state transfer between local and docker executor

This commit is contained in:
Aymeric 2024-12-13 13:58:26 +01:00
parent 8e758fa130
commit 8ed03634b0
4 changed files with 210 additions and 50 deletions

View File

@ -47,7 +47,7 @@ def visit_webpage(url: str) -> str:
llm_engine = HfApiEngine(model)
web_agent = JsonAgent(
web_agent = CodeAgent(
tools=[DuckDuckGoSearchTool(), visit_webpage],
llm_engine=llm_engine,
max_iterations=10,

View File

@ -38,7 +38,7 @@ from .prompts import (
SYSTEM_PROMPT_PLAN_UPDATE,
SYSTEM_PROMPT_PLAN,
)
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code
from .tools import (
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
Tool,

View File

@ -23,7 +23,7 @@ from typing import Dict
from huggingface_hub import hf_hub_download, list_spaces
from transformers.utils import is_offline_mode
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code
from .tools import TOOL_CONFIG_FILE, Tool

View File

@ -1,13 +1,80 @@
import sys
import json
import traceback
from pathlib import Path
import docker
import time
import uuid
import signal
from typing import Optional, Dict, Tuple, Any
import subprocess
import pickle
import re
from typing import Optional, Dict, Tuple, Set, Any
import types
from .default_tools import BASE_PYTHON_TOOLS
class StateManager:
def __init__(self, work_dir: Path):
self.work_dir = work_dir
self.state_file = work_dir / "interpreter_state.pickle"
self.imports_file = work_dir / "imports.txt"
self.import_pattern = re.compile(r'^(?:from\s+[\w.]+\s+)?import\s+.+$')
self.imports: Set[str] = set()
def is_import_statement(self, code: str) -> bool:
"""Check if a line of code is an import statement."""
return bool(self.import_pattern.match(code.strip()))
def track_imports(self, code: str):
"""Track import statements for later use."""
for line in code.split('\n'):
if self.is_import_statement(line.strip()):
self.imports.add(line.strip())
def save_state(self, locals_dict: Dict[str, Any], executor: str):
"""
Save the current state of variables and imports.
Args:
locals_dict: Dictionary of local variables
executor: 'docker' or 'local' to indicate source
"""
# Filter out modules, functions, and special variables
state_dict = {
'variables': {
k: v for k, v in locals_dict.items()
if not (
k.startswith('_')
or callable(v)
or isinstance(v, type)
or isinstance(v, types.ModuleType)
)
},
'imports': list(self.imports),
'source': executor
}
with open(self.state_file, 'wb') as f:
pickle.dump(state_dict, f)
def load_state(self, executor: str) -> Dict[str, Any]:
"""
Load the saved state and handle imports.
Args:
executor: 'docker' or 'local' to indicate destination
Returns:
Dictionary of variables to restore
"""
if not self.state_file.exists():
return {}
with open(self.state_file, 'rb') as f:
state_dict = pickle.load(f)
# First handle imports
for import_stmt in state_dict['imports']:
exec(import_stmt, globals())
return state_dict['variables']
def read_multiplexed_response(socket):
"""Read and demultiplex all responses from Docker exec socket"""
@ -20,6 +87,7 @@ def read_multiplexed_response(socket):
responses = response_data.split(b'\x01\x00\x00\x00\x00\x00')
# The last non-empty chunk should be our JSON response
if len(responses) > 0:
for chunk in reversed(responses):
if chunk and len(chunk.strip()) > 0:
try:
@ -35,7 +103,7 @@ def read_multiplexed_response(socket):
i+=1
class DockerInterpreter:
class DockerPythonInterpreter:
def __init__(self, work_dir: Path = Path(".")):
self.client = docker.from_env()
self.work_dir = work_dir
@ -43,6 +111,8 @@ class DockerInterpreter:
self.container = None
self.exec_id = None
self.socket = None
self.state_manager = StateManager(work_dir)
def create_interpreter_script(self) -> str:
"""Create the interpreter script that will run inside the container"""
@ -52,7 +122,9 @@ import code
import json
import traceback
import signal
import types
from threading import Lock
import pickle
class PersistentInterpreter(code.InteractiveInterpreter):
def __init__(self):
@ -67,6 +139,12 @@ class PersistentInterpreter(code.InteractiveInterpreter):
def run_command(self, source):
with self.lock:
self.output_buffer = []
pickle_path = self.work_dir / "locals.pickle"
if pickle_path.exists():
with open(pickle_path, 'rb') as f:
locals_dict_update = pickle.load(f)['variables']
self.locals_dict.update(locals_dict_update)
try:
more = self.runsource(source)
output = ''.join(self.output_buffer)
@ -78,11 +156,25 @@ class PersistentInterpreter(code.InteractiveInterpreter):
output = repr(result) + '\\n'
except:
pass
return json.dumps({'output': output, 'more': more, 'error': None}) + '\\n'
output = json.dumps({'output': output, 'more': more, 'error': None}) + '\\n'
except KeyboardInterrupt:
return json.dumps({'output': '\\nKeyboardInterrupt\\n', 'more': False, 'error': 'interrupt'}) + '\\n'
output = json.dumps({'output': '\\nKeyboardInterrupt\\n', 'more': False, 'error': 'interrupt'}) + '\\n'
except Exception as e:
return json.dumps({'output': f"Error: {str(e)}\\n", 'more': False, 'error': str(e)}) + '\\n'
output = json.dumps({'output': f"Error: {str(e)}\\n", 'more': False, 'error': str(e)}) + '\\n'
finally:
with open('/workspace/locals.pickle', 'wb') as f:
filtered_locals = {
k: v for k, v in self.locals_dict.items()
if not (
k.startswith('_')
or k in {'pickle', 'f'}
or callable(v)
or isinstance(v, type)
or isinstance(v, types.ModuleType)
)
}
pickle.dump(filtered_locals, f)
return output
def main():
interpreter = PersistentInterpreter()
@ -135,6 +227,8 @@ if __name__ == '__main__':
if container_name is None:
container_name = f"python-interpreter-{uuid.uuid4().hex[:8]}"
self.create_interpreter_script()
# Setup volume mapping
volumes = {
str(self.work_dir.resolve()): {"bind": "/workspace", "mode": "rw"}
@ -160,7 +254,7 @@ if __name__ == '__main__':
)
# Install packages in the new container
print("Installing packages...")
packages = ["pandas", "numpy"] # Add your required packages here
packages = ["pandas", "numpy", "pickle5"] # Add your required packages here
result = self.container.exec_run(
f"pip install {' '.join(packages)}",
@ -192,7 +286,11 @@ if __name__ == '__main__':
demux=True
)._sock
def execute(self, code: str) -> Tuple[str, bool]:
def _raw_execute(self, code: str) -> Tuple[str, bool]:
"""
Execute code directly without state management.
This is the original execute method functionality.
"""
if not self.container:
raise Exception("Container not started")
if not self.socket:
@ -209,6 +307,30 @@ if __name__ == '__main__':
except json.JSONDecodeError:
return f"Error: Invalid response from interpreter: {response}", False
def get_locals_dict(self) -> Dict[str, Any]:
"""Get the current locals dictionary from the interpreter by pickling directly from Docker."""
pickle_path = self.work_dir / "locals.pickle"
if pickle_path.exists():
with open(pickle_path, 'rb') as f:
try:
return pickle.load(f)
except Exception as e:
print(f"Error loading pickled locals: {e}")
return {}
return {}
def execute(self, code: str) -> Tuple[str, bool]:
# Track imports before execution
self.state_manager.track_imports(code)
output, more = self._raw_execute(code)
# Save state after execution
self.state_manager.save_state(
self.get_locals_dict(),
'docker'
)
return output, more
def stop(self, remove: bool = False):
if self.socket:
@ -227,33 +349,71 @@ if __name__ == '__main__':
print(f"Error stopping container: {e}")
raise
def main():
work_dir = Path("interpreter_workspace")
interpreter = DockerInterpreter(work_dir)
def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any:
from .local_python_executor import evaluate_python_code, LIST_SAFE_MODULES
def signal_handler(signum, frame):
print("\nExiting...")
interpreter.stop(remove=True)
sys.exit(0)
"""Execute code locally with state transfer."""
state_manager = StateManager(work_dir)
signal.signal(signal.SIGINT, signal_handler)
# Track imports
state_manager.track_imports(code)
print("Starting Python interpreter in Docker...")
interpreter.start("persistent_python_interpreter2")
# Load state from Docker if available
locals_dict = state_manager.load_state('local')
snippet = "import pandas as pd"
output, more = interpreter.execute(snippet)
print("OUTPUT1")
# Execute in a new namespace with loaded state
namespace = {}
namespace.update(locals_dict)
output = evaluate_python_code(
code,
tools,
{},
namespace,
LIST_SAFE_MODULES,
)
# Save state for Docker
state_manager.save_state(namespace, 'local')
return output
def create_tools_regex(tool_names):
# Escape any special regex characters in tool names
escaped_names = [re.escape(name) for name in tool_names]
# Join with | and add word boundaries
pattern = r'\b(' + '|'.join(escaped_names) + r')\b'
return re.compile(pattern)
def execute_code(code: str, tools: Dict[str, Any], work_dir: Path, interpreter):
"""Execute code with automatic switching between Docker and local."""
lines = code.split('\n')
current_block = []
tool_regex = create_tools_regex(list(tools.keys()) + ["print"]) # Added print for testing
tools = {
**BASE_PYTHON_TOOLS.copy(),
**tools,
}
for line in lines:
if tool_regex.search(line):
# Execute accumulated Docker code if any
if current_block:
output, more = interpreter.execute('\n'.join(current_block))
print(output, end='')
current_block = []
snippet = "pd.DataFrame()"
output, more = interpreter.execute(snippet)
print("OUTPUT2")
output = execute_locally(line, work_dir, tools)
if output:
print(output, end='')
else:
current_block.append(line)
# Execute any remaining Docker code
if current_block:
output, more = interpreter.execute('\n'.join(current_block))
print(output, end='')
print("\nStopping interpreter...")
interpreter.stop(remove=True)
if __name__ == '__main__':
main()
__all__ = ["DockerPythonInterpreter", "execute_code"]