diff --git a/pyproject.toml b/pyproject.toml index 0e6645b..873d0b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ all = [ "smolagents[accelerate,audio,e2b,gradio,litellm,mcp,openai,transformers]", ] test = [ + "ipython>=8.31.0", # for interactive environment tests "pytest>=8.1.0", "python-dotenv>=1.0.1", # For test_all_docs "smolagents[all]", diff --git a/src/smolagents/tool_validation.py b/src/smolagents/tool_validation.py index 960331f..9ac157c 100644 --- a/src/smolagents/tool_validation.py +++ b/src/smolagents/tool_validation.py @@ -1,10 +1,9 @@ import ast import builtins import inspect -import textwrap from typing import Set -from .utils import BASE_BUILTIN_MODULES +from .utils import BASE_BUILTIN_MODULES, get_source _BUILTIN_NAMES = set(vars(builtins)) @@ -132,7 +131,7 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None: """ errors = [] - source = textwrap.dedent(inspect.getsource(cls)) + source = get_source(cls) tree = ast.parse(source) diff --git a/src/smolagents/tools.py b/src/smolagents/tools.py index e7725e8..b73bc6f 100644 --- a/src/smolagents/tools.py +++ b/src/smolagents/tools.py @@ -46,7 +46,7 @@ from ._function_type_hints_utils import ( ) from .tool_validation import MethodChecker, validate_tool_attributes from .types import handle_agent_input_types, handle_agent_output_types -from .utils import _is_package_available, _is_pillow_available, instance_to_source +from .utils import _is_package_available, _is_pillow_available, get_source, instance_to_source logger = logging.getLogger(__name__) @@ -220,8 +220,8 @@ class Tool: # Save tool file if type(self).__name__ == "SimpleTool": # Check that imports are self-contained - source_code = inspect.getsource(self.forward).replace("@tool", "") - forward_node = ast.parse(textwrap.dedent(source_code)) + source_code = get_source(self.forward).replace("@tool", "") + forward_node = ast.parse(source_code) # If tool was created using '@tool' decorator, it has only a forward pass, so it's simpler to just get its code method_checker = MethodChecker(set()) method_checker.visit(forward_node) @@ -229,7 +229,7 @@ class Tool: if len(method_checker.errors) > 0: raise (ValueError("\n".join(method_checker.errors))) - forward_source_code = inspect.getsource(self.forward) + forward_source_code = get_source(self.forward) tool_code = textwrap.dedent( f""" from smolagents import Tool diff --git a/src/smolagents/utils.py b/src/smolagents/utils.py index ad80b9d..f8e34a0 100644 --- a/src/smolagents/utils.py +++ b/src/smolagents/utils.py @@ -20,6 +20,7 @@ import importlib.util import inspect import json import re +import textwrap import types from enum import IntEnum from functools import lru_cache @@ -221,7 +222,7 @@ def get_method_source(method): """Get source code for a method, including bound methods.""" if isinstance(method, types.MethodType): method = method.__func__ - return inspect.getsource(method).strip() + return get_source(method) def is_same_method(method1, method2): @@ -295,7 +296,7 @@ def instance_to_source(instance, base_cls=None): } for name, method in methods.items(): - method_source = inspect.getsource(method) + method_source = get_source(method) # Clean up the indentation method_lines = method_source.split("\n") first_line = method_lines[0] @@ -330,4 +331,56 @@ def instance_to_source(instance, base_cls=None): return "\n".join(final_lines) +def get_source(obj) -> str: + """Get the source code of a class or callable object (e.g.: function, method). + First attempts to get the source code using `inspect.getsource`. + In a dynamic environment (e.g.: Jupyter, IPython), if this fails, + falls back to retrieving the source code from the current interactive shell session. + + Args: + obj: A class or callable object (e.g.: function, method) + + Returns: + str: The source code of the object, dedented and stripped + + Raises: + TypeError: If object is not a class or callable + OSError: If source code cannot be retrieved from any source + ValueError: If source cannot be found in IPython history + + Note: + TODO: handle Python standard REPL + """ + if not (isinstance(obj, type) or callable(obj)): + raise TypeError(f"Expected class or callable, got {type(obj)}") + + inspect_error = None + try: + return textwrap.dedent(inspect.getsource(obj)).strip() + except OSError as e: + # let's keep track of the exception to raise it if all further methods fail + inspect_error = e + try: + import IPython + + shell = IPython.get_ipython() + if not shell: + raise ImportError("No active IPython shell found") + all_cells = "\n".join(shell.user_ns.get("In", [])).strip() + if not all_cells: + raise ValueError("No code cells found in IPython session") + + tree = ast.parse(all_cells) + for node in ast.walk(tree): + if isinstance(node, (ast.ClassDef, ast.FunctionDef)) and node.name == obj.__name__: + return textwrap.dedent("\n".join(all_cells.split("\n")[node.lineno - 1 : node.end_lineno])).strip() + raise ValueError(f"Could not find source code for {obj.__name__} in IPython history") + except ImportError: + # IPython is not available, let's just raise the original inspect error + raise inspect_error + except ValueError as e: + # IPython is available but we couldn't find the source code, let's raise the error + raise e from inspect_error + + __all__ = ["AgentError"] diff --git a/tests/test_utils.py b/tests/test_utils.py index 91064e8..31a8a68 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -12,11 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect +import os +import pathlib +import tempfile +import textwrap import unittest import pytest +from IPython.core.interactiveshell import InteractiveShell -from smolagents.utils import parse_code_blobs +from smolagents import Tool +from smolagents.tools import tool +from smolagents.utils import get_source, parse_code_blobs class AgentTextTests(unittest.TestCase): @@ -58,3 +66,287 @@ def multiply(a, b): return a * b""" result = parse_code_blobs(test_input) assert result == expected_output + + +@pytest.fixture(scope="function") +def ipython_shell(): + """Reset IPython shell before and after each test.""" + shell = InteractiveShell.instance() + shell.reset() # Clean before test + yield shell + shell.reset() # Clean after test + + +@pytest.mark.parametrize( + "obj_name, code_blob", + [ + ("test_func", "def test_func():\n return 42"), + ("TestClass", "class TestClass:\n ..."), + ], +) +def test_get_source_ipython(ipython_shell, obj_name, code_blob): + ipython_shell.run_cell(code_blob, store_history=True) + obj = ipython_shell.user_ns[obj_name] + assert get_source(obj) == code_blob + + +def test_get_source_standard_class(): + class TestClass: ... + + source = get_source(TestClass) + assert source == "class TestClass: ..." + assert source == textwrap.dedent(inspect.getsource(TestClass)).strip() + + +def test_get_source_standard_function(): + def test_func(): ... + + source = get_source(test_func) + assert source == "def test_func(): ..." + assert source == textwrap.dedent(inspect.getsource(test_func)).strip() + + +def test_get_source_ipython_errors_empty_cells(ipython_shell): + test_code = textwrap.dedent("""class TestClass:\n ...""").strip() + ipython_shell.user_ns["In"] = [""] + exec(test_code) + with pytest.raises(ValueError, match="No code cells found in IPython session"): + get_source(locals()["TestClass"]) + + +def test_get_source_ipython_errors_definition_not_found(ipython_shell): + test_code = textwrap.dedent("""class TestClass:\n ...""").strip() + ipython_shell.user_ns["In"] = ["", "print('No class definition here')"] + exec(test_code) + with pytest.raises(ValueError, match="Could not find source code for TestClass in IPython history"): + get_source(locals()["TestClass"]) + + +def test_get_source_ipython_errors_type_error(): + with pytest.raises(TypeError, match="Expected class or callable"): + get_source(None) + + +def test_e2e_class_tool_save(): + class TestTool(Tool): + name = "test_tool" + description = "Test tool description" + inputs = { + "task": { + "type": "string", + "description": "tool input", + } + } + output_type = "string" + + def forward(self, task: str): + import IPython # noqa: F401 + + return task + + test_tool = TestTool() + with tempfile.TemporaryDirectory() as tmp_dir: + test_tool.save(tmp_dir) + assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"} + assert ( + pathlib.Path(tmp_dir, "tool.py").read_text() + == """from smolagents.tools import Tool +import IPython + +class TestTool(Tool): + name = "test_tool" + description = "Test tool description" + inputs = {'task': {'type': 'string', 'description': 'tool input'}} + output_type = "string" + + def forward(self, task: str): + import IPython # noqa: F401 + + return task + + def __init__(self, *args, **kwargs): + self.is_initialized = False +""" + ) + requirements = set(pathlib.Path(tmp_dir, "requirements.txt").read_text().split()) + assert requirements == {"IPython", "smolagents"} + assert ( + pathlib.Path(tmp_dir, "app.py").read_text() + == """from smolagents import launch_gradio_demo +from typing import Optional +from tool import TestTool + +tool = TestTool() + +launch_gradio_demo(tool) +""" + ) + + +def test_e2e_ipython_class_tool_save(): + shell = InteractiveShell.instance() + with tempfile.TemporaryDirectory() as tmp_dir: + code_blob = textwrap.dedent(f""" + from smolagents.tools import Tool + class TestTool(Tool): + name = "test_tool" + description = "Test tool description" + inputs = {{"task": {{"type": "string", + "description": "tool input", + }} + }} + output_type = "string" + + def forward(self, task: str): + import IPython # noqa: F401 + + return task + TestTool().save("{tmp_dir}") + """) + assert shell.run_cell(code_blob, store_history=True).success + assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"} + assert ( + pathlib.Path(tmp_dir, "tool.py").read_text() + == """from smolagents.tools import Tool +import IPython + +class TestTool(Tool): + name = "test_tool" + description = "Test tool description" + inputs = {'task': {'type': 'string', 'description': 'tool input'}} + output_type = "string" + + def forward(self, task: str): + import IPython # noqa: F401 + + return task + + def __init__(self, *args, **kwargs): + self.is_initialized = False +""" + ) + requirements = set(pathlib.Path(tmp_dir, "requirements.txt").read_text().split()) + assert requirements == {"IPython", "smolagents"} + assert ( + pathlib.Path(tmp_dir, "app.py").read_text() + == """from smolagents import launch_gradio_demo +from typing import Optional +from tool import TestTool + +tool = TestTool() + +launch_gradio_demo(tool) +""" + ) + + +def test_e2e_function_tool_save(): + @tool + def test_tool(task: str) -> str: + """ + Test tool description + + Args: + task: tool input + """ + import IPython # noqa: F401 + + return task + + with tempfile.TemporaryDirectory() as tmp_dir: + test_tool.save(tmp_dir) + assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"} + assert ( + pathlib.Path(tmp_dir, "tool.py").read_text() + == """from smolagents import Tool +from typing import Optional + +class SimpleTool(Tool): + name = "test_tool" + description = "Test tool description" + inputs = {"task":{"type":"string","description":"tool input"}} + output_type = "string" + + def forward(self, task: str) -> str: + \""" + Test tool description + + Args: + task: tool input + \""" + import IPython # noqa: F401 + + return task""" + ) + requirements = set(pathlib.Path(tmp_dir, "requirements.txt").read_text().split()) + assert requirements == {"smolagents"} # FIXME: IPython should be in the requirements + assert ( + pathlib.Path(tmp_dir, "app.py").read_text() + == """from smolagents import launch_gradio_demo +from typing import Optional +from tool import SimpleTool + +tool = SimpleTool() + +launch_gradio_demo(tool) +""" + ) + + +def test_e2e_ipython_function_tool_save(): + shell = InteractiveShell.instance() + with tempfile.TemporaryDirectory() as tmp_dir: + code_blob = textwrap.dedent(f""" + from smolagents import tool + + @tool + def test_tool(task: str) -> str: + \""" + Test tool description + + Args: + task: tool input + \""" + import IPython # noqa: F401 + + return task + + test_tool.save("{tmp_dir}") + """) + assert shell.run_cell(code_blob, store_history=True).success + assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"} + assert ( + pathlib.Path(tmp_dir, "tool.py").read_text() + == """from smolagents import Tool +from typing import Optional + +class SimpleTool(Tool): + name = "test_tool" + description = "Test tool description" + inputs = {"task":{"type":"string","description":"tool input"}} + output_type = "string" + + def forward(self, task: str) -> str: + \""" + Test tool description + + Args: + task: tool input + \""" + import IPython # noqa: F401 + + return task""" + ) + requirements = set(pathlib.Path(tmp_dir, "requirements.txt").read_text().split()) + assert requirements == {"smolagents"} # FIXME: IPython should be in the requirements + assert ( + pathlib.Path(tmp_dir, "app.py").read_text() + == """from smolagents import launch_gradio_demo +from typing import Optional +from tool import SimpleTool + +tool = SimpleTool() + +launch_gradio_demo(tool) +""" + )