Pass tests
This commit is contained in:
parent
a50f9284b3
commit
54d6857da2
6
Makefile
6
Makefile
|
@ -8,19 +8,19 @@ extra_quality_checks:
|
|||
python utils/check_copies.py
|
||||
python utils/check_dummies.py
|
||||
python utils/check_repo.py
|
||||
doc-builder style agents docs/source --max_len 119
|
||||
doc-builder style smolagents docs/source --max_len 119
|
||||
|
||||
# this target runs checks on all files
|
||||
quality:
|
||||
ruff check $(check_dirs)
|
||||
ruff format --check $(check_dirs)
|
||||
doc-builder style agents docs/source --max_len 119 --check_only
|
||||
doc-builder style smolagents docs/source --max_len 119 --check_only
|
||||
|
||||
# Format source code automatically and check is there are any problems left that need manual fixing
|
||||
style:
|
||||
ruff check $(check_dirs) --fix
|
||||
ruff format $(check_dirs)
|
||||
doc-builder style agents docs/source --max_len 119
|
||||
doc-builder style smolagents docs/source --max_len 119
|
||||
|
||||
# Run tests for the library
|
||||
test_big_modeling:
|
||||
|
|
|
@ -370,8 +370,7 @@ class MultiStepAgent:
|
|||
try:
|
||||
return self.model(self.input_messages)
|
||||
except Exception as e:
|
||||
error_msg = f"Error in generating final LLM output:\n{e}"
|
||||
raise AgentGenerationError(error_msg)
|
||||
return f"Error in generating final LLM output:\n{e}"
|
||||
|
||||
def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any:
|
||||
"""
|
||||
|
|
|
@ -153,8 +153,8 @@ class DuckDuckGoSearchTool(Tool):
|
|||
}
|
||||
output_type = "any"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(self)
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(self, **kwargs)
|
||||
try:
|
||||
from duckduckgo_search import DDGS
|
||||
except ImportError:
|
||||
|
|
|
@ -410,7 +410,12 @@ class TransformersModel(Model):
|
|||
|
||||
|
||||
class LiteLLMModel(Model):
|
||||
def __init__(self, model_id="anthropic/claude-3-5-sonnet-20240620", api_base=None, api_key=None):
|
||||
def __init__(
|
||||
self,
|
||||
model_id="anthropic/claude-3-5-sonnet-20240620",
|
||||
api_base=None,
|
||||
api_key=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.model_id = model_id
|
||||
# IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs
|
||||
|
|
|
@ -517,5 +517,6 @@ __all__ = [
|
|||
"PLAN_UPDATE_FINAL_PLAN_REDACTION",
|
||||
"SINGLE_STEP_CODE_SYSTEM_PROMPT",
|
||||
"CODE_SYSTEM_PROMPT",
|
||||
"TOOL_CALLING_SYSTEM_PROMPT",
|
||||
"MANAGED_AGENT_PROMPT",
|
||||
]
|
||||
|
|
|
@ -7,8 +7,6 @@ from .utils import BASE_BUILTIN_MODULES
|
|||
|
||||
_BUILTIN_NAMES = set(vars(builtins))
|
||||
|
||||
IMPORTED_PACKAGES = BASE_BUILTIN_MODULES
|
||||
|
||||
|
||||
class MethodChecker(ast.NodeVisitor):
|
||||
"""
|
||||
|
@ -91,7 +89,7 @@ class MethodChecker(ast.NodeVisitor):
|
|||
if isinstance(node.ctx, ast.Load):
|
||||
if not (
|
||||
node.id in _BUILTIN_NAMES
|
||||
or node.id in IMPORTED_PACKAGES
|
||||
or node.id in BASE_BUILTIN_MODULES
|
||||
or node.id in self.arg_names
|
||||
or node.id == "self"
|
||||
or node.id in self.class_attributes
|
||||
|
@ -105,7 +103,7 @@ class MethodChecker(ast.NodeVisitor):
|
|||
if isinstance(node.func, ast.Name):
|
||||
if not (
|
||||
node.func.id in _BUILTIN_NAMES
|
||||
or node.func.id in IMPORTED_PACKAGES
|
||||
or node.func.id in BASE_BUILTIN_MODULES
|
||||
or node.func.id in self.arg_names
|
||||
or node.func.id == "self"
|
||||
or node.func.id in self.class_attributes
|
||||
|
|
|
@ -854,7 +854,7 @@ def load_tool(
|
|||
main_module = importlib.import_module("smolagents")
|
||||
tools_module = main_module
|
||||
tool_class = getattr(tools_module, tool_class_name)
|
||||
return tool_class(model_repo_id, token=token, **kwargs)
|
||||
return tool_class(token=token, **kwargs)
|
||||
else:
|
||||
return Tool.from_hub(
|
||||
task_or_repo_id,
|
||||
|
|
|
@ -104,7 +104,9 @@ class AgentImage(AgentType, ImageType):
|
|||
self._raw = None
|
||||
self._tensor = None
|
||||
|
||||
if isinstance(value, ImageType):
|
||||
if isinstance(value, AgentImage):
|
||||
self._raw, self._path, self._tensor = value._raw, value._path, value._tensor
|
||||
elif isinstance(value, ImageType):
|
||||
self._raw = value
|
||||
elif isinstance(value, bytes):
|
||||
self._raw = Image.open(BytesIO(value))
|
||||
|
|
|
@ -123,6 +123,7 @@ class TestDocs:
|
|||
"ToolCollection",
|
||||
"image_generation_tool",
|
||||
"from_langchain",
|
||||
"while llm_should_continue(memory):",
|
||||
]
|
||||
code_blocks = [
|
||||
block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", self.hf_token).replace(
|
||||
|
|
|
@ -59,6 +59,6 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
|
|||
def test_agent_type_output(self):
|
||||
inputs = self.create_inputs()
|
||||
for input_type, input in inputs.items():
|
||||
output = self.tool(**input)
|
||||
output = self.tool(**input, sanitize_inputs_outputs=True)
|
||||
agent_type = AGENT_TYPE_MAPPING[input_type]
|
||||
self.assertTrue(isinstance(output, agent_type))
|
||||
|
|
|
@ -55,8 +55,8 @@ final_answer('This is the final answer.')
|
|||
self.last_input_token_count = 10
|
||||
self.last_output_token_count = 20
|
||||
|
||||
def __call__(self, prompt, **kwargs):
|
||||
return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
|
||||
def get_tool_call(self, prompt, **kwargs):
|
||||
return "final_answer", {"answer": "image"}, "fake_id"
|
||||
|
||||
agent = ToolCallingAgent(
|
||||
tools=[],
|
||||
|
@ -96,18 +96,21 @@ final_answer('This is the final answer.')
|
|||
self.last_output_token_count = 20
|
||||
|
||||
def __call__(self, prompt, **kwargs):
|
||||
raise AgentError
|
||||
self.last_input_token_count = 10
|
||||
self.last_output_token_count = 0
|
||||
raise Exception("Cannot generate")
|
||||
|
||||
agent = CodeAgent(
|
||||
tools=[],
|
||||
model=FakeLLMModel(),
|
||||
max_iterations=1,
|
||||
)
|
||||
|
||||
agent.run("Fake task")
|
||||
|
||||
self.assertEqual(agent.monitor.total_input_token_count, 20)
|
||||
self.assertEqual(agent.monitor.total_output_token_count, 40)
|
||||
self.assertEqual(
|
||||
agent.monitor.total_input_token_count, 20
|
||||
) # Should have done two monitoring callbacks
|
||||
self.assertEqual(agent.monitor.total_output_token_count, 0)
|
||||
|
||||
def test_streaming_agent_text_output(self):
|
||||
def dummy_model(prompt, **kwargs):
|
||||
|
@ -132,14 +135,16 @@ final_answer('This is the final answer.')
|
|||
self.assertIn("This is the final answer.", final_message.content)
|
||||
|
||||
def test_streaming_agent_image_output(self):
|
||||
def dummy_model(prompt, **kwargs):
|
||||
return (
|
||||
'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
|
||||
)
|
||||
class FakeLLM:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_tool_call(self, messages, **kwargs):
|
||||
return "final_answer", {"answer": "image"}, "fake_id"
|
||||
|
||||
agent = ToolCallingAgent(
|
||||
tools=[],
|
||||
model=dummy_model,
|
||||
model=FakeLLM(),
|
||||
max_iterations=1,
|
||||
)
|
||||
|
||||
|
@ -148,7 +153,7 @@ final_answer('This is the final answer.')
|
|||
stream_to_gradio(
|
||||
agent,
|
||||
task="Test task",
|
||||
image=AgentImage(value="path.png"),
|
||||
additional_args=dict(image=AgentImage(value="path.png")),
|
||||
test_mode=True,
|
||||
)
|
||||
)
|
||||
|
|
|
@ -41,17 +41,16 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
|||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("(2 / 2) * 4")
|
||||
self.assertEqual(result, "4.0")
|
||||
self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(code="(2 / 2) * 4")
|
||||
self.assertEqual(result, "4.0")
|
||||
self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
|
||||
|
||||
def test_agent_type_output(self):
|
||||
inputs = ["2 * 2"]
|
||||
output = self.tool(*inputs)
|
||||
output = self.tool(*inputs, sanitize_inputs_outputs=True)
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
print("OKK", type(output), output_type, AGENT_TYPE_MAPPING)
|
||||
self.assertTrue(isinstance(output, output_type))
|
||||
|
||||
def test_agent_types_inputs(self):
|
||||
|
@ -71,7 +70,7 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
|||
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
||||
|
||||
# Should not raise an error
|
||||
output = self.tool(*inputs)
|
||||
output = self.tool(*inputs, sanitize_inputs_outputs=True)
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, output_type))
|
||||
|
||||
|
|
|
@ -27,4 +27,4 @@ class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin):
|
|||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("Agents")
|
||||
assert isinstance(result, list) and isinstance(result[0], dict)
|
||||
assert isinstance(result, str)
|
||||
|
|
|
@ -93,7 +93,7 @@ class ToolTesterMixin:
|
|||
|
||||
def test_agent_type_output(self):
|
||||
inputs = create_inputs(self.tool.inputs)
|
||||
output = self.tool(**inputs)
|
||||
output = self.tool(**inputs, sanitize_inputs_outputs=True)
|
||||
if self.tool.output_type != "any":
|
||||
agent_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, agent_type))
|
||||
|
@ -164,20 +164,20 @@ class ToolTests(unittest.TestCase):
|
|||
assert coolfunc.output_type == "number"
|
||||
assert "docstring has no description for the argument" in str(e)
|
||||
|
||||
def test_tool_definition_raises_error_imports_outside_function(self):
|
||||
def test_saving_tool_raises_error_imports_outside_function(self):
|
||||
with pytest.raises(Exception) as e:
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
|
||||
@tool
|
||||
def get_current_time() -> str:
|
||||
"""
|
||||
Gets the current time.
|
||||
"""
|
||||
return str(datetime.now())
|
||||
return str(np.random.random())
|
||||
|
||||
get_current_time.save("output")
|
||||
|
||||
assert "datetime" in str(e)
|
||||
assert "np" in str(e)
|
||||
|
||||
# Also test with classic definition
|
||||
with pytest.raises(Exception) as e:
|
||||
|
@ -189,12 +189,12 @@ class ToolTests(unittest.TestCase):
|
|||
output_type = "string"
|
||||
|
||||
def forward(self):
|
||||
return str(datetime.now())
|
||||
return str(np.random.random())
|
||||
|
||||
get_current_time = GetCurrentTimeTool()
|
||||
get_current_time.save("output")
|
||||
|
||||
assert "datetime" in str(e)
|
||||
assert "np" in str(e)
|
||||
|
||||
def test_tool_definition_raises_no_error_imports_in_function(self):
|
||||
@tool
|
||||
|
|
|
@ -20,7 +20,6 @@ from pathlib import Path
|
|||
|
||||
from smolagents.types import AgentAudio, AgentImage, AgentText
|
||||
from transformers.testing_utils import (
|
||||
get_tests_dir,
|
||||
require_soundfile,
|
||||
require_torch,
|
||||
require_vision,
|
||||
|
@ -91,7 +90,7 @@ class AgentImageTests(unittest.TestCase):
|
|||
self.assertTrue(os.path.exists(path))
|
||||
|
||||
def test_from_string(self):
|
||||
path = Path(get_tests_dir("fixtures/")) / "000000039769.png"
|
||||
path = Path("tests/fixtures/000000039769.png")
|
||||
image = Image.open(path)
|
||||
agent_type = AgentImage(path)
|
||||
|
||||
|
@ -103,7 +102,7 @@ class AgentImageTests(unittest.TestCase):
|
|||
self.assertTrue(os.path.exists(path))
|
||||
|
||||
def test_from_image(self):
|
||||
path = Path(get_tests_dir("fixtures/")) / "000000039769.png"
|
||||
path = Path("tests/fixtures/000000039769.png")
|
||||
image = Image.open(path)
|
||||
agent_type = AgentImage(image)
|
||||
|
||||
|
|
Loading…
Reference in New Issue