Start fixing state transfer between local and docker executor
This commit is contained in:
parent
8e758fa130
commit
8ed03634b0
|
@ -47,7 +47,7 @@ def visit_webpage(url: str) -> str:
|
|||
|
||||
llm_engine = HfApiEngine(model)
|
||||
|
||||
web_agent = JsonAgent(
|
||||
web_agent = CodeAgent(
|
||||
tools=[DuckDuckGoSearchTool(), visit_webpage],
|
||||
llm_engine=llm_engine,
|
||||
max_iterations=10,
|
||||
|
|
|
@ -38,7 +38,7 @@ from .prompts import (
|
|||
SYSTEM_PROMPT_PLAN_UPDATE,
|
||||
SYSTEM_PROMPT_PLAN,
|
||||
)
|
||||
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
|
||||
from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code
|
||||
from .tools import (
|
||||
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||
Tool,
|
||||
|
|
|
@ -23,7 +23,7 @@ from typing import Dict
|
|||
from huggingface_hub import hf_hub_download, list_spaces
|
||||
|
||||
from transformers.utils import is_offline_mode
|
||||
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
|
||||
from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code
|
||||
from .tools import TOOL_CONFIG_FILE, Tool
|
||||
|
||||
|
||||
|
|
|
@ -1,13 +1,80 @@
|
|||
import sys
|
||||
import json
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
import docker
|
||||
import time
|
||||
import uuid
|
||||
import signal
|
||||
from typing import Optional, Dict, Tuple, Any
|
||||
import subprocess
|
||||
import pickle
|
||||
import re
|
||||
from typing import Optional, Dict, Tuple, Set, Any
|
||||
import types
|
||||
from .default_tools import BASE_PYTHON_TOOLS
|
||||
|
||||
class StateManager:
|
||||
def __init__(self, work_dir: Path):
|
||||
self.work_dir = work_dir
|
||||
self.state_file = work_dir / "interpreter_state.pickle"
|
||||
self.imports_file = work_dir / "imports.txt"
|
||||
self.import_pattern = re.compile(r'^(?:from\s+[\w.]+\s+)?import\s+.+$')
|
||||
self.imports: Set[str] = set()
|
||||
|
||||
def is_import_statement(self, code: str) -> bool:
|
||||
"""Check if a line of code is an import statement."""
|
||||
return bool(self.import_pattern.match(code.strip()))
|
||||
|
||||
def track_imports(self, code: str):
|
||||
"""Track import statements for later use."""
|
||||
for line in code.split('\n'):
|
||||
if self.is_import_statement(line.strip()):
|
||||
self.imports.add(line.strip())
|
||||
|
||||
def save_state(self, locals_dict: Dict[str, Any], executor: str):
|
||||
"""
|
||||
Save the current state of variables and imports.
|
||||
|
||||
Args:
|
||||
locals_dict: Dictionary of local variables
|
||||
executor: 'docker' or 'local' to indicate source
|
||||
"""
|
||||
# Filter out modules, functions, and special variables
|
||||
state_dict = {
|
||||
'variables': {
|
||||
k: v for k, v in locals_dict.items()
|
||||
if not (
|
||||
k.startswith('_')
|
||||
or callable(v)
|
||||
or isinstance(v, type)
|
||||
or isinstance(v, types.ModuleType)
|
||||
)
|
||||
},
|
||||
'imports': list(self.imports),
|
||||
'source': executor
|
||||
}
|
||||
|
||||
with open(self.state_file, 'wb') as f:
|
||||
pickle.dump(state_dict, f)
|
||||
|
||||
def load_state(self, executor: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Load the saved state and handle imports.
|
||||
|
||||
Args:
|
||||
executor: 'docker' or 'local' to indicate destination
|
||||
|
||||
Returns:
|
||||
Dictionary of variables to restore
|
||||
"""
|
||||
if not self.state_file.exists():
|
||||
return {}
|
||||
|
||||
with open(self.state_file, 'rb') as f:
|
||||
state_dict = pickle.load(f)
|
||||
|
||||
# First handle imports
|
||||
for import_stmt in state_dict['imports']:
|
||||
exec(import_stmt, globals())
|
||||
|
||||
return state_dict['variables']
|
||||
|
||||
|
||||
def read_multiplexed_response(socket):
|
||||
"""Read and demultiplex all responses from Docker exec socket"""
|
||||
|
@ -20,6 +87,7 @@ def read_multiplexed_response(socket):
|
|||
responses = response_data.split(b'\x01\x00\x00\x00\x00\x00')
|
||||
|
||||
# The last non-empty chunk should be our JSON response
|
||||
if len(responses) > 0:
|
||||
for chunk in reversed(responses):
|
||||
if chunk and len(chunk.strip()) > 0:
|
||||
try:
|
||||
|
@ -35,7 +103,7 @@ def read_multiplexed_response(socket):
|
|||
i+=1
|
||||
|
||||
|
||||
class DockerInterpreter:
|
||||
class DockerPythonInterpreter:
|
||||
def __init__(self, work_dir: Path = Path(".")):
|
||||
self.client = docker.from_env()
|
||||
self.work_dir = work_dir
|
||||
|
@ -43,6 +111,8 @@ class DockerInterpreter:
|
|||
self.container = None
|
||||
self.exec_id = None
|
||||
self.socket = None
|
||||
self.state_manager = StateManager(work_dir)
|
||||
|
||||
|
||||
def create_interpreter_script(self) -> str:
|
||||
"""Create the interpreter script that will run inside the container"""
|
||||
|
@ -52,7 +122,9 @@ import code
|
|||
import json
|
||||
import traceback
|
||||
import signal
|
||||
import types
|
||||
from threading import Lock
|
||||
import pickle
|
||||
|
||||
class PersistentInterpreter(code.InteractiveInterpreter):
|
||||
def __init__(self):
|
||||
|
@ -67,6 +139,12 @@ class PersistentInterpreter(code.InteractiveInterpreter):
|
|||
def run_command(self, source):
|
||||
with self.lock:
|
||||
self.output_buffer = []
|
||||
pickle_path = self.work_dir / "locals.pickle"
|
||||
if pickle_path.exists():
|
||||
with open(pickle_path, 'rb') as f:
|
||||
locals_dict_update = pickle.load(f)['variables']
|
||||
self.locals_dict.update(locals_dict_update)
|
||||
|
||||
try:
|
||||
more = self.runsource(source)
|
||||
output = ''.join(self.output_buffer)
|
||||
|
@ -78,11 +156,25 @@ class PersistentInterpreter(code.InteractiveInterpreter):
|
|||
output = repr(result) + '\\n'
|
||||
except:
|
||||
pass
|
||||
return json.dumps({'output': output, 'more': more, 'error': None}) + '\\n'
|
||||
output = json.dumps({'output': output, 'more': more, 'error': None}) + '\\n'
|
||||
except KeyboardInterrupt:
|
||||
return json.dumps({'output': '\\nKeyboardInterrupt\\n', 'more': False, 'error': 'interrupt'}) + '\\n'
|
||||
output = json.dumps({'output': '\\nKeyboardInterrupt\\n', 'more': False, 'error': 'interrupt'}) + '\\n'
|
||||
except Exception as e:
|
||||
return json.dumps({'output': f"Error: {str(e)}\\n", 'more': False, 'error': str(e)}) + '\\n'
|
||||
output = json.dumps({'output': f"Error: {str(e)}\\n", 'more': False, 'error': str(e)}) + '\\n'
|
||||
finally:
|
||||
with open('/workspace/locals.pickle', 'wb') as f:
|
||||
filtered_locals = {
|
||||
k: v for k, v in self.locals_dict.items()
|
||||
if not (
|
||||
k.startswith('_')
|
||||
or k in {'pickle', 'f'}
|
||||
or callable(v)
|
||||
or isinstance(v, type)
|
||||
or isinstance(v, types.ModuleType)
|
||||
)
|
||||
}
|
||||
pickle.dump(filtered_locals, f)
|
||||
return output
|
||||
|
||||
def main():
|
||||
interpreter = PersistentInterpreter()
|
||||
|
@ -135,6 +227,8 @@ if __name__ == '__main__':
|
|||
if container_name is None:
|
||||
container_name = f"python-interpreter-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
self.create_interpreter_script()
|
||||
|
||||
# Setup volume mapping
|
||||
volumes = {
|
||||
str(self.work_dir.resolve()): {"bind": "/workspace", "mode": "rw"}
|
||||
|
@ -160,7 +254,7 @@ if __name__ == '__main__':
|
|||
)
|
||||
# Install packages in the new container
|
||||
print("Installing packages...")
|
||||
packages = ["pandas", "numpy"] # Add your required packages here
|
||||
packages = ["pandas", "numpy", "pickle5"] # Add your required packages here
|
||||
|
||||
result = self.container.exec_run(
|
||||
f"pip install {' '.join(packages)}",
|
||||
|
@ -192,7 +286,11 @@ if __name__ == '__main__':
|
|||
demux=True
|
||||
)._sock
|
||||
|
||||
def execute(self, code: str) -> Tuple[str, bool]:
|
||||
def _raw_execute(self, code: str) -> Tuple[str, bool]:
|
||||
"""
|
||||
Execute code directly without state management.
|
||||
This is the original execute method functionality.
|
||||
"""
|
||||
if not self.container:
|
||||
raise Exception("Container not started")
|
||||
if not self.socket:
|
||||
|
@ -209,6 +307,30 @@ if __name__ == '__main__':
|
|||
except json.JSONDecodeError:
|
||||
return f"Error: Invalid response from interpreter: {response}", False
|
||||
|
||||
def get_locals_dict(self) -> Dict[str, Any]:
|
||||
"""Get the current locals dictionary from the interpreter by pickling directly from Docker."""
|
||||
pickle_path = self.work_dir / "locals.pickle"
|
||||
if pickle_path.exists():
|
||||
with open(pickle_path, 'rb') as f:
|
||||
try:
|
||||
return pickle.load(f)
|
||||
except Exception as e:
|
||||
print(f"Error loading pickled locals: {e}")
|
||||
return {}
|
||||
return {}
|
||||
|
||||
def execute(self, code: str) -> Tuple[str, bool]:
|
||||
# Track imports before execution
|
||||
self.state_manager.track_imports(code)
|
||||
|
||||
output, more = self._raw_execute(code)
|
||||
|
||||
# Save state after execution
|
||||
self.state_manager.save_state(
|
||||
self.get_locals_dict(),
|
||||
'docker'
|
||||
)
|
||||
return output, more
|
||||
|
||||
def stop(self, remove: bool = False):
|
||||
if self.socket:
|
||||
|
@ -227,33 +349,71 @@ if __name__ == '__main__':
|
|||
print(f"Error stopping container: {e}")
|
||||
raise
|
||||
|
||||
def main():
|
||||
work_dir = Path("interpreter_workspace")
|
||||
interpreter = DockerInterpreter(work_dir)
|
||||
def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any:
|
||||
from .local_python_executor import evaluate_python_code, LIST_SAFE_MODULES
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
print("\nExiting...")
|
||||
interpreter.stop(remove=True)
|
||||
sys.exit(0)
|
||||
"""Execute code locally with state transfer."""
|
||||
state_manager = StateManager(work_dir)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
# Track imports
|
||||
state_manager.track_imports(code)
|
||||
|
||||
print("Starting Python interpreter in Docker...")
|
||||
interpreter.start("persistent_python_interpreter2")
|
||||
# Load state from Docker if available
|
||||
locals_dict = state_manager.load_state('local')
|
||||
|
||||
snippet = "import pandas as pd"
|
||||
output, more = interpreter.execute(snippet)
|
||||
print("OUTPUT1")
|
||||
# Execute in a new namespace with loaded state
|
||||
namespace = {}
|
||||
namespace.update(locals_dict)
|
||||
|
||||
output = evaluate_python_code(
|
||||
code,
|
||||
tools,
|
||||
{},
|
||||
namespace,
|
||||
LIST_SAFE_MODULES,
|
||||
)
|
||||
|
||||
# Save state for Docker
|
||||
state_manager.save_state(namespace, 'local')
|
||||
return output
|
||||
|
||||
|
||||
def create_tools_regex(tool_names):
|
||||
# Escape any special regex characters in tool names
|
||||
escaped_names = [re.escape(name) for name in tool_names]
|
||||
# Join with | and add word boundaries
|
||||
pattern = r'\b(' + '|'.join(escaped_names) + r')\b'
|
||||
return re.compile(pattern)
|
||||
|
||||
def execute_code(code: str, tools: Dict[str, Any], work_dir: Path, interpreter):
|
||||
"""Execute code with automatic switching between Docker and local."""
|
||||
lines = code.split('\n')
|
||||
current_block = []
|
||||
tool_regex = create_tools_regex(list(tools.keys()) + ["print"]) # Added print for testing
|
||||
|
||||
tools = {
|
||||
**BASE_PYTHON_TOOLS.copy(),
|
||||
**tools,
|
||||
}
|
||||
|
||||
for line in lines:
|
||||
if tool_regex.search(line):
|
||||
# Execute accumulated Docker code if any
|
||||
if current_block:
|
||||
output, more = interpreter.execute('\n'.join(current_block))
|
||||
print(output, end='')
|
||||
current_block = []
|
||||
|
||||
snippet = "pd.DataFrame()"
|
||||
output, more = interpreter.execute(snippet)
|
||||
print("OUTPUT2")
|
||||
output = execute_locally(line, work_dir, tools)
|
||||
if output:
|
||||
print(output, end='')
|
||||
else:
|
||||
current_block.append(line)
|
||||
|
||||
# Execute any remaining Docker code
|
||||
if current_block:
|
||||
output, more = interpreter.execute('\n'.join(current_block))
|
||||
print(output, end='')
|
||||
|
||||
|
||||
print("\nStopping interpreter...")
|
||||
interpreter.stop(remove=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
__all__ = ["DockerPythonInterpreter", "execute_code"]
|
Loading…
Reference in New Issue