Sort imports and add test workflows

This commit is contained in:
Aymeric 2025-01-06 21:48:15 +01:00
parent 417c6685b0
commit c22fedaee1
27 changed files with 188 additions and 136 deletions

View File

@ -15,19 +15,21 @@ jobs:
with: with:
python-version: 3.10 python-version: 3.10
- name: Install Python dependencies # Setup venv
run: pip install -e .[quality] # TODO: revisit when https://github.com/astral-sh/uv/issues/1526 is addressed.
- name: Setup venv + uv
- name: Run Quality check
run: make quality
- name: Check if failure
if: ${{ failure() }}
run: | run: |
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and rerun 'make style; make quality;'" >> $GITHUB_STEP_SUMMARY pip install --upgrade uv
uv venv
- name: Run Style check - name: Install dependencies
run: make style run: uv pip install "smolagents[test] @ ."
- name: Check if failure - run: uv run ruff check tests src # linter
if: ${{ failure() }} - run: uv run ruff format --check tests src # formatter
run: |
echo "Style check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and rerun 'make style; make quality;'" >> $GITHUB_STEP_SUMMARY # Run type checking at least on smolagents root file to check all modules
# that can be lazy-loaded actually exist.
# - run: uv run mypy src/smolagents/__init__.py --follow-imports=silent --show-traceback
# Run mypy on full package
# - run: uv run mypy src

47
.github/workflows/tests.yml vendored Normal file
View File

@ -0,0 +1,47 @@
name: Python tests
on: [pull_request]
jobs:
build-ubuntu:
runs-on: ubuntu-latest
env:
UV_HTTP_TIMEOUT: 600 # max 10min to install deps
strategy:
fail-fast: false
matrix:
python-version: ["3.10", "3.12"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
# Setup venv
# TODO: revisit when https://github.com/astral-sh/uv/issues/1526 is addressed.
- name: Setup venv + uv
run: |
pip install --upgrade uv
uv venv
# Install dependencies
- name: Install dependencies
run: |
uv pip install "smolagents[test] @ ."
- name: Agent tests
run: |
uv run pytest -sv ./tests/test_agents.py
- name: Tool tests
run: |
uv run pytest -sv ./tests/test_toolss.py
- name: Python interpreter tests
run: |
uv run pytest -sv ./tests/test_python_interpreter.py
- name: Final answer tests
run: |
uv run pytest -sv ./tests/test_final_answer.py

View File

@ -346,7 +346,7 @@
"import glob\n", "import glob\n",
"\n", "\n",
"res = []\n", "res = []\n",
"for f in glob.glob(f\"output/*.jsonl\"):\n", "for f in glob.glob(\"output/*.jsonl\"):\n",
" res.append(pd.read_json(f, lines=True))\n", " res.append(pd.read_json(f, lines=True))\n",
"result_df = pd.concat(res)\n", "result_df = pd.concat(res)\n",
"\n", "\n",

View File

@ -1,5 +1,5 @@
from smolagents.agents import ToolCallingAgent from smolagents.agents import ToolCallingAgent
from smolagents import tool, HfApiModel, TransformersModel, LiteLLMModel from smolagents import tool, LiteLLMModel
from typing import Optional from typing import Optional
# Choose which LLM engine to use! # Choose which LLM engine to use!

View File

@ -29,8 +29,12 @@ dependencies = [
"litellm>=1.55.10", "litellm>=1.55.10",
] ]
[tool.ruff]
ignore = ["F403"]
[project.optional-dependencies] [project.optional-dependencies]
test = [ test = [
"pytest>=8.1.0", "pytest>=8.1.0",
"sqlalchemy" "sqlalchemy",
"ruff>=0.5.0",
] ]

View File

@ -14,7 +14,7 @@ def execute_code(code):
try: try:
exec(code, exec_globals, exec_locals) exec(code, exec_globals, exec_locals)
except Exception as e: except Exception:
traceback.print_exc(file=stderr) traceback.print_exc(file=stderr)
output = stdout.getvalue() output = stdout.getvalue()

View File

@ -21,14 +21,13 @@ from typing import TYPE_CHECKING
from transformers.utils import _LazyModule from transformers.utils import _LazyModule
from transformers.utils.import_utils import define_import_structure from transformers.utils.import_utils import define_import_structure
if TYPE_CHECKING: if TYPE_CHECKING:
from .agents import * from .agents import *
from .default_tools import * from .default_tools import *
from .gradio_ui import *
from .models import *
from .local_python_executor import *
from .e2b_executor import * from .e2b_executor import *
from .gradio_ui import *
from .local_python_executor import *
from .models import *
from .monitoring import * from .monitoring import *
from .prompts import * from .prompts import *
from .tools import * from .tools import *

View File

@ -15,49 +15,50 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import time import time
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
from dataclasses import dataclass from dataclasses import dataclass
from rich.syntax import Syntax from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from rich.console import Group from rich.console import Group
from rich.panel import Panel from rich.panel import Panel
from rich.rule import Rule from rich.rule import Rule
from rich.syntax import Syntax
from rich.text import Text from rich.text import Text
from .utils import (
console,
parse_code_blob,
parse_json_tool_call,
truncate_content,
AgentError,
AgentParsingError,
AgentExecutionError,
AgentGenerationError,
AgentMaxStepsError,
)
from .types import AgentAudio, AgentImage, handle_agent_output_types
from .default_tools import FinalAnswerTool from .default_tools import FinalAnswerTool
from .e2b_executor import E2BExecutor
from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonInterpreter
from .models import MessageRole from .models import MessageRole
from .monitoring import Monitor from .monitoring import Monitor
from .prompts import ( from .prompts import (
CODE_SYSTEM_PROMPT, CODE_SYSTEM_PROMPT,
TOOL_CALLING_SYSTEM_PROMPT, MANAGED_AGENT_PROMPT,
PLAN_UPDATE_FINAL_PLAN_REDACTION, PLAN_UPDATE_FINAL_PLAN_REDACTION,
SYSTEM_PROMPT_FACTS, SYSTEM_PROMPT_FACTS,
SYSTEM_PROMPT_FACTS_UPDATE, SYSTEM_PROMPT_FACTS_UPDATE,
USER_PROMPT_FACTS_UPDATE,
USER_PROMPT_PLAN_UPDATE,
USER_PROMPT_PLAN,
SYSTEM_PROMPT_PLAN_UPDATE,
SYSTEM_PROMPT_PLAN, SYSTEM_PROMPT_PLAN,
MANAGED_AGENT_PROMPT, SYSTEM_PROMPT_PLAN_UPDATE,
TOOL_CALLING_SYSTEM_PROMPT,
USER_PROMPT_FACTS_UPDATE,
USER_PROMPT_PLAN,
USER_PROMPT_PLAN_UPDATE,
) )
from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonInterpreter
from .e2b_executor import E2BExecutor
from .tools import ( from .tools import (
DEFAULT_TOOL_DESCRIPTION_TEMPLATE, DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
Tool, Tool,
get_tool_description_with_args,
Toolbox, Toolbox,
get_tool_description_with_args,
)
from .types import AgentAudio, AgentImage, handle_agent_output_types
from .utils import (
AgentError,
AgentExecutionError,
AgentGenerationError,
AgentMaxStepsError,
AgentParsingError,
console,
parse_code_blob,
parse_json_tool_call,
truncate_content,
) )

View File

@ -18,20 +18,20 @@ import json
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional from typing import Dict, Optional
from huggingface_hub import hf_hub_download, list_spaces
from transformers.utils import is_offline_mode from huggingface_hub import hf_hub_download, list_spaces
from transformers.models.whisper import ( from transformers.models.whisper import (
WhisperProcessor,
WhisperForConditionalGeneration, WhisperForConditionalGeneration,
WhisperProcessor,
) )
from transformers.utils import is_offline_mode
from .local_python_executor import ( from .local_python_executor import (
BASE_BUILTIN_MODULES, BASE_BUILTIN_MODULES,
BASE_PYTHON_TOOLS, BASE_PYTHON_TOOLS,
evaluate_python_code, evaluate_python_code,
) )
from .tools import TOOL_CONFIG_FILE, Tool, PipelineTool from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool
from .types import AgentAudio from .types import AgentAudio
@ -271,8 +271,8 @@ class VisitWebpageTool(Tool):
def forward(self, url: str) -> str: def forward(self, url: str) -> str:
try: try:
from markdownify import markdownify
import requests import requests
from markdownify import markdownify
from requests.exceptions import RequestException from requests.exceptions import RequestException
except ImportError: except ImportError:
raise ImportError( raise ImportError(

View File

@ -14,18 +14,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dotenv import load_dotenv
import textwrap
import base64 import base64
import pickle import pickle
import textwrap
from io import BytesIO from io import BytesIO
from typing import Any, List, Tuple
from dotenv import load_dotenv
from e2b_code_interpreter import Sandbox
from PIL import Image from PIL import Image
from e2b_code_interpreter import Sandbox
from typing import List, Tuple, Any
from .tool_validation import validate_tool_attributes from .tool_validation import validate_tool_attributes
from .utils import instance_to_source, BASE_BUILTIN_MODULES, console
from .tools import Tool from .tools import Tool
from .utils import BASE_BUILTIN_MODULES, console, instance_to_source
load_dotenv() load_dotenv()

View File

@ -14,10 +14,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
from .agents import MultiStepAgent, AgentStep, ActionStep
import gradio as gr import gradio as gr
from .agents import ActionStep, AgentStep, MultiStepAgent
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True): def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
"""Extract ChatMessage objects from agent steps""" """Extract ChatMessage objects from agent steps"""

View File

@ -17,14 +17,15 @@
import ast import ast
import builtins import builtins
import difflib import difflib
import math
from collections.abc import Mapping from collections.abc import Mapping
from importlib import import_module from importlib import import_module
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
import math
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from .utils import truncate_content, BASE_BUILTIN_MODULES from .utils import BASE_BUILTIN_MODULES, truncate_content
class InterpreterError(ValueError): class InterpreterError(ValueError):

View File

@ -14,24 +14,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from copy import deepcopy
from enum import Enum
import json import json
from typing import Dict, List, Optional
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
StoppingCriteria,
StoppingCriteriaList,
)
import litellm
import logging import logging
import os import os
import random import random
import torch from copy import deepcopy
from enum import Enum
from typing import Dict, List, Optional, Tuple, Union
import litellm
import torch
from huggingface_hub import InferenceClient from huggingface_hub import InferenceClient
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
)
from .tools import Tool from .tools import Tool
from .utils import parse_json_tool_call from .utils import parse_json_tool_call
@ -352,16 +351,16 @@ class TransformersModel(Model):
) )
# Get LLM output # Get LLM output
prompt = self.tokenizer.apply_chat_template( prompt_tensor = self.tokenizer.apply_chat_template(
messages, messages,
return_tensors="pt", return_tensors="pt",
return_dict=True, return_dict=True,
) )
prompt = prompt.to(self.model.device) prompt_tensor = prompt_tensor.to(self.model.device)
count_prompt_tokens = prompt["input_ids"].shape[1] count_prompt_tokens = prompt_tensor["input_ids"].shape[1]
out = self.model.generate( out = self.model.generate(
**prompt, **prompt_tensor,
max_new_tokens=max_tokens, max_new_tokens=max_tokens,
stopping_criteria=( stopping_criteria=(
self.make_stopping_criteria(stop_sequences) if stop_sequences else None self.make_stopping_criteria(stop_sequences) if stop_sequences else None
@ -383,7 +382,7 @@ class TransformersModel(Model):
available_tools: List[Tool], available_tools: List[Tool],
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
max_tokens: int = 500, max_tokens: int = 500,
) -> str: ) -> Tuple[str, Union[str, None], str]:
messages = get_clean_message_list( messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions messages, role_conversions=tool_role_conversions
) )

View File

@ -14,9 +14,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .utils import console
from rich.text import Text from rich.text import Text
from .utils import console
class Monitor: class Monitor:
def __init__(self, tracked_model): def __init__(self, tracked_model):

View File

@ -1,8 +1,9 @@
import ast import ast
import inspect
import builtins import builtins
from typing import Set import inspect
import textwrap import textwrap
from typing import Set
from .utils import BASE_BUILTIN_MODULES from .utils import BASE_BUILTIN_MODULES
_BUILTIN_NAMES = set(vars(builtins)) _BUILTIN_NAMES = set(vars(builtins))

View File

@ -18,14 +18,16 @@ import ast
import importlib import importlib
import inspect import inspect
import json import json
import logging
import os import os
import sys import sys
import tempfile import tempfile
import torch
import textwrap import textwrap
from functools import lru_cache, wraps from functools import lru_cache, wraps
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, List, Optional, Union, get_type_hints from typing import Callable, Dict, List, Optional, Union, get_type_hints
import torch
from huggingface_hub import ( from huggingface_hub import (
create_repo, create_repo,
get_collection, get_collection,
@ -35,7 +37,8 @@ from huggingface_hub import (
) )
from huggingface_hub.utils import RepositoryNotFoundError from huggingface_hub.utils import RepositoryNotFoundError
from packaging import version from packaging import version
import logging from transformers import AutoProcessor
from transformers.dynamic_module_utils import get_imports
from transformers.utils import ( from transformers.utils import (
TypeHintParsingException, TypeHintParsingException,
cached_file, cached_file,
@ -45,13 +48,9 @@ from transformers.utils import (
) )
from transformers.utils.chat_template_utils import _parse_type_hint from transformers.utils.chat_template_utils import _parse_type_hint
from transformers.dynamic_module_utils import get_imports from .tool_validation import MethodChecker, validate_tool_attributes
from transformers import AutoProcessor
from .types import ImageType, handle_agent_input_types, handle_agent_output_types from .types import ImageType, handle_agent_input_types, handle_agent_output_types
from .utils import instance_to_source from .utils import instance_to_source
from .tool_validation import validate_tool_attributes, MethodChecker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -12,21 +12,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import os import os
import pathlib import pathlib
import tempfile import tempfile
import uuid import uuid
from io import BytesIO from io import BytesIO
import requests
import numpy as np
import numpy as np
import requests
from transformers.utils import ( from transformers.utils import (
is_soundfile_availble, is_soundfile_availble,
is_torch_available, is_torch_available,
is_vision_available, is_vision_available,
) )
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -14,14 +14,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import ast
import inspect
import json import json
import re import re
from typing import Tuple, Dict, Union
import ast
from rich.console import Console
import inspect
import types import types
from typing import Dict, Tuple, Union
from rich.console import Console
from transformers.utils.import_utils import _is_package_available from transformers.utils.import_utils import _is_package_available
_pygments_available = _is_package_available("pygments") _pygments_available = _is_package_available("pygments")

View File

@ -16,22 +16,22 @@ import os
import tempfile import tempfile
import unittest import unittest
import uuid import uuid
import pytest
from pathlib import Path from pathlib import Path
from smolagents.types import AgentText, AgentImage import pytest
from transformers.testing_utils import get_tests_dir
from smolagents.agents import ( from smolagents.agents import (
AgentMaxStepsError, AgentMaxStepsError,
ManagedAgent,
CodeAgent, CodeAgent,
ToolCallingAgent, ManagedAgent,
Toolbox, Toolbox,
ToolCall, ToolCall,
ToolCallingAgent,
) )
from smolagents.tools import tool
from smolagents.default_tools import PythonInterpreterTool from smolagents.default_tools import PythonInterpreterTool
from transformers.testing_utils import get_tests_dir from smolagents.tools import tool
from smolagents.types import AgentImage, AgentText
def get_new_path(suffix="") -> str: def get_new_path(suffix="") -> str:

View File

@ -17,12 +17,13 @@ import ast
import os import os
import re import re
import shutil import shutil
import tempfile
import subprocess import subprocess
import tempfile
import traceback import traceback
import pytest
from pathlib import Path from pathlib import Path
from typing import List from typing import List
import pytest
from dotenv import load_dotenv from dotenv import load_dotenv

View File

@ -18,16 +18,14 @@ from pathlib import Path
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import get_tests_dir, require_torch from transformers.testing_utils import get_tests_dir, require_torch
from smolagents.types import AGENT_TYPE_MAPPING
from smolagents.default_tools import FinalAnswerTool from smolagents.default_tools import FinalAnswerTool
from smolagents.types import AGENT_TYPE_MAPPING
from .test_tools import ToolTesterMixin from .test_tools import ToolTesterMixin
if is_torch_available(): if is_torch_available():
import torch import torch

View File

@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import unittest import unittest
from smolagents import models, tool
from typing import Optional from typing import Optional
from smolagents import models, tool
class ModelTests(unittest.TestCase): class ModelTests(unittest.TestCase):
def test_get_json_schema_has_nullable_args(self): def test_get_json_schema_has_nullable_args(self):

View File

@ -16,8 +16,8 @@
import unittest import unittest
from smolagents import ( from smolagents import (
AgentImage,
AgentError, AgentError,
AgentImage,
CodeAgent, CodeAgent,
ToolCallingAgent, ToolCallingAgent,
stream_to_gradio, stream_to_gradio,

View File

@ -19,12 +19,12 @@ import numpy as np
import pytest import pytest
from smolagents import load_tool from smolagents import load_tool
from smolagents.types import AGENT_TYPE_MAPPING
from smolagents.default_tools import BASE_PYTHON_TOOLS from smolagents.default_tools import BASE_PYTHON_TOOLS
from smolagents.local_python_executor import ( from smolagents.local_python_executor import (
InterpreterError, InterpreterError,
evaluate_python_code, evaluate_python_code,
) )
from smolagents.types import AGENT_TYPE_MAPPING
from .test_tools import ToolTesterMixin from .test_tools import ToolTesterMixin

View File

@ -14,21 +14,20 @@
# limitations under the License. # limitations under the License.
import unittest import unittest
from pathlib import Path from pathlib import Path
from typing import Dict, Union, Optional from typing import Dict, Optional, Union
import numpy as np import numpy as np
import pytest import pytest
from transformers import is_torch_available, is_vision_available from transformers import is_torch_available, is_vision_available
from transformers.testing_utils import get_tests_dir
from smolagents.tools import AUTHORIZED_TYPES, Tool, tool
from smolagents.types import ( from smolagents.types import (
AGENT_TYPE_MAPPING, AGENT_TYPE_MAPPING,
AgentAudio, AgentAudio,
AgentImage, AgentImage,
AgentText, AgentText,
) )
from smolagents.tools import Tool, tool, AUTHORIZED_TYPES
from transformers.testing_utils import get_tests_dir
if is_torch_available(): if is_torch_available():
import torch import torch

View File

@ -18,7 +18,8 @@ import unittest
import uuid import uuid
from pathlib import Path from pathlib import Path
from smolagents.types import AgentAudio, AgentImage, AgentText import torch
from PIL import Image
from transformers.testing_utils import ( from transformers.testing_utils import (
require_soundfile, require_soundfile,
require_torch, require_torch,
@ -28,9 +29,7 @@ from transformers.utils import (
is_soundfile_availble, is_soundfile_availble,
) )
import torch from smolagents.types import AgentAudio, AgentImage, AgentText
from PIL import Image
if is_soundfile_availble(): if is_soundfile_availble():
import soundfile as sf import soundfile as sf

View File

@ -1,8 +1,7 @@
import os import os
import unittest
import shutil import shutil
import tempfile import tempfile
import unittest
from pathlib import Path from pathlib import Path