Clean makefile, pyproject.toml and CI (#229)
* Clean makefile / pyproject.toml / .github * new tests after * add back sqlalchemy * disable docs tests in CI * continue on error * correct continue on error * Remove all_docs test
This commit is contained in:
parent
fabc59aa08
commit
1f8fd72acb
|
@ -16,20 +16,15 @@ jobs:
|
||||||
python-version: "3.12"
|
python-version: "3.12"
|
||||||
|
|
||||||
# Setup venv
|
# Setup venv
|
||||||
# TODO: revisit when https://github.com/astral-sh/uv/issues/1526 is addressed.
|
|
||||||
- name: Setup venv + uv
|
- name: Setup venv + uv
|
||||||
run: |
|
run: |
|
||||||
pip install --upgrade uv
|
pip install --upgrade uv
|
||||||
uv venv
|
uv venv
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: uv pip install "smolagents[test] @ ."
|
run: uv pip install "smolagents[quality] @ ."
|
||||||
- run: uv run ruff check tests src # linter
|
|
||||||
- run: uv run ruff format --check tests src # formatter
|
|
||||||
|
|
||||||
# Run type checking at least on smolagents root file to check all modules
|
# Equivalent of "make quality" but step by step
|
||||||
# that can be lazy-loaded actually exist.
|
- run: uv run ruff check examples src tests utils # linter
|
||||||
# - run: uv run mypy src/smolagents/__init__.py --follow-imports=silent --show-traceback
|
- run: uv run ruff format --check examples src tests utils # formatter
|
||||||
|
- run: uv run python utils/check_tests_in_ci.py
|
||||||
# Run mypy on full package
|
|
||||||
# - run: uv run mypy src
|
|
|
@ -20,9 +20,7 @@ jobs:
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
|
||||||
# Setup venv
|
# Setup venv
|
||||||
# TODO: revisit when https://github.com/astral-sh/uv/issues/1526 is addressed.
|
|
||||||
- name: Setup venv + uv
|
- name: Setup venv + uv
|
||||||
run: |
|
run: |
|
||||||
pip install --upgrade uv
|
pip install --upgrade uv
|
||||||
|
@ -33,33 +31,59 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
uv pip install "smolagents[test] @ ."
|
uv pip install "smolagents[test] @ ."
|
||||||
|
|
||||||
|
# Run all tests separately for individual feedback
|
||||||
|
# Use 'if success() || failure()' so that all tests are run even if one failed
|
||||||
|
# See https://stackoverflow.com/a/62112985
|
||||||
- name: Agent tests
|
- name: Agent tests
|
||||||
run: |
|
run: |
|
||||||
uv run pytest -sv ./tests/test_agents.py
|
uv run pytest ./tests/test_agents.py
|
||||||
|
if: ${{ success() || failure() }}
|
||||||
|
|
||||||
- name: Default tools tests
|
- name: Default tools tests
|
||||||
run: |
|
run: |
|
||||||
uv run pytest -sv ./tests/test_default_tools.py
|
uv run pytest ./tests/test_default_tools.py
|
||||||
|
if: ${{ success() || failure() }}
|
||||||
|
|
||||||
|
# - name: Docs tests # Disabled for now (slow test + requires API keys)
|
||||||
|
# run: |
|
||||||
|
# uv run pytest ./tests/test_all_docs.py
|
||||||
|
|
||||||
- name: Final answer tests
|
- name: Final answer tests
|
||||||
run: |
|
run: |
|
||||||
uv run pytest -sv ./tests/test_final_answer.py
|
uv run pytest ./tests/test_final_answer.py
|
||||||
|
if: ${{ success() || failure() }}
|
||||||
|
|
||||||
- name: Models tests
|
- name: Models tests
|
||||||
run: |
|
run: |
|
||||||
uv run pytest -sv ./tests/test_models.py
|
uv run pytest ./tests/test_models.py
|
||||||
|
if: ${{ success() || failure() }}
|
||||||
|
|
||||||
- name: Monitoring tests
|
- name: Monitoring tests
|
||||||
run: |
|
run: |
|
||||||
uv run pytest -sv ./tests/test_monitoring.py
|
uv run pytest ./tests/test_monitoring.py
|
||||||
|
if: ${{ success() || failure() }}
|
||||||
|
|
||||||
- name: Python interpreter tests
|
- name: Python interpreter tests
|
||||||
run: |
|
run: |
|
||||||
uv run pytest -sv ./tests/test_python_interpreter.py
|
uv run pytest ./tests/test_python_interpreter.py
|
||||||
|
if: ${{ success() || failure() }}
|
||||||
|
|
||||||
- name: Search tests
|
- name: Search tests
|
||||||
run: |
|
run: |
|
||||||
uv run pytest -sv ./tests/test_search.py
|
uv run pytest ./tests/test_search.py
|
||||||
|
if: ${{ success() || failure() }}
|
||||||
|
|
||||||
- name: Tools tests
|
- name: Tools tests
|
||||||
run: |
|
run: |
|
||||||
uv run pytest -sv ./tests/test_tools.py
|
uv run pytest ./tests/test_tools.py
|
||||||
|
if: ${{ success() || failure() }}
|
||||||
|
|
||||||
- name: Types tests
|
- name: Types tests
|
||||||
run: |
|
run: |
|
||||||
uv run pytest -sv ./tests/test_types.py
|
uv run pytest ./tests/test_types.py
|
||||||
|
if: ${{ success() || failure() }}
|
||||||
|
|
||||||
- name: Utils tests
|
- name: Utils tests
|
||||||
run: |
|
run: |
|
||||||
uv run pytest -sv ./tests/test_utils.py
|
uv run pytest ./tests/test_utils.py
|
||||||
|
if: ${{ success() || failure() }}
|
||||||
|
|
|
@ -91,7 +91,7 @@ happy to make the changes or help you make a contribution if you're interested!
|
||||||
|
|
||||||
## I want to become a maintainer of the project. How do I get there?
|
## I want to become a maintainer of the project. How do I get there?
|
||||||
|
|
||||||
smolagents is a project led and managed by Hugging Face. We are more than
|
smolagents is a project led and managed by Hugging Face. We are more than
|
||||||
happy to have motivated individuals from other organizations join us as maintainers with the goal of helping smolagents
|
happy to have motivated individuals from other organizations join us as maintainers with the goal of helping smolagents
|
||||||
make a dent in the world of Agents.
|
make a dent in the world of Agents.
|
||||||
|
|
||||||
|
|
47
Makefile
47
Makefile
|
@ -1,53 +1,18 @@
|
||||||
.PHONY: quality style test docs utils
|
.PHONY: quality style test docs utils
|
||||||
|
|
||||||
check_dirs := .
|
check_dirs := examples src tests utils
|
||||||
|
|
||||||
# Check that source code meets quality standards
|
# Check code quality of the source code
|
||||||
|
|
||||||
extra_quality_checks:
|
|
||||||
python utils/check_copies.py
|
|
||||||
python utils/check_dummies.py
|
|
||||||
python utils/check_repo.py
|
|
||||||
doc-builder style smolagents docs/source --max_len 119
|
|
||||||
|
|
||||||
# this target runs checks on all files
|
|
||||||
quality:
|
quality:
|
||||||
ruff check $(check_dirs)
|
ruff check $(check_dirs)
|
||||||
ruff format --check $(check_dirs)
|
ruff format --check $(check_dirs)
|
||||||
doc-builder style smolagents docs/source --max_len 119 --check_only
|
python utils/check_tests_in_ci.py
|
||||||
|
|
||||||
# Format source code automatically and check is there are any problems left that need manual fixing
|
# Format source code automatically
|
||||||
style:
|
style:
|
||||||
ruff check $(check_dirs) --fix
|
ruff check $(check_dirs) --fix
|
||||||
ruff format $(check_dirs)
|
ruff format $(check_dirs)
|
||||||
doc-builder style smolagents docs/source --max_len 119
|
|
||||||
|
|
||||||
# Run tests for the library
|
# Run smolagents tests
|
||||||
test_big_modeling:
|
|
||||||
python -m pytest -s -v ./tests/test_big_modeling.py ./tests/test_modeling_utils.py $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_big_modeling.log",)
|
|
||||||
|
|
||||||
test_core:
|
|
||||||
python -m pytest -s -v ./tests/ --ignore=./tests/test_examples.py $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_core.log",)
|
|
||||||
|
|
||||||
test_cli:
|
|
||||||
python -m pytest -s -v ./tests/test_cli.py $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_cli.log",)
|
|
||||||
|
|
||||||
|
|
||||||
# Since the new version of pytest will *change* how things are collected, we need `deepspeed` to
|
|
||||||
# run after test_core and test_cli
|
|
||||||
test:
|
test:
|
||||||
$(MAKE) test_core
|
pytest ./tests/
|
||||||
$(MAKE) test_cli
|
|
||||||
$(MAKE) test_big_modeling
|
|
||||||
$(MAKE) test_deepspeed
|
|
||||||
$(MAKE) test_fsdp
|
|
||||||
|
|
||||||
test_examples:
|
|
||||||
python -m pytest -s -v ./tests/test_examples.py $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_examples.log",)
|
|
||||||
|
|
||||||
# Same as test but used to install only the base dependencies
|
|
||||||
test_prod:
|
|
||||||
$(MAKE) test_core
|
|
||||||
|
|
||||||
test_rest:
|
|
||||||
python -m pytest -s -v ./tests/test_examples.py::FeatureExamplesTests $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_rest.log",)
|
|
20
README.md
20
README.md
|
@ -98,9 +98,27 @@ To contribute, follow our [contribution guide](https://github.com/huggingface/sm
|
||||||
At any moment, feel welcome to open an issue, citing your exact error traces and package versions if it's a bug.
|
At any moment, feel welcome to open an issue, citing your exact error traces and package versions if it's a bug.
|
||||||
It's often even better to open a PR with your proposed fixes/changes!
|
It's often even better to open a PR with your proposed fixes/changes!
|
||||||
|
|
||||||
|
To install dev dependencies, run:
|
||||||
|
```
|
||||||
|
pip install -e ".[dev]"
|
||||||
|
```
|
||||||
|
|
||||||
|
When making changes to the codebase, please check that it follows the repo's code quality requirements by running:
|
||||||
|
To check code quality of the source code:
|
||||||
|
```
|
||||||
|
make quality
|
||||||
|
```
|
||||||
|
|
||||||
|
If the checks fail, you can run the formatter with:
|
||||||
|
```
|
||||||
|
make style
|
||||||
|
```
|
||||||
|
|
||||||
|
And commit the changes.
|
||||||
|
|
||||||
To run tests locally, run this command:
|
To run tests locally, run this command:
|
||||||
```bash
|
```bash
|
||||||
pytest -sv .
|
pytest .
|
||||||
```
|
```
|
||||||
|
|
||||||
## Citing smolagents
|
## Citing smolagents
|
||||||
|
|
|
@ -254,7 +254,10 @@
|
||||||
" if is_vanilla_llm:\n",
|
" if is_vanilla_llm:\n",
|
||||||
" llm = agent\n",
|
" llm = agent\n",
|
||||||
" answer = str(llm([{\"role\": \"user\", \"content\": question}]).content)\n",
|
" answer = str(llm([{\"role\": \"user\", \"content\": question}]).content)\n",
|
||||||
" token_count = {\"input\": llm.last_input_token_count, \"output\": llm.last_output_token_count}\n",
|
" token_count = {\n",
|
||||||
|
" \"input\": llm.last_input_token_count,\n",
|
||||||
|
" \"output\": llm.last_output_token_count,\n",
|
||||||
|
" }\n",
|
||||||
" intermediate_steps = str([])\n",
|
" intermediate_steps = str([])\n",
|
||||||
" else:\n",
|
" else:\n",
|
||||||
" answer = str(agent.run(question))\n",
|
" answer = str(agent.run(question))\n",
|
||||||
|
@ -983,7 +986,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 12,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -1043,8 +1046,8 @@
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Usage (after running your previous data processing code):\n",
|
"# Usage (after running your previous data processing code):\n",
|
||||||
"mathjax_table = create_mathjax_table(pivot_df, formatted_df)\n",
|
"# mathjax_table = create_mathjax_table(pivot_df, formatted_df)\n",
|
||||||
"print(mathjax_table)"
|
"# print(mathjax_table)"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
|
@ -4,8 +4,9 @@ from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
class GetCatImageTool(Tool):
|
class GetCatImageTool(Tool):
|
||||||
name="get_cat_image"
|
name = "get_cat_image"
|
||||||
description = "Get a cat image"
|
description = "Get a cat image"
|
||||||
inputs = {}
|
inputs = {}
|
||||||
output_type = "image"
|
output_type = "image"
|
||||||
|
@ -27,17 +28,22 @@ class GetCatImageTool(Tool):
|
||||||
get_cat_image = GetCatImageTool()
|
get_cat_image = GetCatImageTool()
|
||||||
|
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools = [get_cat_image, VisitWebpageTool()],
|
tools=[get_cat_image, VisitWebpageTool()],
|
||||||
model=HfApiModel(),
|
model=HfApiModel(),
|
||||||
additional_authorized_imports=["Pillow", "requests", "markdownify"], # "duckduckgo-search",
|
additional_authorized_imports=[
|
||||||
use_e2b_executor=True
|
"Pillow",
|
||||||
|
"requests",
|
||||||
|
"markdownify",
|
||||||
|
], # "duckduckgo-search",
|
||||||
|
use_e2b_executor=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
agent.run(
|
agent.run(
|
||||||
"Return me an image of a cat. Directly use the image provided in your state.", additional_args={"cat_image":get_cat_image()}
|
"Return me an image of a cat. Directly use the image provided in your state.",
|
||||||
) # Asking to directly return the image from state tests that additional_args are properly sent to server.
|
additional_args={"cat_image": get_cat_image()},
|
||||||
|
) # Asking to directly return the image from state tests that additional_args are properly sent to server.
|
||||||
|
|
||||||
# Try the agent in a Gradio UI
|
# Try the agent in a Gradio UI
|
||||||
from smolagents import GradioUI
|
from smolagents import GradioUI
|
||||||
|
|
||||||
GradioUI(agent).launch()
|
GradioUI(agent).launch()
|
||||||
|
|
|
@ -1,11 +1,5 @@
|
||||||
from smolagents import (
|
from smolagents import CodeAgent, HfApiModel, GradioUI
|
||||||
CodeAgent,
|
|
||||||
HfApiModel,
|
|
||||||
GradioUI
|
|
||||||
)
|
|
||||||
|
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(tools=[], model=HfApiModel(), max_steps=4, verbosity_level=1)
|
||||||
tools=[], model=HfApiModel(), max_steps=4, verbosity_level=1
|
|
||||||
)
|
|
||||||
|
|
||||||
GradioUI(agent, file_upload_folder='./data').launch()
|
GradioUI(agent, file_upload_folder="./data").launch()
|
||||||
|
|
|
@ -16,7 +16,9 @@ from smolagents import (
|
||||||
# Let's setup the instrumentation first
|
# Let's setup the instrumentation first
|
||||||
|
|
||||||
trace_provider = TracerProvider()
|
trace_provider = TracerProvider()
|
||||||
trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter("http://0.0.0.0:6006/v1/traces")))
|
trace_provider.add_span_processor(
|
||||||
|
SimpleSpanProcessor(OTLPSpanExporter("http://0.0.0.0:6006/v1/traces"))
|
||||||
|
)
|
||||||
|
|
||||||
SmolagentsInstrumentor().instrument(tracer_provider=trace_provider, skip_dep_check=True)
|
SmolagentsInstrumentor().instrument(tracer_provider=trace_provider, skip_dep_check=True)
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,9 @@ from langchain_community.retrievers import BM25Retriever
|
||||||
|
|
||||||
|
|
||||||
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
|
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
|
||||||
knowledge_base = knowledge_base.filter(lambda row: row["source"].startswith("huggingface/transformers"))
|
knowledge_base = knowledge_base.filter(
|
||||||
|
lambda row: row["source"].startswith("huggingface/transformers")
|
||||||
|
)
|
||||||
|
|
||||||
source_docs = [
|
source_docs = [
|
||||||
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]})
|
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]})
|
||||||
|
@ -26,6 +28,7 @@ docs_processed = text_splitter.split_documents(source_docs)
|
||||||
|
|
||||||
from smolagents import Tool
|
from smolagents import Tool
|
||||||
|
|
||||||
|
|
||||||
class RetrieverTool(Tool):
|
class RetrieverTool(Tool):
|
||||||
name = "retriever"
|
name = "retriever"
|
||||||
description = "Uses semantic search to retrieve the parts of transformers documentation that could be most relevant to answer your query."
|
description = "Uses semantic search to retrieve the parts of transformers documentation that could be most relevant to answer your query."
|
||||||
|
@ -39,9 +42,7 @@ class RetrieverTool(Tool):
|
||||||
|
|
||||||
def __init__(self, docs, **kwargs):
|
def __init__(self, docs, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.retriever = BM25Retriever.from_documents(
|
self.retriever = BM25Retriever.from_documents(docs, k=10)
|
||||||
docs, k=10
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, query: str) -> str:
|
def forward(self, query: str) -> str:
|
||||||
assert isinstance(query, str), "Your search query must be a string"
|
assert isinstance(query, str), "Your search query must be a string"
|
||||||
|
@ -56,14 +57,20 @@ class RetrieverTool(Tool):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
from smolagents import HfApiModel, CodeAgent
|
from smolagents import HfApiModel, CodeAgent
|
||||||
|
|
||||||
retriever_tool = RetrieverTool(docs_processed)
|
retriever_tool = RetrieverTool(docs_processed)
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_steps=4, verbosity_level=2
|
tools=[retriever_tool],
|
||||||
|
model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"),
|
||||||
|
max_steps=4,
|
||||||
|
verbosity_level=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
agent_output = agent.run("For a transformers model training, which is slower, the forward or the backward pass?")
|
agent_output = agent.run(
|
||||||
|
"For a transformers model training, which is slower, the forward or the backward pass?"
|
||||||
|
)
|
||||||
|
|
||||||
print("Final output:")
|
print("Final output:")
|
||||||
print(agent_output)
|
print(agent_output)
|
||||||
|
|
|
@ -40,11 +40,14 @@ for row in rows:
|
||||||
inspector = inspect(engine)
|
inspector = inspect(engine)
|
||||||
columns_info = [(col["name"], col["type"]) for col in inspector.get_columns("receipts")]
|
columns_info = [(col["name"], col["type"]) for col in inspector.get_columns("receipts")]
|
||||||
|
|
||||||
table_description = "Columns:\n" + "\n".join([f" - {name}: {col_type}" for name, col_type in columns_info])
|
table_description = "Columns:\n" + "\n".join(
|
||||||
|
[f" - {name}: {col_type}" for name, col_type in columns_info]
|
||||||
|
)
|
||||||
print(table_description)
|
print(table_description)
|
||||||
|
|
||||||
from smolagents import tool
|
from smolagents import tool
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def sql_engine(query: str) -> str:
|
def sql_engine(query: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -66,10 +69,11 @@ def sql_engine(query: str) -> str:
|
||||||
output += "\n" + str(row)
|
output += "\n" + str(row)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
from smolagents import CodeAgent, HfApiModel
|
from smolagents import CodeAgent, HfApiModel
|
||||||
|
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[sql_engine],
|
tools=[sql_engine],
|
||||||
model=HfApiModel("meta-llama/Meta-Llama-3.1-8B-Instruct"),
|
model=HfApiModel("meta-llama/Meta-Llama-3.1-8B-Instruct"),
|
||||||
)
|
)
|
||||||
agent.run("Can you give me the name of the client who got the most expensive receipt?")
|
agent.run("Can you give me the name of the client who got the most expensive receipt?")
|
||||||
|
|
|
@ -9,6 +9,7 @@ from typing import Optional
|
||||||
# For anthropic: change model_id below to 'anthropic/claude-3-5-sonnet-20240620'
|
# For anthropic: change model_id below to 'anthropic/claude-3-5-sonnet-20240620'
|
||||||
model = LiteLLMModel(model_id="gpt-4o")
|
model = LiteLLMModel(model_id="gpt-4o")
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def get_weather(location: str, celsius: Optional[bool] = False) -> str:
|
def get_weather(location: str, celsius: Optional[bool] = False) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -21,6 +22,7 @@ def get_weather(location: str, celsius: Optional[bool] = False) -> str:
|
||||||
"""
|
"""
|
||||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||||
|
|
||||||
|
|
||||||
agent = ToolCallingAgent(tools=[get_weather], model=model)
|
agent = ToolCallingAgent(tools=[get_weather], model=model)
|
||||||
|
|
||||||
print(agent.run("What's the weather like in Paris?"))
|
print(agent.run("What's the weather like in Paris?"))
|
||||||
|
|
|
@ -4,10 +4,11 @@ from typing import Optional
|
||||||
|
|
||||||
model = LiteLLMModel(
|
model = LiteLLMModel(
|
||||||
model_id="ollama_chat/llama3.2",
|
model_id="ollama_chat/llama3.2",
|
||||||
api_base="http://localhost:11434", # replace with remote open-ai compatible server if necessary
|
api_base="http://localhost:11434", # replace with remote open-ai compatible server if necessary
|
||||||
api_key="your-api-key" # replace with API key if necessary
|
api_key="your-api-key", # replace with API key if necessary
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def get_weather(location: str, celsius: Optional[bool] = False) -> str:
|
def get_weather(location: str, celsius: Optional[bool] = False) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -20,6 +21,7 @@ def get_weather(location: str, celsius: Optional[bool] = False) -> str:
|
||||||
"""
|
"""
|
||||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||||
|
|
||||||
|
|
||||||
agent = ToolCallingAgent(tools=[get_weather], model=model)
|
agent = ToolCallingAgent(tools=[get_weather], model=model)
|
||||||
|
|
||||||
print(agent.run("What's the weather like in Paris?"))
|
print(agent.run("What's the weather like in Paris?"))
|
||||||
|
|
|
@ -26,27 +26,37 @@ dependencies = [
|
||||||
"openai>=1.58.1",
|
"openai>=1.58.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
audio = [
|
||||||
|
"soundfile",
|
||||||
|
]
|
||||||
|
torch = [
|
||||||
|
"torch",
|
||||||
|
"accelerate",
|
||||||
|
]
|
||||||
|
litellm = [
|
||||||
|
"litellm>=1.55.10",
|
||||||
|
]
|
||||||
|
quality = [
|
||||||
|
"ruff>=0.9.0",
|
||||||
|
]
|
||||||
|
test = [
|
||||||
|
"pytest>=8.1.0",
|
||||||
|
"smolagents[audio,litellm,torch]",
|
||||||
|
]
|
||||||
|
dev = [
|
||||||
|
"smolagents[quality,test]",
|
||||||
|
"sqlalchemy", # for ./examples
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
# Add the specified `OPTS` to the set of command line arguments as if they had been specified by the user.
|
||||||
|
addopts = "-sv --durations=0"
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
lint.ignore = ["F403"]
|
lint.ignore = ["F403"]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
dev = [
|
"examples/*" = [
|
||||||
"torch",
|
"E402", # module-import-not-at-top-of-file
|
||||||
"torchaudio",
|
]
|
||||||
"torchvision",
|
|
||||||
"sqlalchemy",
|
|
||||||
"accelerate",
|
|
||||||
"soundfile",
|
|
||||||
"litellm>=1.55.10",
|
|
||||||
]
|
|
||||||
test = [
|
|
||||||
"torch",
|
|
||||||
"torchaudio",
|
|
||||||
"torchvision",
|
|
||||||
"pytest>=8.1.0",
|
|
||||||
"sqlalchemy",
|
|
||||||
"ruff>=0.5.0",
|
|
||||||
"accelerate",
|
|
||||||
"soundfile",
|
|
||||||
"litellm>=1.55.10",
|
|
||||||
]
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025-present, the HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Check that all tests are called in CI."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
ROOT = Path(__file__).parent.parent
|
||||||
|
|
||||||
|
TESTS_FOLDER = ROOT / "tests"
|
||||||
|
CI_WORKFLOW_FILE = ROOT / ".github" / "workflows" / "tests.yml"
|
||||||
|
|
||||||
|
|
||||||
|
def check_tests_in_ci():
|
||||||
|
"""List all test files in `./tests/` and check if they are listed in the CI workflow.
|
||||||
|
|
||||||
|
Since each test file is triggered separately in the CI workflow, it is easy to forget a new one when adding new
|
||||||
|
tests, hence this check.
|
||||||
|
|
||||||
|
NOTE: current implementation is quite naive but should work for now. Must be updated if one want to ignore some
|
||||||
|
tests or if file naming is updated (currently only files starting by `test_*` are cheked)
|
||||||
|
"""
|
||||||
|
test_files = [
|
||||||
|
path.relative_to(TESTS_FOLDER).as_posix()
|
||||||
|
for path in TESTS_FOLDER.glob("**/*.py")
|
||||||
|
if path.name.startswith("test_")
|
||||||
|
]
|
||||||
|
ci_workflow_file_content = CI_WORKFLOW_FILE.read_text()
|
||||||
|
missing_test_files = [
|
||||||
|
test_file
|
||||||
|
for test_file in test_files
|
||||||
|
if test_file not in ci_workflow_file_content
|
||||||
|
]
|
||||||
|
if missing_test_files:
|
||||||
|
print(
|
||||||
|
"❌ Some test files seem to be ignored in the CI:\n"
|
||||||
|
+ "\n".join(f" - {test_file}" for test_file in missing_test_files)
|
||||||
|
+ f"\n Please add them manually in {CI_WORKFLOW_FILE}."
|
||||||
|
)
|
||||||
|
exit(1)
|
||||||
|
else:
|
||||||
|
print("✅ All good!")
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
check_tests_in_ci()
|
Loading…
Reference in New Issue