Pass tests
This commit is contained in:
parent
a50f9284b3
commit
54d6857da2
6
Makefile
6
Makefile
|
@ -8,19 +8,19 @@ extra_quality_checks:
|
||||||
python utils/check_copies.py
|
python utils/check_copies.py
|
||||||
python utils/check_dummies.py
|
python utils/check_dummies.py
|
||||||
python utils/check_repo.py
|
python utils/check_repo.py
|
||||||
doc-builder style agents docs/source --max_len 119
|
doc-builder style smolagents docs/source --max_len 119
|
||||||
|
|
||||||
# this target runs checks on all files
|
# this target runs checks on all files
|
||||||
quality:
|
quality:
|
||||||
ruff check $(check_dirs)
|
ruff check $(check_dirs)
|
||||||
ruff format --check $(check_dirs)
|
ruff format --check $(check_dirs)
|
||||||
doc-builder style agents docs/source --max_len 119 --check_only
|
doc-builder style smolagents docs/source --max_len 119 --check_only
|
||||||
|
|
||||||
# Format source code automatically and check is there are any problems left that need manual fixing
|
# Format source code automatically and check is there are any problems left that need manual fixing
|
||||||
style:
|
style:
|
||||||
ruff check $(check_dirs) --fix
|
ruff check $(check_dirs) --fix
|
||||||
ruff format $(check_dirs)
|
ruff format $(check_dirs)
|
||||||
doc-builder style agents docs/source --max_len 119
|
doc-builder style smolagents docs/source --max_len 119
|
||||||
|
|
||||||
# Run tests for the library
|
# Run tests for the library
|
||||||
test_big_modeling:
|
test_big_modeling:
|
||||||
|
|
|
@ -370,8 +370,7 @@ class MultiStepAgent:
|
||||||
try:
|
try:
|
||||||
return self.model(self.input_messages)
|
return self.model(self.input_messages)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error in generating final LLM output:\n{e}"
|
return f"Error in generating final LLM output:\n{e}"
|
||||||
raise AgentGenerationError(error_msg)
|
|
||||||
|
|
||||||
def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any:
|
def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -153,8 +153,8 @@ class DuckDuckGoSearchTool(Tool):
|
||||||
}
|
}
|
||||||
output_type = "any"
|
output_type = "any"
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(self)
|
super().__init__(self, **kwargs)
|
||||||
try:
|
try:
|
||||||
from duckduckgo_search import DDGS
|
from duckduckgo_search import DDGS
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
|
@ -410,7 +410,12 @@ class TransformersModel(Model):
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMModel(Model):
|
class LiteLLMModel(Model):
|
||||||
def __init__(self, model_id="anthropic/claude-3-5-sonnet-20240620", api_base=None, api_key=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id="anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
api_base=None,
|
||||||
|
api_key=None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
# IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs
|
# IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs
|
||||||
|
|
|
@ -517,5 +517,6 @@ __all__ = [
|
||||||
"PLAN_UPDATE_FINAL_PLAN_REDACTION",
|
"PLAN_UPDATE_FINAL_PLAN_REDACTION",
|
||||||
"SINGLE_STEP_CODE_SYSTEM_PROMPT",
|
"SINGLE_STEP_CODE_SYSTEM_PROMPT",
|
||||||
"CODE_SYSTEM_PROMPT",
|
"CODE_SYSTEM_PROMPT",
|
||||||
|
"TOOL_CALLING_SYSTEM_PROMPT",
|
||||||
"MANAGED_AGENT_PROMPT",
|
"MANAGED_AGENT_PROMPT",
|
||||||
]
|
]
|
||||||
|
|
|
@ -7,8 +7,6 @@ from .utils import BASE_BUILTIN_MODULES
|
||||||
|
|
||||||
_BUILTIN_NAMES = set(vars(builtins))
|
_BUILTIN_NAMES = set(vars(builtins))
|
||||||
|
|
||||||
IMPORTED_PACKAGES = BASE_BUILTIN_MODULES
|
|
||||||
|
|
||||||
|
|
||||||
class MethodChecker(ast.NodeVisitor):
|
class MethodChecker(ast.NodeVisitor):
|
||||||
"""
|
"""
|
||||||
|
@ -91,7 +89,7 @@ class MethodChecker(ast.NodeVisitor):
|
||||||
if isinstance(node.ctx, ast.Load):
|
if isinstance(node.ctx, ast.Load):
|
||||||
if not (
|
if not (
|
||||||
node.id in _BUILTIN_NAMES
|
node.id in _BUILTIN_NAMES
|
||||||
or node.id in IMPORTED_PACKAGES
|
or node.id in BASE_BUILTIN_MODULES
|
||||||
or node.id in self.arg_names
|
or node.id in self.arg_names
|
||||||
or node.id == "self"
|
or node.id == "self"
|
||||||
or node.id in self.class_attributes
|
or node.id in self.class_attributes
|
||||||
|
@ -105,7 +103,7 @@ class MethodChecker(ast.NodeVisitor):
|
||||||
if isinstance(node.func, ast.Name):
|
if isinstance(node.func, ast.Name):
|
||||||
if not (
|
if not (
|
||||||
node.func.id in _BUILTIN_NAMES
|
node.func.id in _BUILTIN_NAMES
|
||||||
or node.func.id in IMPORTED_PACKAGES
|
or node.func.id in BASE_BUILTIN_MODULES
|
||||||
or node.func.id in self.arg_names
|
or node.func.id in self.arg_names
|
||||||
or node.func.id == "self"
|
or node.func.id == "self"
|
||||||
or node.func.id in self.class_attributes
|
or node.func.id in self.class_attributes
|
||||||
|
|
|
@ -854,7 +854,7 @@ def load_tool(
|
||||||
main_module = importlib.import_module("smolagents")
|
main_module = importlib.import_module("smolagents")
|
||||||
tools_module = main_module
|
tools_module = main_module
|
||||||
tool_class = getattr(tools_module, tool_class_name)
|
tool_class = getattr(tools_module, tool_class_name)
|
||||||
return tool_class(model_repo_id, token=token, **kwargs)
|
return tool_class(token=token, **kwargs)
|
||||||
else:
|
else:
|
||||||
return Tool.from_hub(
|
return Tool.from_hub(
|
||||||
task_or_repo_id,
|
task_or_repo_id,
|
||||||
|
|
|
@ -104,7 +104,9 @@ class AgentImage(AgentType, ImageType):
|
||||||
self._raw = None
|
self._raw = None
|
||||||
self._tensor = None
|
self._tensor = None
|
||||||
|
|
||||||
if isinstance(value, ImageType):
|
if isinstance(value, AgentImage):
|
||||||
|
self._raw, self._path, self._tensor = value._raw, value._path, value._tensor
|
||||||
|
elif isinstance(value, ImageType):
|
||||||
self._raw = value
|
self._raw = value
|
||||||
elif isinstance(value, bytes):
|
elif isinstance(value, bytes):
|
||||||
self._raw = Image.open(BytesIO(value))
|
self._raw = Image.open(BytesIO(value))
|
||||||
|
|
|
@ -123,6 +123,7 @@ class TestDocs:
|
||||||
"ToolCollection",
|
"ToolCollection",
|
||||||
"image_generation_tool",
|
"image_generation_tool",
|
||||||
"from_langchain",
|
"from_langchain",
|
||||||
|
"while llm_should_continue(memory):",
|
||||||
]
|
]
|
||||||
code_blocks = [
|
code_blocks = [
|
||||||
block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", self.hf_token).replace(
|
block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", self.hf_token).replace(
|
||||||
|
|
|
@ -59,6 +59,6 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
|
||||||
def test_agent_type_output(self):
|
def test_agent_type_output(self):
|
||||||
inputs = self.create_inputs()
|
inputs = self.create_inputs()
|
||||||
for input_type, input in inputs.items():
|
for input_type, input in inputs.items():
|
||||||
output = self.tool(**input)
|
output = self.tool(**input, sanitize_inputs_outputs=True)
|
||||||
agent_type = AGENT_TYPE_MAPPING[input_type]
|
agent_type = AGENT_TYPE_MAPPING[input_type]
|
||||||
self.assertTrue(isinstance(output, agent_type))
|
self.assertTrue(isinstance(output, agent_type))
|
||||||
|
|
|
@ -55,8 +55,8 @@ final_answer('This is the final answer.')
|
||||||
self.last_input_token_count = 10
|
self.last_input_token_count = 10
|
||||||
self.last_output_token_count = 20
|
self.last_output_token_count = 20
|
||||||
|
|
||||||
def __call__(self, prompt, **kwargs):
|
def get_tool_call(self, prompt, **kwargs):
|
||||||
return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
|
return "final_answer", {"answer": "image"}, "fake_id"
|
||||||
|
|
||||||
agent = ToolCallingAgent(
|
agent = ToolCallingAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
|
@ -96,18 +96,21 @@ final_answer('This is the final answer.')
|
||||||
self.last_output_token_count = 20
|
self.last_output_token_count = 20
|
||||||
|
|
||||||
def __call__(self, prompt, **kwargs):
|
def __call__(self, prompt, **kwargs):
|
||||||
raise AgentError
|
self.last_input_token_count = 10
|
||||||
|
self.last_output_token_count = 0
|
||||||
|
raise Exception("Cannot generate")
|
||||||
|
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
model=FakeLLMModel(),
|
model=FakeLLMModel(),
|
||||||
max_iterations=1,
|
max_iterations=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
agent.run("Fake task")
|
agent.run("Fake task")
|
||||||
|
|
||||||
self.assertEqual(agent.monitor.total_input_token_count, 20)
|
self.assertEqual(
|
||||||
self.assertEqual(agent.monitor.total_output_token_count, 40)
|
agent.monitor.total_input_token_count, 20
|
||||||
|
) # Should have done two monitoring callbacks
|
||||||
|
self.assertEqual(agent.monitor.total_output_token_count, 0)
|
||||||
|
|
||||||
def test_streaming_agent_text_output(self):
|
def test_streaming_agent_text_output(self):
|
||||||
def dummy_model(prompt, **kwargs):
|
def dummy_model(prompt, **kwargs):
|
||||||
|
@ -132,14 +135,16 @@ final_answer('This is the final answer.')
|
||||||
self.assertIn("This is the final answer.", final_message.content)
|
self.assertIn("This is the final answer.", final_message.content)
|
||||||
|
|
||||||
def test_streaming_agent_image_output(self):
|
def test_streaming_agent_image_output(self):
|
||||||
def dummy_model(prompt, **kwargs):
|
class FakeLLM:
|
||||||
return (
|
def __init__(self):
|
||||||
'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
|
pass
|
||||||
)
|
|
||||||
|
def get_tool_call(self, messages, **kwargs):
|
||||||
|
return "final_answer", {"answer": "image"}, "fake_id"
|
||||||
|
|
||||||
agent = ToolCallingAgent(
|
agent = ToolCallingAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
model=dummy_model,
|
model=FakeLLM(),
|
||||||
max_iterations=1,
|
max_iterations=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -148,7 +153,7 @@ final_answer('This is the final answer.')
|
||||||
stream_to_gradio(
|
stream_to_gradio(
|
||||||
agent,
|
agent,
|
||||||
task="Test task",
|
task="Test task",
|
||||||
image=AgentImage(value="path.png"),
|
additional_args=dict(image=AgentImage(value="path.png")),
|
||||||
test_mode=True,
|
test_mode=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -41,17 +41,16 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
||||||
|
|
||||||
def test_exact_match_arg(self):
|
def test_exact_match_arg(self):
|
||||||
result = self.tool("(2 / 2) * 4")
|
result = self.tool("(2 / 2) * 4")
|
||||||
self.assertEqual(result, "4.0")
|
self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
|
||||||
|
|
||||||
def test_exact_match_kwarg(self):
|
def test_exact_match_kwarg(self):
|
||||||
result = self.tool(code="(2 / 2) * 4")
|
result = self.tool(code="(2 / 2) * 4")
|
||||||
self.assertEqual(result, "4.0")
|
self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
|
||||||
|
|
||||||
def test_agent_type_output(self):
|
def test_agent_type_output(self):
|
||||||
inputs = ["2 * 2"]
|
inputs = ["2 * 2"]
|
||||||
output = self.tool(*inputs)
|
output = self.tool(*inputs, sanitize_inputs_outputs=True)
|
||||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||||
print("OKK", type(output), output_type, AGENT_TYPE_MAPPING)
|
|
||||||
self.assertTrue(isinstance(output, output_type))
|
self.assertTrue(isinstance(output, output_type))
|
||||||
|
|
||||||
def test_agent_types_inputs(self):
|
def test_agent_types_inputs(self):
|
||||||
|
@ -71,7 +70,7 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
||||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
||||||
|
|
||||||
# Should not raise an error
|
# Should not raise an error
|
||||||
output = self.tool(*inputs)
|
output = self.tool(*inputs, sanitize_inputs_outputs=True)
|
||||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||||
self.assertTrue(isinstance(output, output_type))
|
self.assertTrue(isinstance(output, output_type))
|
||||||
|
|
||||||
|
|
|
@ -27,4 +27,4 @@ class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin):
|
||||||
|
|
||||||
def test_exact_match_arg(self):
|
def test_exact_match_arg(self):
|
||||||
result = self.tool("Agents")
|
result = self.tool("Agents")
|
||||||
assert isinstance(result, list) and isinstance(result[0], dict)
|
assert isinstance(result, str)
|
||||||
|
|
|
@ -93,7 +93,7 @@ class ToolTesterMixin:
|
||||||
|
|
||||||
def test_agent_type_output(self):
|
def test_agent_type_output(self):
|
||||||
inputs = create_inputs(self.tool.inputs)
|
inputs = create_inputs(self.tool.inputs)
|
||||||
output = self.tool(**inputs)
|
output = self.tool(**inputs, sanitize_inputs_outputs=True)
|
||||||
if self.tool.output_type != "any":
|
if self.tool.output_type != "any":
|
||||||
agent_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
agent_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||||
self.assertTrue(isinstance(output, agent_type))
|
self.assertTrue(isinstance(output, agent_type))
|
||||||
|
@ -164,20 +164,20 @@ class ToolTests(unittest.TestCase):
|
||||||
assert coolfunc.output_type == "number"
|
assert coolfunc.output_type == "number"
|
||||||
assert "docstring has no description for the argument" in str(e)
|
assert "docstring has no description for the argument" in str(e)
|
||||||
|
|
||||||
def test_tool_definition_raises_error_imports_outside_function(self):
|
def test_saving_tool_raises_error_imports_outside_function(self):
|
||||||
with pytest.raises(Exception) as e:
|
with pytest.raises(Exception) as e:
|
||||||
from datetime import datetime
|
import numpy as np
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def get_current_time() -> str:
|
def get_current_time() -> str:
|
||||||
"""
|
"""
|
||||||
Gets the current time.
|
Gets the current time.
|
||||||
"""
|
"""
|
||||||
return str(datetime.now())
|
return str(np.random.random())
|
||||||
|
|
||||||
get_current_time.save("output")
|
get_current_time.save("output")
|
||||||
|
|
||||||
assert "datetime" in str(e)
|
assert "np" in str(e)
|
||||||
|
|
||||||
# Also test with classic definition
|
# Also test with classic definition
|
||||||
with pytest.raises(Exception) as e:
|
with pytest.raises(Exception) as e:
|
||||||
|
@ -189,12 +189,12 @@ class ToolTests(unittest.TestCase):
|
||||||
output_type = "string"
|
output_type = "string"
|
||||||
|
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return str(datetime.now())
|
return str(np.random.random())
|
||||||
|
|
||||||
get_current_time = GetCurrentTimeTool()
|
get_current_time = GetCurrentTimeTool()
|
||||||
get_current_time.save("output")
|
get_current_time.save("output")
|
||||||
|
|
||||||
assert "datetime" in str(e)
|
assert "np" in str(e)
|
||||||
|
|
||||||
def test_tool_definition_raises_no_error_imports_in_function(self):
|
def test_tool_definition_raises_no_error_imports_in_function(self):
|
||||||
@tool
|
@tool
|
||||||
|
|
|
@ -20,7 +20,6 @@ from pathlib import Path
|
||||||
|
|
||||||
from smolagents.types import AgentAudio, AgentImage, AgentText
|
from smolagents.types import AgentAudio, AgentImage, AgentText
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
get_tests_dir,
|
|
||||||
require_soundfile,
|
require_soundfile,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_vision,
|
require_vision,
|
||||||
|
@ -91,7 +90,7 @@ class AgentImageTests(unittest.TestCase):
|
||||||
self.assertTrue(os.path.exists(path))
|
self.assertTrue(os.path.exists(path))
|
||||||
|
|
||||||
def test_from_string(self):
|
def test_from_string(self):
|
||||||
path = Path(get_tests_dir("fixtures/")) / "000000039769.png"
|
path = Path("tests/fixtures/000000039769.png")
|
||||||
image = Image.open(path)
|
image = Image.open(path)
|
||||||
agent_type = AgentImage(path)
|
agent_type = AgentImage(path)
|
||||||
|
|
||||||
|
@ -103,7 +102,7 @@ class AgentImageTests(unittest.TestCase):
|
||||||
self.assertTrue(os.path.exists(path))
|
self.assertTrue(os.path.exists(path))
|
||||||
|
|
||||||
def test_from_image(self):
|
def test_from_image(self):
|
||||||
path = Path(get_tests_dir("fixtures/")) / "000000039769.png"
|
path = Path("tests/fixtures/000000039769.png")
|
||||||
image = Image.open(path)
|
image = Image.open(path)
|
||||||
agent_type = AgentImage(image)
|
agent_type = AgentImage(image)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue