Fix: source code inspection in interactive shells (#281)

* Support interactive shells for source inspection
* Add tool save e2e tests
This commit is contained in:
Antoine Jeannot 2025-01-22 13:43:16 +01:00 committed by GitHub
parent 5d6502ae1d
commit 6196958deb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 355 additions and 10 deletions

View File

@ -60,6 +60,7 @@ all = [
"smolagents[accelerate,audio,e2b,gradio,litellm,mcp,openai,transformers]", "smolagents[accelerate,audio,e2b,gradio,litellm,mcp,openai,transformers]",
] ]
test = [ test = [
"ipython>=8.31.0", # for interactive environment tests
"pytest>=8.1.0", "pytest>=8.1.0",
"python-dotenv>=1.0.1", # For test_all_docs "python-dotenv>=1.0.1", # For test_all_docs
"smolagents[all]", "smolagents[all]",

View File

@ -1,10 +1,9 @@
import ast import ast
import builtins import builtins
import inspect import inspect
import textwrap
from typing import Set from typing import Set
from .utils import BASE_BUILTIN_MODULES from .utils import BASE_BUILTIN_MODULES, get_source
_BUILTIN_NAMES = set(vars(builtins)) _BUILTIN_NAMES = set(vars(builtins))
@ -132,7 +131,7 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None:
""" """
errors = [] errors = []
source = textwrap.dedent(inspect.getsource(cls)) source = get_source(cls)
tree = ast.parse(source) tree = ast.parse(source)

View File

@ -46,7 +46,7 @@ from ._function_type_hints_utils import (
) )
from .tool_validation import MethodChecker, validate_tool_attributes from .tool_validation import MethodChecker, validate_tool_attributes
from .types import handle_agent_input_types, handle_agent_output_types from .types import handle_agent_input_types, handle_agent_output_types
from .utils import _is_package_available, _is_pillow_available, instance_to_source from .utils import _is_package_available, _is_pillow_available, get_source, instance_to_source
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -220,8 +220,8 @@ class Tool:
# Save tool file # Save tool file
if type(self).__name__ == "SimpleTool": if type(self).__name__ == "SimpleTool":
# Check that imports are self-contained # Check that imports are self-contained
source_code = inspect.getsource(self.forward).replace("@tool", "") source_code = get_source(self.forward).replace("@tool", "")
forward_node = ast.parse(textwrap.dedent(source_code)) forward_node = ast.parse(source_code)
# If tool was created using '@tool' decorator, it has only a forward pass, so it's simpler to just get its code # If tool was created using '@tool' decorator, it has only a forward pass, so it's simpler to just get its code
method_checker = MethodChecker(set()) method_checker = MethodChecker(set())
method_checker.visit(forward_node) method_checker.visit(forward_node)
@ -229,7 +229,7 @@ class Tool:
if len(method_checker.errors) > 0: if len(method_checker.errors) > 0:
raise (ValueError("\n".join(method_checker.errors))) raise (ValueError("\n".join(method_checker.errors)))
forward_source_code = inspect.getsource(self.forward) forward_source_code = get_source(self.forward)
tool_code = textwrap.dedent( tool_code = textwrap.dedent(
f""" f"""
from smolagents import Tool from smolagents import Tool

View File

@ -20,6 +20,7 @@ import importlib.util
import inspect import inspect
import json import json
import re import re
import textwrap
import types import types
from enum import IntEnum from enum import IntEnum
from functools import lru_cache from functools import lru_cache
@ -221,7 +222,7 @@ def get_method_source(method):
"""Get source code for a method, including bound methods.""" """Get source code for a method, including bound methods."""
if isinstance(method, types.MethodType): if isinstance(method, types.MethodType):
method = method.__func__ method = method.__func__
return inspect.getsource(method).strip() return get_source(method)
def is_same_method(method1, method2): def is_same_method(method1, method2):
@ -295,7 +296,7 @@ def instance_to_source(instance, base_cls=None):
} }
for name, method in methods.items(): for name, method in methods.items():
method_source = inspect.getsource(method) method_source = get_source(method)
# Clean up the indentation # Clean up the indentation
method_lines = method_source.split("\n") method_lines = method_source.split("\n")
first_line = method_lines[0] first_line = method_lines[0]
@ -330,4 +331,56 @@ def instance_to_source(instance, base_cls=None):
return "\n".join(final_lines) return "\n".join(final_lines)
def get_source(obj) -> str:
"""Get the source code of a class or callable object (e.g.: function, method).
First attempts to get the source code using `inspect.getsource`.
In a dynamic environment (e.g.: Jupyter, IPython), if this fails,
falls back to retrieving the source code from the current interactive shell session.
Args:
obj: A class or callable object (e.g.: function, method)
Returns:
str: The source code of the object, dedented and stripped
Raises:
TypeError: If object is not a class or callable
OSError: If source code cannot be retrieved from any source
ValueError: If source cannot be found in IPython history
Note:
TODO: handle Python standard REPL
"""
if not (isinstance(obj, type) or callable(obj)):
raise TypeError(f"Expected class or callable, got {type(obj)}")
inspect_error = None
try:
return textwrap.dedent(inspect.getsource(obj)).strip()
except OSError as e:
# let's keep track of the exception to raise it if all further methods fail
inspect_error = e
try:
import IPython
shell = IPython.get_ipython()
if not shell:
raise ImportError("No active IPython shell found")
all_cells = "\n".join(shell.user_ns.get("In", [])).strip()
if not all_cells:
raise ValueError("No code cells found in IPython session")
tree = ast.parse(all_cells)
for node in ast.walk(tree):
if isinstance(node, (ast.ClassDef, ast.FunctionDef)) and node.name == obj.__name__:
return textwrap.dedent("\n".join(all_cells.split("\n")[node.lineno - 1 : node.end_lineno])).strip()
raise ValueError(f"Could not find source code for {obj.__name__} in IPython history")
except ImportError:
# IPython is not available, let's just raise the original inspect error
raise inspect_error
except ValueError as e:
# IPython is available but we couldn't find the source code, let's raise the error
raise e from inspect_error
__all__ = ["AgentError"] __all__ = ["AgentError"]

View File

@ -12,11 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
import os
import pathlib
import tempfile
import textwrap
import unittest import unittest
import pytest import pytest
from IPython.core.interactiveshell import InteractiveShell
from smolagents.utils import parse_code_blobs from smolagents import Tool
from smolagents.tools import tool
from smolagents.utils import get_source, parse_code_blobs
class AgentTextTests(unittest.TestCase): class AgentTextTests(unittest.TestCase):
@ -58,3 +66,287 @@ def multiply(a, b):
return a * b""" return a * b"""
result = parse_code_blobs(test_input) result = parse_code_blobs(test_input)
assert result == expected_output assert result == expected_output
@pytest.fixture(scope="function")
def ipython_shell():
"""Reset IPython shell before and after each test."""
shell = InteractiveShell.instance()
shell.reset() # Clean before test
yield shell
shell.reset() # Clean after test
@pytest.mark.parametrize(
"obj_name, code_blob",
[
("test_func", "def test_func():\n return 42"),
("TestClass", "class TestClass:\n ..."),
],
)
def test_get_source_ipython(ipython_shell, obj_name, code_blob):
ipython_shell.run_cell(code_blob, store_history=True)
obj = ipython_shell.user_ns[obj_name]
assert get_source(obj) == code_blob
def test_get_source_standard_class():
class TestClass: ...
source = get_source(TestClass)
assert source == "class TestClass: ..."
assert source == textwrap.dedent(inspect.getsource(TestClass)).strip()
def test_get_source_standard_function():
def test_func(): ...
source = get_source(test_func)
assert source == "def test_func(): ..."
assert source == textwrap.dedent(inspect.getsource(test_func)).strip()
def test_get_source_ipython_errors_empty_cells(ipython_shell):
test_code = textwrap.dedent("""class TestClass:\n ...""").strip()
ipython_shell.user_ns["In"] = [""]
exec(test_code)
with pytest.raises(ValueError, match="No code cells found in IPython session"):
get_source(locals()["TestClass"])
def test_get_source_ipython_errors_definition_not_found(ipython_shell):
test_code = textwrap.dedent("""class TestClass:\n ...""").strip()
ipython_shell.user_ns["In"] = ["", "print('No class definition here')"]
exec(test_code)
with pytest.raises(ValueError, match="Could not find source code for TestClass in IPython history"):
get_source(locals()["TestClass"])
def test_get_source_ipython_errors_type_error():
with pytest.raises(TypeError, match="Expected class or callable"):
get_source(None)
def test_e2e_class_tool_save():
class TestTool(Tool):
name = "test_tool"
description = "Test tool description"
inputs = {
"task": {
"type": "string",
"description": "tool input",
}
}
output_type = "string"
def forward(self, task: str):
import IPython # noqa: F401
return task
test_tool = TestTool()
with tempfile.TemporaryDirectory() as tmp_dir:
test_tool.save(tmp_dir)
assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"}
assert (
pathlib.Path(tmp_dir, "tool.py").read_text()
== """from smolagents.tools import Tool
import IPython
class TestTool(Tool):
name = "test_tool"
description = "Test tool description"
inputs = {'task': {'type': 'string', 'description': 'tool input'}}
output_type = "string"
def forward(self, task: str):
import IPython # noqa: F401
return task
def __init__(self, *args, **kwargs):
self.is_initialized = False
"""
)
requirements = set(pathlib.Path(tmp_dir, "requirements.txt").read_text().split())
assert requirements == {"IPython", "smolagents"}
assert (
pathlib.Path(tmp_dir, "app.py").read_text()
== """from smolagents import launch_gradio_demo
from typing import Optional
from tool import TestTool
tool = TestTool()
launch_gradio_demo(tool)
"""
)
def test_e2e_ipython_class_tool_save():
shell = InteractiveShell.instance()
with tempfile.TemporaryDirectory() as tmp_dir:
code_blob = textwrap.dedent(f"""
from smolagents.tools import Tool
class TestTool(Tool):
name = "test_tool"
description = "Test tool description"
inputs = {{"task": {{"type": "string",
"description": "tool input",
}}
}}
output_type = "string"
def forward(self, task: str):
import IPython # noqa: F401
return task
TestTool().save("{tmp_dir}")
""")
assert shell.run_cell(code_blob, store_history=True).success
assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"}
assert (
pathlib.Path(tmp_dir, "tool.py").read_text()
== """from smolagents.tools import Tool
import IPython
class TestTool(Tool):
name = "test_tool"
description = "Test tool description"
inputs = {'task': {'type': 'string', 'description': 'tool input'}}
output_type = "string"
def forward(self, task: str):
import IPython # noqa: F401
return task
def __init__(self, *args, **kwargs):
self.is_initialized = False
"""
)
requirements = set(pathlib.Path(tmp_dir, "requirements.txt").read_text().split())
assert requirements == {"IPython", "smolagents"}
assert (
pathlib.Path(tmp_dir, "app.py").read_text()
== """from smolagents import launch_gradio_demo
from typing import Optional
from tool import TestTool
tool = TestTool()
launch_gradio_demo(tool)
"""
)
def test_e2e_function_tool_save():
@tool
def test_tool(task: str) -> str:
"""
Test tool description
Args:
task: tool input
"""
import IPython # noqa: F401
return task
with tempfile.TemporaryDirectory() as tmp_dir:
test_tool.save(tmp_dir)
assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"}
assert (
pathlib.Path(tmp_dir, "tool.py").read_text()
== """from smolagents import Tool
from typing import Optional
class SimpleTool(Tool):
name = "test_tool"
description = "Test tool description"
inputs = {"task":{"type":"string","description":"tool input"}}
output_type = "string"
def forward(self, task: str) -> str:
\"""
Test tool description
Args:
task: tool input
\"""
import IPython # noqa: F401
return task"""
)
requirements = set(pathlib.Path(tmp_dir, "requirements.txt").read_text().split())
assert requirements == {"smolagents"} # FIXME: IPython should be in the requirements
assert (
pathlib.Path(tmp_dir, "app.py").read_text()
== """from smolagents import launch_gradio_demo
from typing import Optional
from tool import SimpleTool
tool = SimpleTool()
launch_gradio_demo(tool)
"""
)
def test_e2e_ipython_function_tool_save():
shell = InteractiveShell.instance()
with tempfile.TemporaryDirectory() as tmp_dir:
code_blob = textwrap.dedent(f"""
from smolagents import tool
@tool
def test_tool(task: str) -> str:
\"""
Test tool description
Args:
task: tool input
\"""
import IPython # noqa: F401
return task
test_tool.save("{tmp_dir}")
""")
assert shell.run_cell(code_blob, store_history=True).success
assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"}
assert (
pathlib.Path(tmp_dir, "tool.py").read_text()
== """from smolagents import Tool
from typing import Optional
class SimpleTool(Tool):
name = "test_tool"
description = "Test tool description"
inputs = {"task":{"type":"string","description":"tool input"}}
output_type = "string"
def forward(self, task: str) -> str:
\"""
Test tool description
Args:
task: tool input
\"""
import IPython # noqa: F401
return task"""
)
requirements = set(pathlib.Path(tmp_dir, "requirements.txt").read_text().split())
assert requirements == {"smolagents"} # FIXME: IPython should be in the requirements
assert (
pathlib.Path(tmp_dir, "app.py").read_text()
== """from smolagents import launch_gradio_demo
from typing import Optional
from tool import SimpleTool
tool = SimpleTool()
launch_gradio_demo(tool)
"""
)