Fix: source code inspection in interactive shells (#281)
* Support interactive shells for source inspection * Add tool save e2e tests
This commit is contained in:
parent
5d6502ae1d
commit
6196958deb
|
@ -60,6 +60,7 @@ all = [
|
||||||
"smolagents[accelerate,audio,e2b,gradio,litellm,mcp,openai,transformers]",
|
"smolagents[accelerate,audio,e2b,gradio,litellm,mcp,openai,transformers]",
|
||||||
]
|
]
|
||||||
test = [
|
test = [
|
||||||
|
"ipython>=8.31.0", # for interactive environment tests
|
||||||
"pytest>=8.1.0",
|
"pytest>=8.1.0",
|
||||||
"python-dotenv>=1.0.1", # For test_all_docs
|
"python-dotenv>=1.0.1", # For test_all_docs
|
||||||
"smolagents[all]",
|
"smolagents[all]",
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
import ast
|
import ast
|
||||||
import builtins
|
import builtins
|
||||||
import inspect
|
import inspect
|
||||||
import textwrap
|
|
||||||
from typing import Set
|
from typing import Set
|
||||||
|
|
||||||
from .utils import BASE_BUILTIN_MODULES
|
from .utils import BASE_BUILTIN_MODULES, get_source
|
||||||
|
|
||||||
|
|
||||||
_BUILTIN_NAMES = set(vars(builtins))
|
_BUILTIN_NAMES = set(vars(builtins))
|
||||||
|
@ -132,7 +131,7 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None:
|
||||||
"""
|
"""
|
||||||
errors = []
|
errors = []
|
||||||
|
|
||||||
source = textwrap.dedent(inspect.getsource(cls))
|
source = get_source(cls)
|
||||||
|
|
||||||
tree = ast.parse(source)
|
tree = ast.parse(source)
|
||||||
|
|
||||||
|
|
|
@ -46,7 +46,7 @@ from ._function_type_hints_utils import (
|
||||||
)
|
)
|
||||||
from .tool_validation import MethodChecker, validate_tool_attributes
|
from .tool_validation import MethodChecker, validate_tool_attributes
|
||||||
from .types import handle_agent_input_types, handle_agent_output_types
|
from .types import handle_agent_input_types, handle_agent_output_types
|
||||||
from .utils import _is_package_available, _is_pillow_available, instance_to_source
|
from .utils import _is_package_available, _is_pillow_available, get_source, instance_to_source
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -220,8 +220,8 @@ class Tool:
|
||||||
# Save tool file
|
# Save tool file
|
||||||
if type(self).__name__ == "SimpleTool":
|
if type(self).__name__ == "SimpleTool":
|
||||||
# Check that imports are self-contained
|
# Check that imports are self-contained
|
||||||
source_code = inspect.getsource(self.forward).replace("@tool", "")
|
source_code = get_source(self.forward).replace("@tool", "")
|
||||||
forward_node = ast.parse(textwrap.dedent(source_code))
|
forward_node = ast.parse(source_code)
|
||||||
# If tool was created using '@tool' decorator, it has only a forward pass, so it's simpler to just get its code
|
# If tool was created using '@tool' decorator, it has only a forward pass, so it's simpler to just get its code
|
||||||
method_checker = MethodChecker(set())
|
method_checker = MethodChecker(set())
|
||||||
method_checker.visit(forward_node)
|
method_checker.visit(forward_node)
|
||||||
|
@ -229,7 +229,7 @@ class Tool:
|
||||||
if len(method_checker.errors) > 0:
|
if len(method_checker.errors) > 0:
|
||||||
raise (ValueError("\n".join(method_checker.errors)))
|
raise (ValueError("\n".join(method_checker.errors)))
|
||||||
|
|
||||||
forward_source_code = inspect.getsource(self.forward)
|
forward_source_code = get_source(self.forward)
|
||||||
tool_code = textwrap.dedent(
|
tool_code = textwrap.dedent(
|
||||||
f"""
|
f"""
|
||||||
from smolagents import Tool
|
from smolagents import Tool
|
||||||
|
|
|
@ -20,6 +20,7 @@ import importlib.util
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import textwrap
|
||||||
import types
|
import types
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
@ -221,7 +222,7 @@ def get_method_source(method):
|
||||||
"""Get source code for a method, including bound methods."""
|
"""Get source code for a method, including bound methods."""
|
||||||
if isinstance(method, types.MethodType):
|
if isinstance(method, types.MethodType):
|
||||||
method = method.__func__
|
method = method.__func__
|
||||||
return inspect.getsource(method).strip()
|
return get_source(method)
|
||||||
|
|
||||||
|
|
||||||
def is_same_method(method1, method2):
|
def is_same_method(method1, method2):
|
||||||
|
@ -295,7 +296,7 @@ def instance_to_source(instance, base_cls=None):
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, method in methods.items():
|
for name, method in methods.items():
|
||||||
method_source = inspect.getsource(method)
|
method_source = get_source(method)
|
||||||
# Clean up the indentation
|
# Clean up the indentation
|
||||||
method_lines = method_source.split("\n")
|
method_lines = method_source.split("\n")
|
||||||
first_line = method_lines[0]
|
first_line = method_lines[0]
|
||||||
|
@ -330,4 +331,56 @@ def instance_to_source(instance, base_cls=None):
|
||||||
return "\n".join(final_lines)
|
return "\n".join(final_lines)
|
||||||
|
|
||||||
|
|
||||||
|
def get_source(obj) -> str:
|
||||||
|
"""Get the source code of a class or callable object (e.g.: function, method).
|
||||||
|
First attempts to get the source code using `inspect.getsource`.
|
||||||
|
In a dynamic environment (e.g.: Jupyter, IPython), if this fails,
|
||||||
|
falls back to retrieving the source code from the current interactive shell session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: A class or callable object (e.g.: function, method)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The source code of the object, dedented and stripped
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If object is not a class or callable
|
||||||
|
OSError: If source code cannot be retrieved from any source
|
||||||
|
ValueError: If source cannot be found in IPython history
|
||||||
|
|
||||||
|
Note:
|
||||||
|
TODO: handle Python standard REPL
|
||||||
|
"""
|
||||||
|
if not (isinstance(obj, type) or callable(obj)):
|
||||||
|
raise TypeError(f"Expected class or callable, got {type(obj)}")
|
||||||
|
|
||||||
|
inspect_error = None
|
||||||
|
try:
|
||||||
|
return textwrap.dedent(inspect.getsource(obj)).strip()
|
||||||
|
except OSError as e:
|
||||||
|
# let's keep track of the exception to raise it if all further methods fail
|
||||||
|
inspect_error = e
|
||||||
|
try:
|
||||||
|
import IPython
|
||||||
|
|
||||||
|
shell = IPython.get_ipython()
|
||||||
|
if not shell:
|
||||||
|
raise ImportError("No active IPython shell found")
|
||||||
|
all_cells = "\n".join(shell.user_ns.get("In", [])).strip()
|
||||||
|
if not all_cells:
|
||||||
|
raise ValueError("No code cells found in IPython session")
|
||||||
|
|
||||||
|
tree = ast.parse(all_cells)
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
if isinstance(node, (ast.ClassDef, ast.FunctionDef)) and node.name == obj.__name__:
|
||||||
|
return textwrap.dedent("\n".join(all_cells.split("\n")[node.lineno - 1 : node.end_lineno])).strip()
|
||||||
|
raise ValueError(f"Could not find source code for {obj.__name__} in IPython history")
|
||||||
|
except ImportError:
|
||||||
|
# IPython is not available, let's just raise the original inspect error
|
||||||
|
raise inspect_error
|
||||||
|
except ValueError as e:
|
||||||
|
# IPython is available but we couldn't find the source code, let's raise the error
|
||||||
|
raise e from inspect_error
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["AgentError"]
|
__all__ = ["AgentError"]
|
||||||
|
|
|
@ -12,11 +12,19 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import inspect
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import tempfile
|
||||||
|
import textwrap
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from IPython.core.interactiveshell import InteractiveShell
|
||||||
|
|
||||||
from smolagents.utils import parse_code_blobs
|
from smolagents import Tool
|
||||||
|
from smolagents.tools import tool
|
||||||
|
from smolagents.utils import get_source, parse_code_blobs
|
||||||
|
|
||||||
|
|
||||||
class AgentTextTests(unittest.TestCase):
|
class AgentTextTests(unittest.TestCase):
|
||||||
|
@ -58,3 +66,287 @@ def multiply(a, b):
|
||||||
return a * b"""
|
return a * b"""
|
||||||
result = parse_code_blobs(test_input)
|
result = parse_code_blobs(test_input)
|
||||||
assert result == expected_output
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def ipython_shell():
|
||||||
|
"""Reset IPython shell before and after each test."""
|
||||||
|
shell = InteractiveShell.instance()
|
||||||
|
shell.reset() # Clean before test
|
||||||
|
yield shell
|
||||||
|
shell.reset() # Clean after test
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"obj_name, code_blob",
|
||||||
|
[
|
||||||
|
("test_func", "def test_func():\n return 42"),
|
||||||
|
("TestClass", "class TestClass:\n ..."),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_source_ipython(ipython_shell, obj_name, code_blob):
|
||||||
|
ipython_shell.run_cell(code_blob, store_history=True)
|
||||||
|
obj = ipython_shell.user_ns[obj_name]
|
||||||
|
assert get_source(obj) == code_blob
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_source_standard_class():
|
||||||
|
class TestClass: ...
|
||||||
|
|
||||||
|
source = get_source(TestClass)
|
||||||
|
assert source == "class TestClass: ..."
|
||||||
|
assert source == textwrap.dedent(inspect.getsource(TestClass)).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_source_standard_function():
|
||||||
|
def test_func(): ...
|
||||||
|
|
||||||
|
source = get_source(test_func)
|
||||||
|
assert source == "def test_func(): ..."
|
||||||
|
assert source == textwrap.dedent(inspect.getsource(test_func)).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_source_ipython_errors_empty_cells(ipython_shell):
|
||||||
|
test_code = textwrap.dedent("""class TestClass:\n ...""").strip()
|
||||||
|
ipython_shell.user_ns["In"] = [""]
|
||||||
|
exec(test_code)
|
||||||
|
with pytest.raises(ValueError, match="No code cells found in IPython session"):
|
||||||
|
get_source(locals()["TestClass"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_source_ipython_errors_definition_not_found(ipython_shell):
|
||||||
|
test_code = textwrap.dedent("""class TestClass:\n ...""").strip()
|
||||||
|
ipython_shell.user_ns["In"] = ["", "print('No class definition here')"]
|
||||||
|
exec(test_code)
|
||||||
|
with pytest.raises(ValueError, match="Could not find source code for TestClass in IPython history"):
|
||||||
|
get_source(locals()["TestClass"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_source_ipython_errors_type_error():
|
||||||
|
with pytest.raises(TypeError, match="Expected class or callable"):
|
||||||
|
get_source(None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_e2e_class_tool_save():
|
||||||
|
class TestTool(Tool):
|
||||||
|
name = "test_tool"
|
||||||
|
description = "Test tool description"
|
||||||
|
inputs = {
|
||||||
|
"task": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "tool input",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output_type = "string"
|
||||||
|
|
||||||
|
def forward(self, task: str):
|
||||||
|
import IPython # noqa: F401
|
||||||
|
|
||||||
|
return task
|
||||||
|
|
||||||
|
test_tool = TestTool()
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
test_tool.save(tmp_dir)
|
||||||
|
assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"}
|
||||||
|
assert (
|
||||||
|
pathlib.Path(tmp_dir, "tool.py").read_text()
|
||||||
|
== """from smolagents.tools import Tool
|
||||||
|
import IPython
|
||||||
|
|
||||||
|
class TestTool(Tool):
|
||||||
|
name = "test_tool"
|
||||||
|
description = "Test tool description"
|
||||||
|
inputs = {'task': {'type': 'string', 'description': 'tool input'}}
|
||||||
|
output_type = "string"
|
||||||
|
|
||||||
|
def forward(self, task: str):
|
||||||
|
import IPython # noqa: F401
|
||||||
|
|
||||||
|
return task
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.is_initialized = False
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
requirements = set(pathlib.Path(tmp_dir, "requirements.txt").read_text().split())
|
||||||
|
assert requirements == {"IPython", "smolagents"}
|
||||||
|
assert (
|
||||||
|
pathlib.Path(tmp_dir, "app.py").read_text()
|
||||||
|
== """from smolagents import launch_gradio_demo
|
||||||
|
from typing import Optional
|
||||||
|
from tool import TestTool
|
||||||
|
|
||||||
|
tool = TestTool()
|
||||||
|
|
||||||
|
launch_gradio_demo(tool)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_e2e_ipython_class_tool_save():
|
||||||
|
shell = InteractiveShell.instance()
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
code_blob = textwrap.dedent(f"""
|
||||||
|
from smolagents.tools import Tool
|
||||||
|
class TestTool(Tool):
|
||||||
|
name = "test_tool"
|
||||||
|
description = "Test tool description"
|
||||||
|
inputs = {{"task": {{"type": "string",
|
||||||
|
"description": "tool input",
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
output_type = "string"
|
||||||
|
|
||||||
|
def forward(self, task: str):
|
||||||
|
import IPython # noqa: F401
|
||||||
|
|
||||||
|
return task
|
||||||
|
TestTool().save("{tmp_dir}")
|
||||||
|
""")
|
||||||
|
assert shell.run_cell(code_blob, store_history=True).success
|
||||||
|
assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"}
|
||||||
|
assert (
|
||||||
|
pathlib.Path(tmp_dir, "tool.py").read_text()
|
||||||
|
== """from smolagents.tools import Tool
|
||||||
|
import IPython
|
||||||
|
|
||||||
|
class TestTool(Tool):
|
||||||
|
name = "test_tool"
|
||||||
|
description = "Test tool description"
|
||||||
|
inputs = {'task': {'type': 'string', 'description': 'tool input'}}
|
||||||
|
output_type = "string"
|
||||||
|
|
||||||
|
def forward(self, task: str):
|
||||||
|
import IPython # noqa: F401
|
||||||
|
|
||||||
|
return task
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.is_initialized = False
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
requirements = set(pathlib.Path(tmp_dir, "requirements.txt").read_text().split())
|
||||||
|
assert requirements == {"IPython", "smolagents"}
|
||||||
|
assert (
|
||||||
|
pathlib.Path(tmp_dir, "app.py").read_text()
|
||||||
|
== """from smolagents import launch_gradio_demo
|
||||||
|
from typing import Optional
|
||||||
|
from tool import TestTool
|
||||||
|
|
||||||
|
tool = TestTool()
|
||||||
|
|
||||||
|
launch_gradio_demo(tool)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_e2e_function_tool_save():
|
||||||
|
@tool
|
||||||
|
def test_tool(task: str) -> str:
|
||||||
|
"""
|
||||||
|
Test tool description
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: tool input
|
||||||
|
"""
|
||||||
|
import IPython # noqa: F401
|
||||||
|
|
||||||
|
return task
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
test_tool.save(tmp_dir)
|
||||||
|
assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"}
|
||||||
|
assert (
|
||||||
|
pathlib.Path(tmp_dir, "tool.py").read_text()
|
||||||
|
== """from smolagents import Tool
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
class SimpleTool(Tool):
|
||||||
|
name = "test_tool"
|
||||||
|
description = "Test tool description"
|
||||||
|
inputs = {"task":{"type":"string","description":"tool input"}}
|
||||||
|
output_type = "string"
|
||||||
|
|
||||||
|
def forward(self, task: str) -> str:
|
||||||
|
\"""
|
||||||
|
Test tool description
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: tool input
|
||||||
|
\"""
|
||||||
|
import IPython # noqa: F401
|
||||||
|
|
||||||
|
return task"""
|
||||||
|
)
|
||||||
|
requirements = set(pathlib.Path(tmp_dir, "requirements.txt").read_text().split())
|
||||||
|
assert requirements == {"smolagents"} # FIXME: IPython should be in the requirements
|
||||||
|
assert (
|
||||||
|
pathlib.Path(tmp_dir, "app.py").read_text()
|
||||||
|
== """from smolagents import launch_gradio_demo
|
||||||
|
from typing import Optional
|
||||||
|
from tool import SimpleTool
|
||||||
|
|
||||||
|
tool = SimpleTool()
|
||||||
|
|
||||||
|
launch_gradio_demo(tool)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_e2e_ipython_function_tool_save():
|
||||||
|
shell = InteractiveShell.instance()
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
code_blob = textwrap.dedent(f"""
|
||||||
|
from smolagents import tool
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def test_tool(task: str) -> str:
|
||||||
|
\"""
|
||||||
|
Test tool description
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: tool input
|
||||||
|
\"""
|
||||||
|
import IPython # noqa: F401
|
||||||
|
|
||||||
|
return task
|
||||||
|
|
||||||
|
test_tool.save("{tmp_dir}")
|
||||||
|
""")
|
||||||
|
assert shell.run_cell(code_blob, store_history=True).success
|
||||||
|
assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"}
|
||||||
|
assert (
|
||||||
|
pathlib.Path(tmp_dir, "tool.py").read_text()
|
||||||
|
== """from smolagents import Tool
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
class SimpleTool(Tool):
|
||||||
|
name = "test_tool"
|
||||||
|
description = "Test tool description"
|
||||||
|
inputs = {"task":{"type":"string","description":"tool input"}}
|
||||||
|
output_type = "string"
|
||||||
|
|
||||||
|
def forward(self, task: str) -> str:
|
||||||
|
\"""
|
||||||
|
Test tool description
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: tool input
|
||||||
|
\"""
|
||||||
|
import IPython # noqa: F401
|
||||||
|
|
||||||
|
return task"""
|
||||||
|
)
|
||||||
|
requirements = set(pathlib.Path(tmp_dir, "requirements.txt").read_text().split())
|
||||||
|
assert requirements == {"smolagents"} # FIXME: IPython should be in the requirements
|
||||||
|
assert (
|
||||||
|
pathlib.Path(tmp_dir, "app.py").read_text()
|
||||||
|
== """from smolagents import launch_gradio_demo
|
||||||
|
from typing import Optional
|
||||||
|
from tool import SimpleTool
|
||||||
|
|
||||||
|
tool = SimpleTool()
|
||||||
|
|
||||||
|
launch_gradio_demo(tool)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue