From 75b0e4f41cf06dcbcd4226d9a743fb631382f5e2 Mon Sep 17 00:00:00 2001 From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> Date: Fri, 14 Feb 2025 17:05:52 +0100 Subject: [PATCH 1/4] Update README.md --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index 89c902b..0191e79 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,13 @@ agent.run("How many seconds would it take for a leopard at full speed to run thr https://github.com/user-attachments/assets/cd0226e2-7479-4102-aea0-57c22ca47884 +You can even share your agent to hub: +```py +agent.push_to_hub("m-ric/my_agent") + +# agent.from_hub("m-ric/my_agent") to load an agent from Hub +``` + Our library is LLM-agnostic: you could switch the example above to any inference provider.
From a940f42c04ea11149b2500b8cc6cadf9174bba77 Mon Sep 17 00:00:00 2001 From: kingdomad <34766852+kingdomad@users.noreply.github.com> Date: Sat, 15 Feb 2025 01:03:22 +0800 Subject: [PATCH 2/4] refactor: simplify file type checking from MIME to extension (#342) --- .github/workflows/tests.yml | 5 ++ src/smolagents/gradio_ui.py | 33 ++-------- tests/test_gradio_ui.py | 125 ++++++++++++++++++++++++++++++++++++ 3 files changed, 136 insertions(+), 27 deletions(-) create mode 100644 tests/test_gradio_ui.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3b7d6d4..f93d4fc 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -103,6 +103,11 @@ jobs: uv run pytest ./tests/test_utils.py if: ${{ success() || failure() }} + - name: Gradio UI tests + run: | + uv run pytest ./tests/test_gradio_ui.py + if: ${{ success() || failure() }} + - name: Function type hints utils tests run: | uv run pytest ./tests/test_function_type_hints_utils.py diff --git a/src/smolagents/gradio_ui.py b/src/smolagents/gradio_ui.py index a0f04b9..11094a5 100644 --- a/src/smolagents/gradio_ui.py +++ b/src/smolagents/gradio_ui.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import mimetypes import os import re import shutil @@ -199,30 +198,20 @@ class GradioUI: yield messages yield messages - def upload_file( - self, - file, - file_uploads_log, - allowed_file_types=[ - "application/pdf", - "application/vnd.openxmlformats-officedocument.wordprocessingml.document", - "text/plain", - ], - ): + def upload_file(self, file, file_uploads_log, allowed_file_types=None): """ Handle file uploads, default allowed types are .pdf, .docx, and .txt """ import gradio as gr if file is None: - return gr.Textbox("No file uploaded", visible=True), file_uploads_log + return gr.Textbox(value="No file uploaded", visible=True), file_uploads_log - try: - mime_type, _ = mimetypes.guess_type(file.name) - except Exception as e: - return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log + if allowed_file_types is None: + allowed_file_types = [".pdf", ".docx", ".txt"] - if mime_type not in allowed_file_types: + file_ext = os.path.splitext(file.name)[1].lower() + if file_ext not in allowed_file_types: return gr.Textbox("File type disallowed", visible=True), file_uploads_log # Sanitize file name @@ -231,16 +220,6 @@ class GradioUI: r"[^\w\-.]", "_", original_name ) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores - type_to_ext = {} - for ext, t in mimetypes.types_map.items(): - if t not in type_to_ext: - type_to_ext[t] = ext - - # Ensure the extension correlates to the mime type - sanitized_name = sanitized_name.split(".")[:-1] - sanitized_name.append("" + type_to_ext[mime_type]) - sanitized_name = "".join(sanitized_name) - # Save the uploaded file to the specified folder file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name)) shutil.copy(file.name, file_path) diff --git a/tests/test_gradio_ui.py b/tests/test_gradio_ui.py new file mode 100644 index 0000000..0b337d2 --- /dev/null +++ b/tests/test_gradio_ui.py @@ -0,0 +1,125 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import tempfile +import unittest +from unittest.mock import Mock, patch + +from smolagents.gradio_ui import GradioUI + + +class GradioUITester(unittest.TestCase): + def setUp(self): + """Initialize test environment""" + self.temp_dir = tempfile.mkdtemp() + self.mock_agent = Mock() + self.ui = GradioUI(agent=self.mock_agent, file_upload_folder=self.temp_dir) + self.allowed_types = [".pdf", ".docx", ".txt"] + + def tearDown(self): + """Clean up test environment""" + shutil.rmtree(self.temp_dir) + + def test_upload_file_default_types(self): + """Test default allowed file types""" + default_types = [".pdf", ".docx", ".txt"] + for file_type in default_types: + with tempfile.NamedTemporaryFile(suffix=file_type) as temp_file: + mock_file = Mock() + mock_file.name = temp_file.name + + textbox, uploads_log = self.ui.upload_file(mock_file, []) + + self.assertIn("File uploaded:", textbox.value) + self.assertEqual(len(uploads_log), 1) + self.assertTrue(os.path.exists(os.path.join(self.temp_dir, os.path.basename(temp_file.name)))) + + def test_upload_file_default_types_disallowed(self): + """Test default disallowed file types""" + disallowed_types = [".exe", ".sh", ".py", ".jpg"] + for file_type in disallowed_types: + with tempfile.NamedTemporaryFile(suffix=file_type) as temp_file: + mock_file = Mock() + mock_file.name = temp_file.name + + textbox, uploads_log = self.ui.upload_file(mock_file, []) + + self.assertEqual(textbox.value, "File type disallowed") + self.assertEqual(len(uploads_log), 0) + + def test_upload_file_success(self): + """Test successful file upload scenario""" + with tempfile.NamedTemporaryFile(suffix=".txt") as temp_file: + mock_file = Mock() + mock_file.name = temp_file.name + + textbox, uploads_log = self.ui.upload_file(mock_file, []) + + self.assertIn("File uploaded:", textbox.value) + self.assertEqual(len(uploads_log), 1) + self.assertTrue(os.path.exists(os.path.join(self.temp_dir, os.path.basename(temp_file.name)))) + self.assertEqual(uploads_log[0], os.path.join(self.temp_dir, os.path.basename(temp_file.name))) + + def test_upload_file_none(self): + """Test scenario when no file is selected""" + textbox, uploads_log = self.ui.upload_file(None, []) + + self.assertEqual(textbox.value, "No file uploaded") + self.assertEqual(len(uploads_log), 0) + + def test_upload_file_invalid_type(self): + """Test disallowed file type""" + with tempfile.NamedTemporaryFile(suffix=".exe") as temp_file: + mock_file = Mock() + mock_file.name = temp_file.name + + textbox, uploads_log = self.ui.upload_file(mock_file, []) + + self.assertEqual(textbox.value, "File type disallowed") + self.assertEqual(len(uploads_log), 0) + + def test_upload_file_special_chars(self): + """Test scenario with special characters in filename""" + with tempfile.NamedTemporaryFile(suffix=".txt") as temp_file: + # Create a new temporary file with special characters + special_char_name = os.path.join(os.path.dirname(temp_file.name), "test@#$%^&*.txt") + shutil.copy(temp_file.name, special_char_name) + try: + mock_file = Mock() + mock_file.name = special_char_name + + with patch("shutil.copy"): + textbox, uploads_log = self.ui.upload_file(mock_file, []) + + self.assertIn("File uploaded:", textbox.value) + self.assertEqual(len(uploads_log), 1) + self.assertIn("test_____", uploads_log[0]) + finally: + # Clean up the special character file + if os.path.exists(special_char_name): + os.remove(special_char_name) + + def test_upload_file_custom_types(self): + """Test custom allowed file types""" + with tempfile.NamedTemporaryFile(suffix=".csv") as temp_file: + mock_file = Mock() + mock_file.name = temp_file.name + + textbox, uploads_log = self.ui.upload_file(mock_file, [], allowed_file_types=[".csv"]) + + self.assertIn("File uploaded:", textbox.value) + self.assertEqual(len(uploads_log), 1) From a6cc506d099e5bf19b2fd9f9899394b211d369f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Galego?= Date: Fri, 14 Feb 2025 18:32:12 +0000 Subject: [PATCH 3/4] =?UTF-8?q?Bugfix:=20Groq=20via=20LiteLLM=20?= =?UTF-8?q?=F0=9F=9A=85=20(#605)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> --- src/smolagents/models.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 172757e..7b35cb1 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -863,9 +863,8 @@ class LiteLLMModel(Model): import litellm model_info: dict = litellm.get_model_info(self.model_id) - if model_info["litellm_provider"] == "ollama": + if model_info["litellm_provider"] in ["ollama", "groq"]: return model_info["key"] != "llava" - return False def __call__( From 776523693055f4f4e2d3e81e7d3550245c5b8a4c Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Sat, 15 Feb 2025 10:24:14 +0100 Subject: [PATCH 4/4] Add docstring args for MultiStepAgent.from_folder (#654) --- src/smolagents/agents.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index 1fac1d7..521cf71 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -936,7 +936,12 @@ You have been provided with these additional arguments, that you can access usin @classmethod def from_folder(cls, folder: Union[str, Path], **kwargs): - """Loads an agent from a local folder""" + """Loads an agent from a local folder. + + Args: + folder (`str` or `Path`): The folder where the agent is saved. + **kwargs: Additional keyword arguments that will be passed to the agent's init. + """ folder = Path(folder) agent_dict = json.loads((folder / "agent.json").read_text())