Sort imports and add test workflows
This commit is contained in:
parent
417c6685b0
commit
c22fedaee1
|
@ -15,19 +15,21 @@ jobs:
|
||||||
with:
|
with:
|
||||||
python-version: 3.10
|
python-version: 3.10
|
||||||
|
|
||||||
- name: Install Python dependencies
|
# Setup venv
|
||||||
run: pip install -e .[quality]
|
# TODO: revisit when https://github.com/astral-sh/uv/issues/1526 is addressed.
|
||||||
|
- name: Setup venv + uv
|
||||||
- name: Run Quality check
|
|
||||||
run: make quality
|
|
||||||
- name: Check if failure
|
|
||||||
if: ${{ failure() }}
|
|
||||||
run: |
|
run: |
|
||||||
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and rerun 'make style; make quality;'" >> $GITHUB_STEP_SUMMARY
|
pip install --upgrade uv
|
||||||
|
uv venv
|
||||||
|
|
||||||
- name: Run Style check
|
- name: Install dependencies
|
||||||
run: make style
|
run: uv pip install "smolagents[test] @ ."
|
||||||
- name: Check if failure
|
- run: uv run ruff check tests src # linter
|
||||||
if: ${{ failure() }}
|
- run: uv run ruff format --check tests src # formatter
|
||||||
run: |
|
|
||||||
echo "Style check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and rerun 'make style; make quality;'" >> $GITHUB_STEP_SUMMARY
|
# Run type checking at least on smolagents root file to check all modules
|
||||||
|
# that can be lazy-loaded actually exist.
|
||||||
|
# - run: uv run mypy src/smolagents/__init__.py --follow-imports=silent --show-traceback
|
||||||
|
|
||||||
|
# Run mypy on full package
|
||||||
|
# - run: uv run mypy src
|
|
@ -0,0 +1,47 @@
|
||||||
|
name: Python tests
|
||||||
|
|
||||||
|
on: [pull_request]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build-ubuntu:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
UV_HTTP_TIMEOUT: 600 # max 10min to install deps
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.10", "3.12"]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
|
||||||
|
# Setup venv
|
||||||
|
# TODO: revisit when https://github.com/astral-sh/uv/issues/1526 is addressed.
|
||||||
|
- name: Setup venv + uv
|
||||||
|
run: |
|
||||||
|
pip install --upgrade uv
|
||||||
|
uv venv
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
uv pip install "smolagents[test] @ ."
|
||||||
|
|
||||||
|
- name: Agent tests
|
||||||
|
run: |
|
||||||
|
uv run pytest -sv ./tests/test_agents.py
|
||||||
|
- name: Tool tests
|
||||||
|
run: |
|
||||||
|
uv run pytest -sv ./tests/test_toolss.py
|
||||||
|
- name: Python interpreter tests
|
||||||
|
run: |
|
||||||
|
uv run pytest -sv ./tests/test_python_interpreter.py
|
||||||
|
- name: Final answer tests
|
||||||
|
run: |
|
||||||
|
uv run pytest -sv ./tests/test_final_answer.py
|
|
@ -346,7 +346,7 @@
|
||||||
"import glob\n",
|
"import glob\n",
|
||||||
"\n",
|
"\n",
|
||||||
"res = []\n",
|
"res = []\n",
|
||||||
"for f in glob.glob(f\"output/*.jsonl\"):\n",
|
"for f in glob.glob(\"output/*.jsonl\"):\n",
|
||||||
" res.append(pd.read_json(f, lines=True))\n",
|
" res.append(pd.read_json(f, lines=True))\n",
|
||||||
"result_df = pd.concat(res)\n",
|
"result_df = pd.concat(res)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from smolagents.agents import ToolCallingAgent
|
from smolagents.agents import ToolCallingAgent
|
||||||
from smolagents import tool, HfApiModel, TransformersModel, LiteLLMModel
|
from smolagents import tool, LiteLLMModel
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
# Choose which LLM engine to use!
|
# Choose which LLM engine to use!
|
||||||
|
|
|
@ -12,25 +12,29 @@ authors = [
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"torch",
|
"torch",
|
||||||
"torchaudio",
|
"torchaudio",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
"transformers>=4.0.0",
|
"transformers>=4.0.0",
|
||||||
"requests>=2.32.3",
|
"requests>=2.32.3",
|
||||||
"rich>=13.9.4",
|
"rich>=13.9.4",
|
||||||
"pandas>=2.2.3",
|
"pandas>=2.2.3",
|
||||||
"jinja2>=3.1.4",
|
"jinja2>=3.1.4",
|
||||||
"pillow>=11.0.0",
|
"pillow>=11.0.0",
|
||||||
"markdownify>=0.14.1",
|
"markdownify>=0.14.1",
|
||||||
"gradio>=5.8.0",
|
"gradio>=5.8.0",
|
||||||
"duckduckgo-search>=6.3.7",
|
"duckduckgo-search>=6.3.7",
|
||||||
"python-dotenv>=1.0.1",
|
"python-dotenv>=1.0.1",
|
||||||
"e2b-code-interpreter>=1.0.3",
|
"e2b-code-interpreter>=1.0.3",
|
||||||
"litellm>=1.55.10",
|
"litellm>=1.55.10",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
ignore = ["F403"]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
test = [
|
test = [
|
||||||
"pytest>=8.1.0",
|
"pytest>=8.1.0",
|
||||||
"sqlalchemy"
|
"sqlalchemy",
|
||||||
|
"ruff>=0.5.0",
|
||||||
]
|
]
|
||||||
|
|
|
@ -14,7 +14,7 @@ def execute_code(code):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
exec(code, exec_globals, exec_locals)
|
exec(code, exec_globals, exec_locals)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
traceback.print_exc(file=stderr)
|
traceback.print_exc(file=stderr)
|
||||||
|
|
||||||
output = stdout.getvalue()
|
output = stdout.getvalue()
|
||||||
|
|
|
@ -21,14 +21,13 @@ from typing import TYPE_CHECKING
|
||||||
from transformers.utils import _LazyModule
|
from transformers.utils import _LazyModule
|
||||||
from transformers.utils.import_utils import define_import_structure
|
from transformers.utils.import_utils import define_import_structure
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .agents import *
|
from .agents import *
|
||||||
from .default_tools import *
|
from .default_tools import *
|
||||||
from .gradio_ui import *
|
|
||||||
from .models import *
|
|
||||||
from .local_python_executor import *
|
|
||||||
from .e2b_executor import *
|
from .e2b_executor import *
|
||||||
|
from .gradio_ui import *
|
||||||
|
from .local_python_executor import *
|
||||||
|
from .models import *
|
||||||
from .monitoring import *
|
from .monitoring import *
|
||||||
from .prompts import *
|
from .prompts import *
|
||||||
from .tools import *
|
from .tools import *
|
||||||
|
|
|
@ -15,49 +15,50 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import time
|
import time
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from rich.syntax import Syntax
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from rich.console import Group
|
from rich.console import Group
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from rich.rule import Rule
|
from rich.rule import Rule
|
||||||
|
from rich.syntax import Syntax
|
||||||
from rich.text import Text
|
from rich.text import Text
|
||||||
|
|
||||||
from .utils import (
|
|
||||||
console,
|
|
||||||
parse_code_blob,
|
|
||||||
parse_json_tool_call,
|
|
||||||
truncate_content,
|
|
||||||
AgentError,
|
|
||||||
AgentParsingError,
|
|
||||||
AgentExecutionError,
|
|
||||||
AgentGenerationError,
|
|
||||||
AgentMaxStepsError,
|
|
||||||
)
|
|
||||||
from .types import AgentAudio, AgentImage, handle_agent_output_types
|
|
||||||
from .default_tools import FinalAnswerTool
|
from .default_tools import FinalAnswerTool
|
||||||
|
from .e2b_executor import E2BExecutor
|
||||||
|
from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonInterpreter
|
||||||
from .models import MessageRole
|
from .models import MessageRole
|
||||||
from .monitoring import Monitor
|
from .monitoring import Monitor
|
||||||
from .prompts import (
|
from .prompts import (
|
||||||
CODE_SYSTEM_PROMPT,
|
CODE_SYSTEM_PROMPT,
|
||||||
TOOL_CALLING_SYSTEM_PROMPT,
|
MANAGED_AGENT_PROMPT,
|
||||||
PLAN_UPDATE_FINAL_PLAN_REDACTION,
|
PLAN_UPDATE_FINAL_PLAN_REDACTION,
|
||||||
SYSTEM_PROMPT_FACTS,
|
SYSTEM_PROMPT_FACTS,
|
||||||
SYSTEM_PROMPT_FACTS_UPDATE,
|
SYSTEM_PROMPT_FACTS_UPDATE,
|
||||||
USER_PROMPT_FACTS_UPDATE,
|
|
||||||
USER_PROMPT_PLAN_UPDATE,
|
|
||||||
USER_PROMPT_PLAN,
|
|
||||||
SYSTEM_PROMPT_PLAN_UPDATE,
|
|
||||||
SYSTEM_PROMPT_PLAN,
|
SYSTEM_PROMPT_PLAN,
|
||||||
MANAGED_AGENT_PROMPT,
|
SYSTEM_PROMPT_PLAN_UPDATE,
|
||||||
|
TOOL_CALLING_SYSTEM_PROMPT,
|
||||||
|
USER_PROMPT_FACTS_UPDATE,
|
||||||
|
USER_PROMPT_PLAN,
|
||||||
|
USER_PROMPT_PLAN_UPDATE,
|
||||||
)
|
)
|
||||||
from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonInterpreter
|
|
||||||
from .e2b_executor import E2BExecutor
|
|
||||||
from .tools import (
|
from .tools import (
|
||||||
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||||
Tool,
|
Tool,
|
||||||
get_tool_description_with_args,
|
|
||||||
Toolbox,
|
Toolbox,
|
||||||
|
get_tool_description_with_args,
|
||||||
|
)
|
||||||
|
from .types import AgentAudio, AgentImage, handle_agent_output_types
|
||||||
|
from .utils import (
|
||||||
|
AgentError,
|
||||||
|
AgentExecutionError,
|
||||||
|
AgentGenerationError,
|
||||||
|
AgentMaxStepsError,
|
||||||
|
AgentParsingError,
|
||||||
|
console,
|
||||||
|
parse_code_blob,
|
||||||
|
parse_json_tool_call,
|
||||||
|
truncate_content,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,20 +18,20 @@ import json
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
from huggingface_hub import hf_hub_download, list_spaces
|
|
||||||
|
|
||||||
from transformers.utils import is_offline_mode
|
from huggingface_hub import hf_hub_download, list_spaces
|
||||||
from transformers.models.whisper import (
|
from transformers.models.whisper import (
|
||||||
WhisperProcessor,
|
|
||||||
WhisperForConditionalGeneration,
|
WhisperForConditionalGeneration,
|
||||||
|
WhisperProcessor,
|
||||||
)
|
)
|
||||||
|
from transformers.utils import is_offline_mode
|
||||||
|
|
||||||
from .local_python_executor import (
|
from .local_python_executor import (
|
||||||
BASE_BUILTIN_MODULES,
|
BASE_BUILTIN_MODULES,
|
||||||
BASE_PYTHON_TOOLS,
|
BASE_PYTHON_TOOLS,
|
||||||
evaluate_python_code,
|
evaluate_python_code,
|
||||||
)
|
)
|
||||||
from .tools import TOOL_CONFIG_FILE, Tool, PipelineTool
|
from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool
|
||||||
from .types import AgentAudio
|
from .types import AgentAudio
|
||||||
|
|
||||||
|
|
||||||
|
@ -271,8 +271,8 @@ class VisitWebpageTool(Tool):
|
||||||
|
|
||||||
def forward(self, url: str) -> str:
|
def forward(self, url: str) -> str:
|
||||||
try:
|
try:
|
||||||
from markdownify import markdownify
|
|
||||||
import requests
|
import requests
|
||||||
|
from markdownify import markdownify
|
||||||
from requests.exceptions import RequestException
|
from requests.exceptions import RequestException
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
|
|
|
@ -14,18 +14,19 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from dotenv import load_dotenv
|
|
||||||
import textwrap
|
|
||||||
import base64
|
import base64
|
||||||
import pickle
|
import pickle
|
||||||
|
import textwrap
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
from typing import Any, List, Tuple
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from e2b_code_interpreter import Sandbox
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from e2b_code_interpreter import Sandbox
|
|
||||||
from typing import List, Tuple, Any
|
|
||||||
from .tool_validation import validate_tool_attributes
|
from .tool_validation import validate_tool_attributes
|
||||||
from .utils import instance_to_source, BASE_BUILTIN_MODULES, console
|
|
||||||
from .tools import Tool
|
from .tools import Tool
|
||||||
|
from .utils import BASE_BUILTIN_MODULES, console, instance_to_source
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
|
@ -14,10 +14,11 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
|
|
||||||
from .agents import MultiStepAgent, AgentStep, ActionStep
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
from .agents import ActionStep, AgentStep, MultiStepAgent
|
||||||
|
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
|
||||||
|
|
||||||
|
|
||||||
def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
|
def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
|
||||||
"""Extract ChatMessage objects from agent steps"""
|
"""Extract ChatMessage objects from agent steps"""
|
||||||
|
|
|
@ -17,14 +17,15 @@
|
||||||
import ast
|
import ast
|
||||||
import builtins
|
import builtins
|
||||||
import difflib
|
import difflib
|
||||||
|
import math
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
import math
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from .utils import truncate_content, BASE_BUILTIN_MODULES
|
from .utils import BASE_BUILTIN_MODULES, truncate_content
|
||||||
|
|
||||||
|
|
||||||
class InterpreterError(ValueError):
|
class InterpreterError(ValueError):
|
||||||
|
|
|
@ -14,24 +14,23 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from copy import deepcopy
|
|
||||||
from enum import Enum
|
|
||||||
import json
|
import json
|
||||||
from typing import Dict, List, Optional
|
|
||||||
from transformers import (
|
|
||||||
AutoTokenizer,
|
|
||||||
AutoModelForCausalLM,
|
|
||||||
StoppingCriteria,
|
|
||||||
StoppingCriteriaList,
|
|
||||||
)
|
|
||||||
|
|
||||||
import litellm
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import torch
|
from copy import deepcopy
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
import torch
|
||||||
from huggingface_hub import InferenceClient
|
from huggingface_hub import InferenceClient
|
||||||
|
from transformers import (
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
StoppingCriteria,
|
||||||
|
StoppingCriteriaList,
|
||||||
|
)
|
||||||
|
|
||||||
from .tools import Tool
|
from .tools import Tool
|
||||||
from .utils import parse_json_tool_call
|
from .utils import parse_json_tool_call
|
||||||
|
@ -352,16 +351,16 @@ class TransformersModel(Model):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get LLM output
|
# Get LLM output
|
||||||
prompt = self.tokenizer.apply_chat_template(
|
prompt_tensor = self.tokenizer.apply_chat_template(
|
||||||
messages,
|
messages,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
)
|
)
|
||||||
prompt = prompt.to(self.model.device)
|
prompt_tensor = prompt_tensor.to(self.model.device)
|
||||||
count_prompt_tokens = prompt["input_ids"].shape[1]
|
count_prompt_tokens = prompt_tensor["input_ids"].shape[1]
|
||||||
|
|
||||||
out = self.model.generate(
|
out = self.model.generate(
|
||||||
**prompt,
|
**prompt_tensor,
|
||||||
max_new_tokens=max_tokens,
|
max_new_tokens=max_tokens,
|
||||||
stopping_criteria=(
|
stopping_criteria=(
|
||||||
self.make_stopping_criteria(stop_sequences) if stop_sequences else None
|
self.make_stopping_criteria(stop_sequences) if stop_sequences else None
|
||||||
|
@ -383,7 +382,7 @@ class TransformersModel(Model):
|
||||||
available_tools: List[Tool],
|
available_tools: List[Tool],
|
||||||
stop_sequences: Optional[List[str]] = None,
|
stop_sequences: Optional[List[str]] = None,
|
||||||
max_tokens: int = 500,
|
max_tokens: int = 500,
|
||||||
) -> str:
|
) -> Tuple[str, Union[str, None], str]:
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
messages, role_conversions=tool_role_conversions
|
messages, role_conversions=tool_role_conversions
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,9 +14,10 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from .utils import console
|
|
||||||
from rich.text import Text
|
from rich.text import Text
|
||||||
|
|
||||||
|
from .utils import console
|
||||||
|
|
||||||
|
|
||||||
class Monitor:
|
class Monitor:
|
||||||
def __init__(self, tracked_model):
|
def __init__(self, tracked_model):
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
import ast
|
import ast
|
||||||
import inspect
|
|
||||||
import builtins
|
import builtins
|
||||||
from typing import Set
|
import inspect
|
||||||
import textwrap
|
import textwrap
|
||||||
|
from typing import Set
|
||||||
|
|
||||||
from .utils import BASE_BUILTIN_MODULES
|
from .utils import BASE_BUILTIN_MODULES
|
||||||
|
|
||||||
_BUILTIN_NAMES = set(vars(builtins))
|
_BUILTIN_NAMES = set(vars(builtins))
|
||||||
|
|
|
@ -18,14 +18,16 @@ import ast
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import torch
|
|
||||||
import textwrap
|
import textwrap
|
||||||
from functools import lru_cache, wraps
|
from functools import lru_cache, wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, List, Optional, Union, get_type_hints
|
from typing import Callable, Dict, List, Optional, Union, get_type_hints
|
||||||
|
|
||||||
|
import torch
|
||||||
from huggingface_hub import (
|
from huggingface_hub import (
|
||||||
create_repo,
|
create_repo,
|
||||||
get_collection,
|
get_collection,
|
||||||
|
@ -35,7 +37,8 @@ from huggingface_hub import (
|
||||||
)
|
)
|
||||||
from huggingface_hub.utils import RepositoryNotFoundError
|
from huggingface_hub.utils import RepositoryNotFoundError
|
||||||
from packaging import version
|
from packaging import version
|
||||||
import logging
|
from transformers import AutoProcessor
|
||||||
|
from transformers.dynamic_module_utils import get_imports
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
TypeHintParsingException,
|
TypeHintParsingException,
|
||||||
cached_file,
|
cached_file,
|
||||||
|
@ -45,13 +48,9 @@ from transformers.utils import (
|
||||||
)
|
)
|
||||||
from transformers.utils.chat_template_utils import _parse_type_hint
|
from transformers.utils.chat_template_utils import _parse_type_hint
|
||||||
|
|
||||||
from transformers.dynamic_module_utils import get_imports
|
from .tool_validation import MethodChecker, validate_tool_attributes
|
||||||
from transformers import AutoProcessor
|
|
||||||
|
|
||||||
from .types import ImageType, handle_agent_input_types, handle_agent_output_types
|
from .types import ImageType, handle_agent_input_types, handle_agent_output_types
|
||||||
from .utils import instance_to_source
|
from .utils import instance_to_source
|
||||||
from .tool_validation import validate_tool_attributes, MethodChecker
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -12,21 +12,20 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import tempfile
|
import tempfile
|
||||||
import uuid
|
import uuid
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import requests
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
is_soundfile_availble,
|
is_soundfile_availble,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
)
|
)
|
||||||
import logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -14,14 +14,14 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import ast
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import Tuple, Dict, Union
|
|
||||||
import ast
|
|
||||||
from rich.console import Console
|
|
||||||
import inspect
|
|
||||||
import types
|
import types
|
||||||
|
from typing import Dict, Tuple, Union
|
||||||
|
|
||||||
|
from rich.console import Console
|
||||||
from transformers.utils.import_utils import _is_package_available
|
from transformers.utils.import_utils import _is_package_available
|
||||||
|
|
||||||
_pygments_available = _is_package_available("pygments")
|
_pygments_available = _is_package_available("pygments")
|
||||||
|
|
|
@ -16,22 +16,22 @@ import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
import uuid
|
import uuid
|
||||||
import pytest
|
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from smolagents.types import AgentText, AgentImage
|
import pytest
|
||||||
|
from transformers.testing_utils import get_tests_dir
|
||||||
|
|
||||||
from smolagents.agents import (
|
from smolagents.agents import (
|
||||||
AgentMaxStepsError,
|
AgentMaxStepsError,
|
||||||
ManagedAgent,
|
|
||||||
CodeAgent,
|
CodeAgent,
|
||||||
ToolCallingAgent,
|
ManagedAgent,
|
||||||
Toolbox,
|
Toolbox,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
|
ToolCallingAgent,
|
||||||
)
|
)
|
||||||
from smolagents.tools import tool
|
|
||||||
from smolagents.default_tools import PythonInterpreterTool
|
from smolagents.default_tools import PythonInterpreterTool
|
||||||
from transformers.testing_utils import get_tests_dir
|
from smolagents.tools import tool
|
||||||
|
from smolagents.types import AgentImage, AgentText
|
||||||
|
|
||||||
|
|
||||||
def get_new_path(suffix="") -> str:
|
def get_new_path(suffix="") -> str:
|
||||||
|
|
|
@ -17,12 +17,13 @@ import ast
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import tempfile
|
||||||
import traceback
|
import traceback
|
||||||
import pytest
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,16 +18,14 @@ from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
from transformers.testing_utils import get_tests_dir, require_torch
|
from transformers.testing_utils import get_tests_dir, require_torch
|
||||||
from smolagents.types import AGENT_TYPE_MAPPING
|
|
||||||
|
|
||||||
from smolagents.default_tools import FinalAnswerTool
|
from smolagents.default_tools import FinalAnswerTool
|
||||||
|
from smolagents.types import AGENT_TYPE_MAPPING
|
||||||
|
|
||||||
from .test_tools import ToolTesterMixin
|
from .test_tools import ToolTesterMixin
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
|
@ -13,9 +13,10 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import unittest
|
import unittest
|
||||||
from smolagents import models, tool
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from smolagents import models, tool
|
||||||
|
|
||||||
|
|
||||||
class ModelTests(unittest.TestCase):
|
class ModelTests(unittest.TestCase):
|
||||||
def test_get_json_schema_has_nullable_args(self):
|
def test_get_json_schema_has_nullable_args(self):
|
||||||
|
|
|
@ -16,8 +16,8 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from smolagents import (
|
from smolagents import (
|
||||||
AgentImage,
|
|
||||||
AgentError,
|
AgentError,
|
||||||
|
AgentImage,
|
||||||
CodeAgent,
|
CodeAgent,
|
||||||
ToolCallingAgent,
|
ToolCallingAgent,
|
||||||
stream_to_gradio,
|
stream_to_gradio,
|
||||||
|
|
|
@ -19,12 +19,12 @@ import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from smolagents import load_tool
|
from smolagents import load_tool
|
||||||
from smolagents.types import AGENT_TYPE_MAPPING
|
|
||||||
from smolagents.default_tools import BASE_PYTHON_TOOLS
|
from smolagents.default_tools import BASE_PYTHON_TOOLS
|
||||||
from smolagents.local_python_executor import (
|
from smolagents.local_python_executor import (
|
||||||
InterpreterError,
|
InterpreterError,
|
||||||
evaluate_python_code,
|
evaluate_python_code,
|
||||||
)
|
)
|
||||||
|
from smolagents.types import AGENT_TYPE_MAPPING
|
||||||
|
|
||||||
from .test_tools import ToolTesterMixin
|
from .test_tools import ToolTesterMixin
|
||||||
|
|
||||||
|
|
|
@ -14,21 +14,20 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Union, Optional
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from transformers import is_torch_available, is_vision_available
|
from transformers import is_torch_available, is_vision_available
|
||||||
|
from transformers.testing_utils import get_tests_dir
|
||||||
|
|
||||||
|
from smolagents.tools import AUTHORIZED_TYPES, Tool, tool
|
||||||
from smolagents.types import (
|
from smolagents.types import (
|
||||||
AGENT_TYPE_MAPPING,
|
AGENT_TYPE_MAPPING,
|
||||||
AgentAudio,
|
AgentAudio,
|
||||||
AgentImage,
|
AgentImage,
|
||||||
AgentText,
|
AgentText,
|
||||||
)
|
)
|
||||||
from smolagents.tools import Tool, tool, AUTHORIZED_TYPES
|
|
||||||
from transformers.testing_utils import get_tests_dir
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -18,7 +18,8 @@ import unittest
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from smolagents.types import AgentAudio, AgentImage, AgentText
|
import torch
|
||||||
|
from PIL import Image
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_soundfile,
|
require_soundfile,
|
||||||
require_torch,
|
require_torch,
|
||||||
|
@ -28,9 +29,7 @@ from transformers.utils import (
|
||||||
is_soundfile_availble,
|
is_soundfile_availble,
|
||||||
)
|
)
|
||||||
|
|
||||||
import torch
|
from smolagents.types import AgentAudio, AgentImage, AgentText
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
if is_soundfile_availble():
|
if is_soundfile_availble():
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import unittest
|
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue