Merge branch 'main' of github.com:huggingface/smolagents

This commit is contained in:
Aymeric 2025-02-15 11:17:39 +01:00
commit 161ef452c3
5 changed files with 149 additions and 28 deletions

View File

@ -103,6 +103,11 @@ jobs:
uv run pytest ./tests/test_utils.py uv run pytest ./tests/test_utils.py
if: ${{ success() || failure() }} if: ${{ success() || failure() }}
- name: Gradio UI tests
run: |
uv run pytest ./tests/test_gradio_ui.py
if: ${{ success() || failure() }}
- name: Function type hints utils tests - name: Function type hints utils tests
run: | run: |
uv run pytest ./tests/test_function_type_hints_utils.py uv run pytest ./tests/test_function_type_hints_utils.py

View File

@ -67,6 +67,13 @@ agent.run("How many seconds would it take for a leopard at full speed to run thr
https://github.com/user-attachments/assets/cd0226e2-7479-4102-aea0-57c22ca47884 https://github.com/user-attachments/assets/cd0226e2-7479-4102-aea0-57c22ca47884
You can even share your agent to hub:
```py
agent.push_to_hub("m-ric/my_agent")
# agent.from_hub("m-ric/my_agent") to load an agent from Hub
```
Our library is LLM-agnostic: you could switch the example above to any inference provider. Our library is LLM-agnostic: you could switch the example above to any inference provider.
<details> <details>

View File

@ -936,7 +936,12 @@ You have been provided with these additional arguments, that you can access usin
@classmethod @classmethod
def from_folder(cls, folder: Union[str, Path], **kwargs): def from_folder(cls, folder: Union[str, Path], **kwargs):
"""Loads an agent from a local folder""" """Loads an agent from a local folder.
Args:
folder (`str` or `Path`): The folder where the agent is saved.
**kwargs: Additional keyword arguments that will be passed to the agent's init.
"""
folder = Path(folder) folder = Path(folder)
agent_dict = json.loads((folder / "agent.json").read_text()) agent_dict = json.loads((folder / "agent.json").read_text())

View File

@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import mimetypes
import os import os
import re import re
import shutil import shutil
@ -199,30 +198,20 @@ class GradioUI:
yield messages yield messages
yield messages yield messages
def upload_file( def upload_file(self, file, file_uploads_log, allowed_file_types=None):
self,
file,
file_uploads_log,
allowed_file_types=[
"application/pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"text/plain",
],
):
""" """
Handle file uploads, default allowed types are .pdf, .docx, and .txt Handle file uploads, default allowed types are .pdf, .docx, and .txt
""" """
import gradio as gr import gradio as gr
if file is None: if file is None:
return gr.Textbox("No file uploaded", visible=True), file_uploads_log return gr.Textbox(value="No file uploaded", visible=True), file_uploads_log
try: if allowed_file_types is None:
mime_type, _ = mimetypes.guess_type(file.name) allowed_file_types = [".pdf", ".docx", ".txt"]
except Exception as e:
return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log
if mime_type not in allowed_file_types: file_ext = os.path.splitext(file.name)[1].lower()
if file_ext not in allowed_file_types:
return gr.Textbox("File type disallowed", visible=True), file_uploads_log return gr.Textbox("File type disallowed", visible=True), file_uploads_log
# Sanitize file name # Sanitize file name
@ -231,16 +220,6 @@ class GradioUI:
r"[^\w\-.]", "_", original_name r"[^\w\-.]", "_", original_name
) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores ) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores
type_to_ext = {}
for ext, t in mimetypes.types_map.items():
if t not in type_to_ext:
type_to_ext[t] = ext
# Ensure the extension correlates to the mime type
sanitized_name = sanitized_name.split(".")[:-1]
sanitized_name.append("" + type_to_ext[mime_type])
sanitized_name = "".join(sanitized_name)
# Save the uploaded file to the specified folder # Save the uploaded file to the specified folder
file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name)) file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name))
shutil.copy(file.name, file_path) shutil.copy(file.name, file_path)

125
tests/test_gradio_ui.py Normal file
View File

@ -0,0 +1,125 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import tempfile
import unittest
from unittest.mock import Mock, patch
from smolagents.gradio_ui import GradioUI
class GradioUITester(unittest.TestCase):
def setUp(self):
"""Initialize test environment"""
self.temp_dir = tempfile.mkdtemp()
self.mock_agent = Mock()
self.ui = GradioUI(agent=self.mock_agent, file_upload_folder=self.temp_dir)
self.allowed_types = [".pdf", ".docx", ".txt"]
def tearDown(self):
"""Clean up test environment"""
shutil.rmtree(self.temp_dir)
def test_upload_file_default_types(self):
"""Test default allowed file types"""
default_types = [".pdf", ".docx", ".txt"]
for file_type in default_types:
with tempfile.NamedTemporaryFile(suffix=file_type) as temp_file:
mock_file = Mock()
mock_file.name = temp_file.name
textbox, uploads_log = self.ui.upload_file(mock_file, [])
self.assertIn("File uploaded:", textbox.value)
self.assertEqual(len(uploads_log), 1)
self.assertTrue(os.path.exists(os.path.join(self.temp_dir, os.path.basename(temp_file.name))))
def test_upload_file_default_types_disallowed(self):
"""Test default disallowed file types"""
disallowed_types = [".exe", ".sh", ".py", ".jpg"]
for file_type in disallowed_types:
with tempfile.NamedTemporaryFile(suffix=file_type) as temp_file:
mock_file = Mock()
mock_file.name = temp_file.name
textbox, uploads_log = self.ui.upload_file(mock_file, [])
self.assertEqual(textbox.value, "File type disallowed")
self.assertEqual(len(uploads_log), 0)
def test_upload_file_success(self):
"""Test successful file upload scenario"""
with tempfile.NamedTemporaryFile(suffix=".txt") as temp_file:
mock_file = Mock()
mock_file.name = temp_file.name
textbox, uploads_log = self.ui.upload_file(mock_file, [])
self.assertIn("File uploaded:", textbox.value)
self.assertEqual(len(uploads_log), 1)
self.assertTrue(os.path.exists(os.path.join(self.temp_dir, os.path.basename(temp_file.name))))
self.assertEqual(uploads_log[0], os.path.join(self.temp_dir, os.path.basename(temp_file.name)))
def test_upload_file_none(self):
"""Test scenario when no file is selected"""
textbox, uploads_log = self.ui.upload_file(None, [])
self.assertEqual(textbox.value, "No file uploaded")
self.assertEqual(len(uploads_log), 0)
def test_upload_file_invalid_type(self):
"""Test disallowed file type"""
with tempfile.NamedTemporaryFile(suffix=".exe") as temp_file:
mock_file = Mock()
mock_file.name = temp_file.name
textbox, uploads_log = self.ui.upload_file(mock_file, [])
self.assertEqual(textbox.value, "File type disallowed")
self.assertEqual(len(uploads_log), 0)
def test_upload_file_special_chars(self):
"""Test scenario with special characters in filename"""
with tempfile.NamedTemporaryFile(suffix=".txt") as temp_file:
# Create a new temporary file with special characters
special_char_name = os.path.join(os.path.dirname(temp_file.name), "test@#$%^&*.txt")
shutil.copy(temp_file.name, special_char_name)
try:
mock_file = Mock()
mock_file.name = special_char_name
with patch("shutil.copy"):
textbox, uploads_log = self.ui.upload_file(mock_file, [])
self.assertIn("File uploaded:", textbox.value)
self.assertEqual(len(uploads_log), 1)
self.assertIn("test_____", uploads_log[0])
finally:
# Clean up the special character file
if os.path.exists(special_char_name):
os.remove(special_char_name)
def test_upload_file_custom_types(self):
"""Test custom allowed file types"""
with tempfile.NamedTemporaryFile(suffix=".csv") as temp_file:
mock_file = Mock()
mock_file.name = temp_file.name
textbox, uploads_log = self.ui.upload_file(mock_file, [], allowed_file_types=[".csv"])
self.assertIn("File uploaded:", textbox.value)
self.assertEqual(len(uploads_log), 1)