diff --git a/Makefile b/Makefile index b7036c7..a24c1ae 100644 --- a/Makefile +++ b/Makefile @@ -8,19 +8,19 @@ extra_quality_checks: python utils/check_copies.py python utils/check_dummies.py python utils/check_repo.py - doc-builder style agents docs/source --max_len 119 + doc-builder style smolagents docs/source --max_len 119 # this target runs checks on all files quality: ruff check $(check_dirs) ruff format --check $(check_dirs) - doc-builder style agents docs/source --max_len 119 --check_only + doc-builder style smolagents docs/source --max_len 119 --check_only # Format source code automatically and check is there are any problems left that need manual fixing style: ruff check $(check_dirs) --fix ruff format $(check_dirs) - doc-builder style agents docs/source --max_len 119 + doc-builder style smolagents docs/source --max_len 119 # Run tests for the library test_big_modeling: diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index 9dbe039..a311876 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -370,8 +370,7 @@ class MultiStepAgent: try: return self.model(self.input_messages) except Exception as e: - error_msg = f"Error in generating final LLM output:\n{e}" - raise AgentGenerationError(error_msg) + return f"Error in generating final LLM output:\n{e}" def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any: """ diff --git a/src/smolagents/default_tools.py b/src/smolagents/default_tools.py index afa2ae1..d2f05c2 100644 --- a/src/smolagents/default_tools.py +++ b/src/smolagents/default_tools.py @@ -153,8 +153,8 @@ class DuckDuckGoSearchTool(Tool): } output_type = "any" - def __init__(self): - super().__init__(self) + def __init__(self, **kwargs): + super().__init__(self, **kwargs) try: from duckduckgo_search import DDGS except ImportError: diff --git a/src/smolagents/models.py b/src/smolagents/models.py index deeb7fc..6fc8dbb 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -410,7 +410,12 @@ class TransformersModel(Model): class LiteLLMModel(Model): - def __init__(self, model_id="anthropic/claude-3-5-sonnet-20240620", api_base=None, api_key=None): + def __init__( + self, + model_id="anthropic/claude-3-5-sonnet-20240620", + api_base=None, + api_key=None, + ): super().__init__() self.model_id = model_id # IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs diff --git a/src/smolagents/prompts.py b/src/smolagents/prompts.py index 0d0afb8..e85ab19 100644 --- a/src/smolagents/prompts.py +++ b/src/smolagents/prompts.py @@ -517,5 +517,6 @@ __all__ = [ "PLAN_UPDATE_FINAL_PLAN_REDACTION", "SINGLE_STEP_CODE_SYSTEM_PROMPT", "CODE_SYSTEM_PROMPT", + "TOOL_CALLING_SYSTEM_PROMPT", "MANAGED_AGENT_PROMPT", ] diff --git a/src/smolagents/tool_validation.py b/src/smolagents/tool_validation.py index 027eb02..e410e89 100644 --- a/src/smolagents/tool_validation.py +++ b/src/smolagents/tool_validation.py @@ -7,8 +7,6 @@ from .utils import BASE_BUILTIN_MODULES _BUILTIN_NAMES = set(vars(builtins)) -IMPORTED_PACKAGES = BASE_BUILTIN_MODULES - class MethodChecker(ast.NodeVisitor): """ @@ -91,7 +89,7 @@ class MethodChecker(ast.NodeVisitor): if isinstance(node.ctx, ast.Load): if not ( node.id in _BUILTIN_NAMES - or node.id in IMPORTED_PACKAGES + or node.id in BASE_BUILTIN_MODULES or node.id in self.arg_names or node.id == "self" or node.id in self.class_attributes @@ -105,7 +103,7 @@ class MethodChecker(ast.NodeVisitor): if isinstance(node.func, ast.Name): if not ( node.func.id in _BUILTIN_NAMES - or node.func.id in IMPORTED_PACKAGES + or node.func.id in BASE_BUILTIN_MODULES or node.func.id in self.arg_names or node.func.id == "self" or node.func.id in self.class_attributes diff --git a/src/smolagents/tools.py b/src/smolagents/tools.py index 1db46c2..010ea11 100644 --- a/src/smolagents/tools.py +++ b/src/smolagents/tools.py @@ -854,7 +854,7 @@ def load_tool( main_module = importlib.import_module("smolagents") tools_module = main_module tool_class = getattr(tools_module, tool_class_name) - return tool_class(model_repo_id, token=token, **kwargs) + return tool_class(token=token, **kwargs) else: return Tool.from_hub( task_or_repo_id, diff --git a/src/smolagents/types.py b/src/smolagents/types.py index b56502f..d817608 100644 --- a/src/smolagents/types.py +++ b/src/smolagents/types.py @@ -104,7 +104,9 @@ class AgentImage(AgentType, ImageType): self._raw = None self._tensor = None - if isinstance(value, ImageType): + if isinstance(value, AgentImage): + self._raw, self._path, self._tensor = value._raw, value._path, value._tensor + elif isinstance(value, ImageType): self._raw = value elif isinstance(value, bytes): self._raw = Image.open(BytesIO(value)) diff --git a/tests/test_all_docs.py b/tests/test_all_docs.py index d1467ac..0e195cc 100644 --- a/tests/test_all_docs.py +++ b/tests/test_all_docs.py @@ -123,6 +123,7 @@ class TestDocs: "ToolCollection", "image_generation_tool", "from_langchain", + "while llm_should_continue(memory):", ] code_blocks = [ block.replace("", self.hf_token).replace( diff --git a/tests/test_final_answer.py b/tests/test_final_answer.py index 8e79774..65a9728 100644 --- a/tests/test_final_answer.py +++ b/tests/test_final_answer.py @@ -59,6 +59,6 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin): def test_agent_type_output(self): inputs = self.create_inputs() for input_type, input in inputs.items(): - output = self.tool(**input) + output = self.tool(**input, sanitize_inputs_outputs=True) agent_type = AGENT_TYPE_MAPPING[input_type] self.assertTrue(isinstance(output, agent_type)) diff --git a/tests/test_monitoring.py b/tests/test_monitoring.py index a0217fc..4c04948 100644 --- a/tests/test_monitoring.py +++ b/tests/test_monitoring.py @@ -55,8 +55,8 @@ final_answer('This is the final answer.') self.last_input_token_count = 10 self.last_output_token_count = 20 - def __call__(self, prompt, **kwargs): - return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}' + def get_tool_call(self, prompt, **kwargs): + return "final_answer", {"answer": "image"}, "fake_id" agent = ToolCallingAgent( tools=[], @@ -96,18 +96,21 @@ final_answer('This is the final answer.') self.last_output_token_count = 20 def __call__(self, prompt, **kwargs): - raise AgentError + self.last_input_token_count = 10 + self.last_output_token_count = 0 + raise Exception("Cannot generate") agent = CodeAgent( tools=[], model=FakeLLMModel(), max_iterations=1, ) - agent.run("Fake task") - self.assertEqual(agent.monitor.total_input_token_count, 20) - self.assertEqual(agent.monitor.total_output_token_count, 40) + self.assertEqual( + agent.monitor.total_input_token_count, 20 + ) # Should have done two monitoring callbacks + self.assertEqual(agent.monitor.total_output_token_count, 0) def test_streaming_agent_text_output(self): def dummy_model(prompt, **kwargs): @@ -132,14 +135,16 @@ final_answer('This is the final answer.') self.assertIn("This is the final answer.", final_message.content) def test_streaming_agent_image_output(self): - def dummy_model(prompt, **kwargs): - return ( - 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}' - ) + class FakeLLM: + def __init__(self): + pass + + def get_tool_call(self, messages, **kwargs): + return "final_answer", {"answer": "image"}, "fake_id" agent = ToolCallingAgent( tools=[], - model=dummy_model, + model=FakeLLM(), max_iterations=1, ) @@ -148,7 +153,7 @@ final_answer('This is the final answer.') stream_to_gradio( agent, task="Test task", - image=AgentImage(value="path.png"), + additional_args=dict(image=AgentImage(value="path.png")), test_mode=True, ) ) diff --git a/tests/test_python_interpreter.py b/tests/test_python_interpreter.py index 61683dd..03e3075 100644 --- a/tests/test_python_interpreter.py +++ b/tests/test_python_interpreter.py @@ -41,17 +41,16 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin): def test_exact_match_arg(self): result = self.tool("(2 / 2) * 4") - self.assertEqual(result, "4.0") + self.assertEqual(result, "Stdout:\n\nOutput: 4.0") def test_exact_match_kwarg(self): result = self.tool(code="(2 / 2) * 4") - self.assertEqual(result, "4.0") + self.assertEqual(result, "Stdout:\n\nOutput: 4.0") def test_agent_type_output(self): inputs = ["2 * 2"] - output = self.tool(*inputs) + output = self.tool(*inputs, sanitize_inputs_outputs=True) output_type = AGENT_TYPE_MAPPING[self.tool.output_type] - print("OKK", type(output), output_type, AGENT_TYPE_MAPPING) self.assertTrue(isinstance(output, output_type)) def test_agent_types_inputs(self): @@ -71,7 +70,7 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin): _inputs.append(AGENT_TYPE_MAPPING[input_type](_input)) # Should not raise an error - output = self.tool(*inputs) + output = self.tool(*inputs, sanitize_inputs_outputs=True) output_type = AGENT_TYPE_MAPPING[self.tool.output_type] self.assertTrue(isinstance(output, output_type)) diff --git a/tests/test_search.py b/tests/test_search.py index 9660d4f..488b97b 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -27,4 +27,4 @@ class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin): def test_exact_match_arg(self): result = self.tool("Agents") - assert isinstance(result, list) and isinstance(result[0], dict) + assert isinstance(result, str) diff --git a/tests/test_tools.py b/tests/test_tools.py index 87f9fda..9e8e3df 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -93,7 +93,7 @@ class ToolTesterMixin: def test_agent_type_output(self): inputs = create_inputs(self.tool.inputs) - output = self.tool(**inputs) + output = self.tool(**inputs, sanitize_inputs_outputs=True) if self.tool.output_type != "any": agent_type = AGENT_TYPE_MAPPING[self.tool.output_type] self.assertTrue(isinstance(output, agent_type)) @@ -164,20 +164,20 @@ class ToolTests(unittest.TestCase): assert coolfunc.output_type == "number" assert "docstring has no description for the argument" in str(e) - def test_tool_definition_raises_error_imports_outside_function(self): + def test_saving_tool_raises_error_imports_outside_function(self): with pytest.raises(Exception) as e: - from datetime import datetime + import numpy as np @tool def get_current_time() -> str: """ Gets the current time. """ - return str(datetime.now()) + return str(np.random.random()) get_current_time.save("output") - assert "datetime" in str(e) + assert "np" in str(e) # Also test with classic definition with pytest.raises(Exception) as e: @@ -189,12 +189,12 @@ class ToolTests(unittest.TestCase): output_type = "string" def forward(self): - return str(datetime.now()) + return str(np.random.random()) get_current_time = GetCurrentTimeTool() get_current_time.save("output") - assert "datetime" in str(e) + assert "np" in str(e) def test_tool_definition_raises_no_error_imports_in_function(self): @tool diff --git a/tests/test_types.py b/tests/test_types.py index ee2ec66..8026a57 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -20,7 +20,6 @@ from pathlib import Path from smolagents.types import AgentAudio, AgentImage, AgentText from transformers.testing_utils import ( - get_tests_dir, require_soundfile, require_torch, require_vision, @@ -91,7 +90,7 @@ class AgentImageTests(unittest.TestCase): self.assertTrue(os.path.exists(path)) def test_from_string(self): - path = Path(get_tests_dir("fixtures/")) / "000000039769.png" + path = Path("tests/fixtures/000000039769.png") image = Image.open(path) agent_type = AgentImage(path) @@ -103,7 +102,7 @@ class AgentImageTests(unittest.TestCase): self.assertTrue(os.path.exists(path)) def test_from_image(self): - path = Path(get_tests_dir("fixtures/")) / "000000039769.png" + path = Path("tests/fixtures/000000039769.png") image = Image.open(path) agent_type = AgentImage(image)