Sort imports and add test workflows

This commit is contained in:
Aymeric 2025-01-06 21:48:15 +01:00
parent 417c6685b0
commit c22fedaee1
27 changed files with 188 additions and 136 deletions

View File

@ -14,20 +14,22 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: 3.10
- name: Install Python dependencies
run: pip install -e .[quality]
- name: Run Quality check
run: make quality
- name: Check if failure
if: ${{ failure() }}
# Setup venv
# TODO: revisit when https://github.com/astral-sh/uv/issues/1526 is addressed.
- name: Setup venv + uv
run: |
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and rerun 'make style; make quality;'" >> $GITHUB_STEP_SUMMARY
pip install --upgrade uv
uv venv
- name: Run Style check
run: make style
- name: Check if failure
if: ${{ failure() }}
run: |
echo "Style check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and rerun 'make style; make quality;'" >> $GITHUB_STEP_SUMMARY
- name: Install dependencies
run: uv pip install "smolagents[test] @ ."
- run: uv run ruff check tests src # linter
- run: uv run ruff format --check tests src # formatter
# Run type checking at least on smolagents root file to check all modules
# that can be lazy-loaded actually exist.
# - run: uv run mypy src/smolagents/__init__.py --follow-imports=silent --show-traceback
# Run mypy on full package
# - run: uv run mypy src

47
.github/workflows/tests.yml vendored Normal file
View File

@ -0,0 +1,47 @@
name: Python tests
on: [pull_request]
jobs:
build-ubuntu:
runs-on: ubuntu-latest
env:
UV_HTTP_TIMEOUT: 600 # max 10min to install deps
strategy:
fail-fast: false
matrix:
python-version: ["3.10", "3.12"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
# Setup venv
# TODO: revisit when https://github.com/astral-sh/uv/issues/1526 is addressed.
- name: Setup venv + uv
run: |
pip install --upgrade uv
uv venv
# Install dependencies
- name: Install dependencies
run: |
uv pip install "smolagents[test] @ ."
- name: Agent tests
run: |
uv run pytest -sv ./tests/test_agents.py
- name: Tool tests
run: |
uv run pytest -sv ./tests/test_toolss.py
- name: Python interpreter tests
run: |
uv run pytest -sv ./tests/test_python_interpreter.py
- name: Final answer tests
run: |
uv run pytest -sv ./tests/test_final_answer.py

View File

@ -346,7 +346,7 @@
"import glob\n",
"\n",
"res = []\n",
"for f in glob.glob(f\"output/*.jsonl\"):\n",
"for f in glob.glob(\"output/*.jsonl\"):\n",
" res.append(pd.read_json(f, lines=True))\n",
"result_df = pd.concat(res)\n",
"\n",

View File

@ -1,5 +1,5 @@
from smolagents.agents import ToolCallingAgent
from smolagents import tool, HfApiModel, TransformersModel, LiteLLMModel
from smolagents import tool, LiteLLMModel
from typing import Optional
# Choose which LLM engine to use!

View File

@ -12,25 +12,29 @@ authors = [
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"torch",
"torchaudio",
"torchvision",
"transformers>=4.0.0",
"requests>=2.32.3",
"rich>=13.9.4",
"pandas>=2.2.3",
"jinja2>=3.1.4",
"pillow>=11.0.0",
"markdownify>=0.14.1",
"gradio>=5.8.0",
"duckduckgo-search>=6.3.7",
"python-dotenv>=1.0.1",
"e2b-code-interpreter>=1.0.3",
"litellm>=1.55.10",
"torch",
"torchaudio",
"torchvision",
"transformers>=4.0.0",
"requests>=2.32.3",
"rich>=13.9.4",
"pandas>=2.2.3",
"jinja2>=3.1.4",
"pillow>=11.0.0",
"markdownify>=0.14.1",
"gradio>=5.8.0",
"duckduckgo-search>=6.3.7",
"python-dotenv>=1.0.1",
"e2b-code-interpreter>=1.0.3",
"litellm>=1.55.10",
]
[tool.ruff]
ignore = ["F403"]
[project.optional-dependencies]
test = [
"pytest>=8.1.0",
"sqlalchemy"
"pytest>=8.1.0",
"sqlalchemy",
"ruff>=0.5.0",
]

View File

@ -14,7 +14,7 @@ def execute_code(code):
try:
exec(code, exec_globals, exec_locals)
except Exception as e:
except Exception:
traceback.print_exc(file=stderr)
output = stdout.getvalue()

View File

@ -21,14 +21,13 @@ from typing import TYPE_CHECKING
from transformers.utils import _LazyModule
from transformers.utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .agents import *
from .default_tools import *
from .gradio_ui import *
from .models import *
from .local_python_executor import *
from .e2b_executor import *
from .gradio_ui import *
from .local_python_executor import *
from .models import *
from .monitoring import *
from .prompts import *
from .tools import *

View File

@ -15,49 +15,50 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
from dataclasses import dataclass
from rich.syntax import Syntax
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from rich.console import Group
from rich.panel import Panel
from rich.rule import Rule
from rich.syntax import Syntax
from rich.text import Text
from .utils import (
console,
parse_code_blob,
parse_json_tool_call,
truncate_content,
AgentError,
AgentParsingError,
AgentExecutionError,
AgentGenerationError,
AgentMaxStepsError,
)
from .types import AgentAudio, AgentImage, handle_agent_output_types
from .default_tools import FinalAnswerTool
from .e2b_executor import E2BExecutor
from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonInterpreter
from .models import MessageRole
from .monitoring import Monitor
from .prompts import (
CODE_SYSTEM_PROMPT,
TOOL_CALLING_SYSTEM_PROMPT,
MANAGED_AGENT_PROMPT,
PLAN_UPDATE_FINAL_PLAN_REDACTION,
SYSTEM_PROMPT_FACTS,
SYSTEM_PROMPT_FACTS_UPDATE,
USER_PROMPT_FACTS_UPDATE,
USER_PROMPT_PLAN_UPDATE,
USER_PROMPT_PLAN,
SYSTEM_PROMPT_PLAN_UPDATE,
SYSTEM_PROMPT_PLAN,
MANAGED_AGENT_PROMPT,
SYSTEM_PROMPT_PLAN_UPDATE,
TOOL_CALLING_SYSTEM_PROMPT,
USER_PROMPT_FACTS_UPDATE,
USER_PROMPT_PLAN,
USER_PROMPT_PLAN_UPDATE,
)
from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonInterpreter
from .e2b_executor import E2BExecutor
from .tools import (
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
Tool,
get_tool_description_with_args,
Toolbox,
get_tool_description_with_args,
)
from .types import AgentAudio, AgentImage, handle_agent_output_types
from .utils import (
AgentError,
AgentExecutionError,
AgentGenerationError,
AgentMaxStepsError,
AgentParsingError,
console,
parse_code_blob,
parse_json_tool_call,
truncate_content,
)

View File

@ -18,20 +18,20 @@ import json
import re
from dataclasses import dataclass
from typing import Dict, Optional
from huggingface_hub import hf_hub_download, list_spaces
from transformers.utils import is_offline_mode
from huggingface_hub import hf_hub_download, list_spaces
from transformers.models.whisper import (
WhisperProcessor,
WhisperForConditionalGeneration,
WhisperProcessor,
)
from transformers.utils import is_offline_mode
from .local_python_executor import (
BASE_BUILTIN_MODULES,
BASE_PYTHON_TOOLS,
evaluate_python_code,
)
from .tools import TOOL_CONFIG_FILE, Tool, PipelineTool
from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool
from .types import AgentAudio
@ -271,8 +271,8 @@ class VisitWebpageTool(Tool):
def forward(self, url: str) -> str:
try:
from markdownify import markdownify
import requests
from markdownify import markdownify
from requests.exceptions import RequestException
except ImportError:
raise ImportError(

View File

@ -14,18 +14,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dotenv import load_dotenv
import textwrap
import base64
import pickle
import textwrap
from io import BytesIO
from typing import Any, List, Tuple
from dotenv import load_dotenv
from e2b_code_interpreter import Sandbox
from PIL import Image
from e2b_code_interpreter import Sandbox
from typing import List, Tuple, Any
from .tool_validation import validate_tool_attributes
from .utils import instance_to_source, BASE_BUILTIN_MODULES, console
from .tools import Tool
from .utils import BASE_BUILTIN_MODULES, console, instance_to_source
load_dotenv()

View File

@ -14,10 +14,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
from .agents import MultiStepAgent, AgentStep, ActionStep
import gradio as gr
from .agents import ActionStep, AgentStep, MultiStepAgent
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
"""Extract ChatMessage objects from agent steps"""

View File

@ -17,14 +17,15 @@
import ast
import builtins
import difflib
import math
from collections.abc import Mapping
from importlib import import_module
from typing import Any, Callable, Dict, List, Optional, Tuple
import math
import numpy as np
import pandas as pd
from .utils import truncate_content, BASE_BUILTIN_MODULES
from .utils import BASE_BUILTIN_MODULES, truncate_content
class InterpreterError(ValueError):

View File

@ -14,24 +14,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from enum import Enum
import json
from typing import Dict, List, Optional
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
StoppingCriteria,
StoppingCriteriaList,
)
import litellm
import logging
import os
import random
import torch
from copy import deepcopy
from enum import Enum
from typing import Dict, List, Optional, Tuple, Union
import litellm
import torch
from huggingface_hub import InferenceClient
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
)
from .tools import Tool
from .utils import parse_json_tool_call
@ -352,16 +351,16 @@ class TransformersModel(Model):
)
# Get LLM output
prompt = self.tokenizer.apply_chat_template(
prompt_tensor = self.tokenizer.apply_chat_template(
messages,
return_tensors="pt",
return_dict=True,
)
prompt = prompt.to(self.model.device)
count_prompt_tokens = prompt["input_ids"].shape[1]
prompt_tensor = prompt_tensor.to(self.model.device)
count_prompt_tokens = prompt_tensor["input_ids"].shape[1]
out = self.model.generate(
**prompt,
**prompt_tensor,
max_new_tokens=max_tokens,
stopping_criteria=(
self.make_stopping_criteria(stop_sequences) if stop_sequences else None
@ -383,7 +382,7 @@ class TransformersModel(Model):
available_tools: List[Tool],
stop_sequences: Optional[List[str]] = None,
max_tokens: int = 500,
) -> str:
) -> Tuple[str, Union[str, None], str]:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)

View File

@ -14,9 +14,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .utils import console
from rich.text import Text
from .utils import console
class Monitor:
def __init__(self, tracked_model):

View File

@ -1,8 +1,9 @@
import ast
import inspect
import builtins
from typing import Set
import inspect
import textwrap
from typing import Set
from .utils import BASE_BUILTIN_MODULES
_BUILTIN_NAMES = set(vars(builtins))

View File

@ -18,14 +18,16 @@ import ast
import importlib
import inspect
import json
import logging
import os
import sys
import tempfile
import torch
import textwrap
from functools import lru_cache, wraps
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union, get_type_hints
import torch
from huggingface_hub import (
create_repo,
get_collection,
@ -35,7 +37,8 @@ from huggingface_hub import (
)
from huggingface_hub.utils import RepositoryNotFoundError
from packaging import version
import logging
from transformers import AutoProcessor
from transformers.dynamic_module_utils import get_imports
from transformers.utils import (
TypeHintParsingException,
cached_file,
@ -45,13 +48,9 @@ from transformers.utils import (
)
from transformers.utils.chat_template_utils import _parse_type_hint
from transformers.dynamic_module_utils import get_imports
from transformers import AutoProcessor
from .tool_validation import MethodChecker, validate_tool_attributes
from .types import ImageType, handle_agent_input_types, handle_agent_output_types
from .utils import instance_to_source
from .tool_validation import validate_tool_attributes, MethodChecker
logger = logging.getLogger(__name__)

View File

@ -12,21 +12,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import pathlib
import tempfile
import uuid
from io import BytesIO
import requests
import numpy as np
import numpy as np
import requests
from transformers.utils import (
is_soundfile_availble,
is_torch_available,
is_vision_available,
)
import logging
logger = logging.getLogger(__name__)

View File

@ -14,14 +14,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import inspect
import json
import re
from typing import Tuple, Dict, Union
import ast
from rich.console import Console
import inspect
import types
from typing import Dict, Tuple, Union
from rich.console import Console
from transformers.utils.import_utils import _is_package_available
_pygments_available = _is_package_available("pygments")

View File

@ -16,22 +16,22 @@ import os
import tempfile
import unittest
import uuid
import pytest
from pathlib import Path
from smolagents.types import AgentText, AgentImage
import pytest
from transformers.testing_utils import get_tests_dir
from smolagents.agents import (
AgentMaxStepsError,
ManagedAgent,
CodeAgent,
ToolCallingAgent,
ManagedAgent,
Toolbox,
ToolCall,
ToolCallingAgent,
)
from smolagents.tools import tool
from smolagents.default_tools import PythonInterpreterTool
from transformers.testing_utils import get_tests_dir
from smolagents.tools import tool
from smolagents.types import AgentImage, AgentText
def get_new_path(suffix="") -> str:

View File

@ -17,12 +17,13 @@ import ast
import os
import re
import shutil
import tempfile
import subprocess
import tempfile
import traceback
import pytest
from pathlib import Path
from typing import List
import pytest
from dotenv import load_dotenv

View File

@ -18,16 +18,14 @@ from pathlib import Path
import numpy as np
from PIL import Image
from transformers import is_torch_available
from transformers.testing_utils import get_tests_dir, require_torch
from smolagents.types import AGENT_TYPE_MAPPING
from smolagents.default_tools import FinalAnswerTool
from smolagents.types import AGENT_TYPE_MAPPING
from .test_tools import ToolTesterMixin
if is_torch_available():
import torch

View File

@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from smolagents import models, tool
from typing import Optional
from smolagents import models, tool
class ModelTests(unittest.TestCase):
def test_get_json_schema_has_nullable_args(self):

View File

@ -16,8 +16,8 @@
import unittest
from smolagents import (
AgentImage,
AgentError,
AgentImage,
CodeAgent,
ToolCallingAgent,
stream_to_gradio,

View File

@ -19,12 +19,12 @@ import numpy as np
import pytest
from smolagents import load_tool
from smolagents.types import AGENT_TYPE_MAPPING
from smolagents.default_tools import BASE_PYTHON_TOOLS
from smolagents.local_python_executor import (
InterpreterError,
evaluate_python_code,
)
from smolagents.types import AGENT_TYPE_MAPPING
from .test_tools import ToolTesterMixin

View File

@ -14,21 +14,20 @@
# limitations under the License.
import unittest
from pathlib import Path
from typing import Dict, Union, Optional
from typing import Dict, Optional, Union
import numpy as np
import pytest
from transformers import is_torch_available, is_vision_available
from transformers.testing_utils import get_tests_dir
from smolagents.tools import AUTHORIZED_TYPES, Tool, tool
from smolagents.types import (
AGENT_TYPE_MAPPING,
AgentAudio,
AgentImage,
AgentText,
)
from smolagents.tools import Tool, tool, AUTHORIZED_TYPES
from transformers.testing_utils import get_tests_dir
if is_torch_available():
import torch

View File

@ -18,7 +18,8 @@ import unittest
import uuid
from pathlib import Path
from smolagents.types import AgentAudio, AgentImage, AgentText
import torch
from PIL import Image
from transformers.testing_utils import (
require_soundfile,
require_torch,
@ -28,9 +29,7 @@ from transformers.utils import (
is_soundfile_availble,
)
import torch
from PIL import Image
from smolagents.types import AgentAudio, AgentImage, AgentText
if is_soundfile_availble():
import soundfile as sf

View File

@ -1,8 +1,7 @@
import os
import unittest
import shutil
import tempfile
import unittest
from pathlib import Path