Pass tests

This commit is contained in:
Aymeric 2024-12-30 18:03:53 +01:00
parent a50f9284b3
commit 54d6857da2
15 changed files with 52 additions and 43 deletions

View File

@ -8,19 +8,19 @@ extra_quality_checks:
python utils/check_copies.py python utils/check_copies.py
python utils/check_dummies.py python utils/check_dummies.py
python utils/check_repo.py python utils/check_repo.py
doc-builder style agents docs/source --max_len 119 doc-builder style smolagents docs/source --max_len 119
# this target runs checks on all files # this target runs checks on all files
quality: quality:
ruff check $(check_dirs) ruff check $(check_dirs)
ruff format --check $(check_dirs) ruff format --check $(check_dirs)
doc-builder style agents docs/source --max_len 119 --check_only doc-builder style smolagents docs/source --max_len 119 --check_only
# Format source code automatically and check is there are any problems left that need manual fixing # Format source code automatically and check is there are any problems left that need manual fixing
style: style:
ruff check $(check_dirs) --fix ruff check $(check_dirs) --fix
ruff format $(check_dirs) ruff format $(check_dirs)
doc-builder style agents docs/source --max_len 119 doc-builder style smolagents docs/source --max_len 119
# Run tests for the library # Run tests for the library
test_big_modeling: test_big_modeling:

View File

@ -370,8 +370,7 @@ class MultiStepAgent:
try: try:
return self.model(self.input_messages) return self.model(self.input_messages)
except Exception as e: except Exception as e:
error_msg = f"Error in generating final LLM output:\n{e}" return f"Error in generating final LLM output:\n{e}"
raise AgentGenerationError(error_msg)
def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any: def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any:
""" """

View File

@ -153,8 +153,8 @@ class DuckDuckGoSearchTool(Tool):
} }
output_type = "any" output_type = "any"
def __init__(self): def __init__(self, **kwargs):
super().__init__(self) super().__init__(self, **kwargs)
try: try:
from duckduckgo_search import DDGS from duckduckgo_search import DDGS
except ImportError: except ImportError:

View File

@ -410,7 +410,12 @@ class TransformersModel(Model):
class LiteLLMModel(Model): class LiteLLMModel(Model):
def __init__(self, model_id="anthropic/claude-3-5-sonnet-20240620", api_base=None, api_key=None): def __init__(
self,
model_id="anthropic/claude-3-5-sonnet-20240620",
api_base=None,
api_key=None,
):
super().__init__() super().__init__()
self.model_id = model_id self.model_id = model_id
# IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs # IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs

View File

@ -517,5 +517,6 @@ __all__ = [
"PLAN_UPDATE_FINAL_PLAN_REDACTION", "PLAN_UPDATE_FINAL_PLAN_REDACTION",
"SINGLE_STEP_CODE_SYSTEM_PROMPT", "SINGLE_STEP_CODE_SYSTEM_PROMPT",
"CODE_SYSTEM_PROMPT", "CODE_SYSTEM_PROMPT",
"TOOL_CALLING_SYSTEM_PROMPT",
"MANAGED_AGENT_PROMPT", "MANAGED_AGENT_PROMPT",
] ]

View File

@ -7,8 +7,6 @@ from .utils import BASE_BUILTIN_MODULES
_BUILTIN_NAMES = set(vars(builtins)) _BUILTIN_NAMES = set(vars(builtins))
IMPORTED_PACKAGES = BASE_BUILTIN_MODULES
class MethodChecker(ast.NodeVisitor): class MethodChecker(ast.NodeVisitor):
""" """
@ -91,7 +89,7 @@ class MethodChecker(ast.NodeVisitor):
if isinstance(node.ctx, ast.Load): if isinstance(node.ctx, ast.Load):
if not ( if not (
node.id in _BUILTIN_NAMES node.id in _BUILTIN_NAMES
or node.id in IMPORTED_PACKAGES or node.id in BASE_BUILTIN_MODULES
or node.id in self.arg_names or node.id in self.arg_names
or node.id == "self" or node.id == "self"
or node.id in self.class_attributes or node.id in self.class_attributes
@ -105,7 +103,7 @@ class MethodChecker(ast.NodeVisitor):
if isinstance(node.func, ast.Name): if isinstance(node.func, ast.Name):
if not ( if not (
node.func.id in _BUILTIN_NAMES node.func.id in _BUILTIN_NAMES
or node.func.id in IMPORTED_PACKAGES or node.func.id in BASE_BUILTIN_MODULES
or node.func.id in self.arg_names or node.func.id in self.arg_names
or node.func.id == "self" or node.func.id == "self"
or node.func.id in self.class_attributes or node.func.id in self.class_attributes

View File

@ -854,7 +854,7 @@ def load_tool(
main_module = importlib.import_module("smolagents") main_module = importlib.import_module("smolagents")
tools_module = main_module tools_module = main_module
tool_class = getattr(tools_module, tool_class_name) tool_class = getattr(tools_module, tool_class_name)
return tool_class(model_repo_id, token=token, **kwargs) return tool_class(token=token, **kwargs)
else: else:
return Tool.from_hub( return Tool.from_hub(
task_or_repo_id, task_or_repo_id,

View File

@ -104,7 +104,9 @@ class AgentImage(AgentType, ImageType):
self._raw = None self._raw = None
self._tensor = None self._tensor = None
if isinstance(value, ImageType): if isinstance(value, AgentImage):
self._raw, self._path, self._tensor = value._raw, value._path, value._tensor
elif isinstance(value, ImageType):
self._raw = value self._raw = value
elif isinstance(value, bytes): elif isinstance(value, bytes):
self._raw = Image.open(BytesIO(value)) self._raw = Image.open(BytesIO(value))

View File

@ -123,6 +123,7 @@ class TestDocs:
"ToolCollection", "ToolCollection",
"image_generation_tool", "image_generation_tool",
"from_langchain", "from_langchain",
"while llm_should_continue(memory):",
] ]
code_blocks = [ code_blocks = [
block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", self.hf_token).replace( block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", self.hf_token).replace(

View File

@ -59,6 +59,6 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
def test_agent_type_output(self): def test_agent_type_output(self):
inputs = self.create_inputs() inputs = self.create_inputs()
for input_type, input in inputs.items(): for input_type, input in inputs.items():
output = self.tool(**input) output = self.tool(**input, sanitize_inputs_outputs=True)
agent_type = AGENT_TYPE_MAPPING[input_type] agent_type = AGENT_TYPE_MAPPING[input_type]
self.assertTrue(isinstance(output, agent_type)) self.assertTrue(isinstance(output, agent_type))

View File

@ -55,8 +55,8 @@ final_answer('This is the final answer.')
self.last_input_token_count = 10 self.last_input_token_count = 10
self.last_output_token_count = 20 self.last_output_token_count = 20
def __call__(self, prompt, **kwargs): def get_tool_call(self, prompt, **kwargs):
return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}' return "final_answer", {"answer": "image"}, "fake_id"
agent = ToolCallingAgent( agent = ToolCallingAgent(
tools=[], tools=[],
@ -96,18 +96,21 @@ final_answer('This is the final answer.')
self.last_output_token_count = 20 self.last_output_token_count = 20
def __call__(self, prompt, **kwargs): def __call__(self, prompt, **kwargs):
raise AgentError self.last_input_token_count = 10
self.last_output_token_count = 0
raise Exception("Cannot generate")
agent = CodeAgent( agent = CodeAgent(
tools=[], tools=[],
model=FakeLLMModel(), model=FakeLLMModel(),
max_iterations=1, max_iterations=1,
) )
agent.run("Fake task") agent.run("Fake task")
self.assertEqual(agent.monitor.total_input_token_count, 20) self.assertEqual(
self.assertEqual(agent.monitor.total_output_token_count, 40) agent.monitor.total_input_token_count, 20
) # Should have done two monitoring callbacks
self.assertEqual(agent.monitor.total_output_token_count, 0)
def test_streaming_agent_text_output(self): def test_streaming_agent_text_output(self):
def dummy_model(prompt, **kwargs): def dummy_model(prompt, **kwargs):
@ -132,14 +135,16 @@ final_answer('This is the final answer.')
self.assertIn("This is the final answer.", final_message.content) self.assertIn("This is the final answer.", final_message.content)
def test_streaming_agent_image_output(self): def test_streaming_agent_image_output(self):
def dummy_model(prompt, **kwargs): class FakeLLM:
return ( def __init__(self):
'Action:{"action": "final_answer", "action_input": {"answer": "image"}}' pass
)
def get_tool_call(self, messages, **kwargs):
return "final_answer", {"answer": "image"}, "fake_id"
agent = ToolCallingAgent( agent = ToolCallingAgent(
tools=[], tools=[],
model=dummy_model, model=FakeLLM(),
max_iterations=1, max_iterations=1,
) )
@ -148,7 +153,7 @@ final_answer('This is the final answer.')
stream_to_gradio( stream_to_gradio(
agent, agent,
task="Test task", task="Test task",
image=AgentImage(value="path.png"), additional_args=dict(image=AgentImage(value="path.png")),
test_mode=True, test_mode=True,
) )
) )

View File

@ -41,17 +41,16 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
def test_exact_match_arg(self): def test_exact_match_arg(self):
result = self.tool("(2 / 2) * 4") result = self.tool("(2 / 2) * 4")
self.assertEqual(result, "4.0") self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
def test_exact_match_kwarg(self): def test_exact_match_kwarg(self):
result = self.tool(code="(2 / 2) * 4") result = self.tool(code="(2 / 2) * 4")
self.assertEqual(result, "4.0") self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
def test_agent_type_output(self): def test_agent_type_output(self):
inputs = ["2 * 2"] inputs = ["2 * 2"]
output = self.tool(*inputs) output = self.tool(*inputs, sanitize_inputs_outputs=True)
output_type = AGENT_TYPE_MAPPING[self.tool.output_type] output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
print("OKK", type(output), output_type, AGENT_TYPE_MAPPING)
self.assertTrue(isinstance(output, output_type)) self.assertTrue(isinstance(output, output_type))
def test_agent_types_inputs(self): def test_agent_types_inputs(self):
@ -71,7 +70,7 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input)) _inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
# Should not raise an error # Should not raise an error
output = self.tool(*inputs) output = self.tool(*inputs, sanitize_inputs_outputs=True)
output_type = AGENT_TYPE_MAPPING[self.tool.output_type] output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
self.assertTrue(isinstance(output, output_type)) self.assertTrue(isinstance(output, output_type))

View File

@ -27,4 +27,4 @@ class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin):
def test_exact_match_arg(self): def test_exact_match_arg(self):
result = self.tool("Agents") result = self.tool("Agents")
assert isinstance(result, list) and isinstance(result[0], dict) assert isinstance(result, str)

View File

@ -93,7 +93,7 @@ class ToolTesterMixin:
def test_agent_type_output(self): def test_agent_type_output(self):
inputs = create_inputs(self.tool.inputs) inputs = create_inputs(self.tool.inputs)
output = self.tool(**inputs) output = self.tool(**inputs, sanitize_inputs_outputs=True)
if self.tool.output_type != "any": if self.tool.output_type != "any":
agent_type = AGENT_TYPE_MAPPING[self.tool.output_type] agent_type = AGENT_TYPE_MAPPING[self.tool.output_type]
self.assertTrue(isinstance(output, agent_type)) self.assertTrue(isinstance(output, agent_type))
@ -164,20 +164,20 @@ class ToolTests(unittest.TestCase):
assert coolfunc.output_type == "number" assert coolfunc.output_type == "number"
assert "docstring has no description for the argument" in str(e) assert "docstring has no description for the argument" in str(e)
def test_tool_definition_raises_error_imports_outside_function(self): def test_saving_tool_raises_error_imports_outside_function(self):
with pytest.raises(Exception) as e: with pytest.raises(Exception) as e:
from datetime import datetime import numpy as np
@tool @tool
def get_current_time() -> str: def get_current_time() -> str:
""" """
Gets the current time. Gets the current time.
""" """
return str(datetime.now()) return str(np.random.random())
get_current_time.save("output") get_current_time.save("output")
assert "datetime" in str(e) assert "np" in str(e)
# Also test with classic definition # Also test with classic definition
with pytest.raises(Exception) as e: with pytest.raises(Exception) as e:
@ -189,12 +189,12 @@ class ToolTests(unittest.TestCase):
output_type = "string" output_type = "string"
def forward(self): def forward(self):
return str(datetime.now()) return str(np.random.random())
get_current_time = GetCurrentTimeTool() get_current_time = GetCurrentTimeTool()
get_current_time.save("output") get_current_time.save("output")
assert "datetime" in str(e) assert "np" in str(e)
def test_tool_definition_raises_no_error_imports_in_function(self): def test_tool_definition_raises_no_error_imports_in_function(self):
@tool @tool

View File

@ -20,7 +20,6 @@ from pathlib import Path
from smolagents.types import AgentAudio, AgentImage, AgentText from smolagents.types import AgentAudio, AgentImage, AgentText
from transformers.testing_utils import ( from transformers.testing_utils import (
get_tests_dir,
require_soundfile, require_soundfile,
require_torch, require_torch,
require_vision, require_vision,
@ -91,7 +90,7 @@ class AgentImageTests(unittest.TestCase):
self.assertTrue(os.path.exists(path)) self.assertTrue(os.path.exists(path))
def test_from_string(self): def test_from_string(self):
path = Path(get_tests_dir("fixtures/")) / "000000039769.png" path = Path("tests/fixtures/000000039769.png")
image = Image.open(path) image = Image.open(path)
agent_type = AgentImage(path) agent_type = AgentImage(path)
@ -103,7 +102,7 @@ class AgentImageTests(unittest.TestCase):
self.assertTrue(os.path.exists(path)) self.assertTrue(os.path.exists(path))
def test_from_image(self): def test_from_image(self):
path = Path(get_tests_dir("fixtures/")) / "000000039769.png" path = Path("tests/fixtures/000000039769.png")
image = Image.open(path) image = Image.open(path)
agent_type = AgentImage(image) agent_type = AgentImage(image)