refactor: simplify file type checking from MIME to extension (#342)

This commit is contained in:
kingdomad 2025-02-15 01:03:22 +08:00 committed by GitHub
parent 75b0e4f41c
commit a940f42c04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 136 additions and 27 deletions

View File

@ -103,6 +103,11 @@ jobs:
uv run pytest ./tests/test_utils.py uv run pytest ./tests/test_utils.py
if: ${{ success() || failure() }} if: ${{ success() || failure() }}
- name: Gradio UI tests
run: |
uv run pytest ./tests/test_gradio_ui.py
if: ${{ success() || failure() }}
- name: Function type hints utils tests - name: Function type hints utils tests
run: | run: |
uv run pytest ./tests/test_function_type_hints_utils.py uv run pytest ./tests/test_function_type_hints_utils.py

View File

@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import mimetypes
import os import os
import re import re
import shutil import shutil
@ -199,30 +198,20 @@ class GradioUI:
yield messages yield messages
yield messages yield messages
def upload_file( def upload_file(self, file, file_uploads_log, allowed_file_types=None):
self,
file,
file_uploads_log,
allowed_file_types=[
"application/pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"text/plain",
],
):
""" """
Handle file uploads, default allowed types are .pdf, .docx, and .txt Handle file uploads, default allowed types are .pdf, .docx, and .txt
""" """
import gradio as gr import gradio as gr
if file is None: if file is None:
return gr.Textbox("No file uploaded", visible=True), file_uploads_log return gr.Textbox(value="No file uploaded", visible=True), file_uploads_log
try: if allowed_file_types is None:
mime_type, _ = mimetypes.guess_type(file.name) allowed_file_types = [".pdf", ".docx", ".txt"]
except Exception as e:
return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log
if mime_type not in allowed_file_types: file_ext = os.path.splitext(file.name)[1].lower()
if file_ext not in allowed_file_types:
return gr.Textbox("File type disallowed", visible=True), file_uploads_log return gr.Textbox("File type disallowed", visible=True), file_uploads_log
# Sanitize file name # Sanitize file name
@ -231,16 +220,6 @@ class GradioUI:
r"[^\w\-.]", "_", original_name r"[^\w\-.]", "_", original_name
) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores ) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores
type_to_ext = {}
for ext, t in mimetypes.types_map.items():
if t not in type_to_ext:
type_to_ext[t] = ext
# Ensure the extension correlates to the mime type
sanitized_name = sanitized_name.split(".")[:-1]
sanitized_name.append("" + type_to_ext[mime_type])
sanitized_name = "".join(sanitized_name)
# Save the uploaded file to the specified folder # Save the uploaded file to the specified folder
file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name)) file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name))
shutil.copy(file.name, file_path) shutil.copy(file.name, file_path)

125
tests/test_gradio_ui.py Normal file
View File

@ -0,0 +1,125 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import tempfile
import unittest
from unittest.mock import Mock, patch
from smolagents.gradio_ui import GradioUI
class GradioUITester(unittest.TestCase):
def setUp(self):
"""Initialize test environment"""
self.temp_dir = tempfile.mkdtemp()
self.mock_agent = Mock()
self.ui = GradioUI(agent=self.mock_agent, file_upload_folder=self.temp_dir)
self.allowed_types = [".pdf", ".docx", ".txt"]
def tearDown(self):
"""Clean up test environment"""
shutil.rmtree(self.temp_dir)
def test_upload_file_default_types(self):
"""Test default allowed file types"""
default_types = [".pdf", ".docx", ".txt"]
for file_type in default_types:
with tempfile.NamedTemporaryFile(suffix=file_type) as temp_file:
mock_file = Mock()
mock_file.name = temp_file.name
textbox, uploads_log = self.ui.upload_file(mock_file, [])
self.assertIn("File uploaded:", textbox.value)
self.assertEqual(len(uploads_log), 1)
self.assertTrue(os.path.exists(os.path.join(self.temp_dir, os.path.basename(temp_file.name))))
def test_upload_file_default_types_disallowed(self):
"""Test default disallowed file types"""
disallowed_types = [".exe", ".sh", ".py", ".jpg"]
for file_type in disallowed_types:
with tempfile.NamedTemporaryFile(suffix=file_type) as temp_file:
mock_file = Mock()
mock_file.name = temp_file.name
textbox, uploads_log = self.ui.upload_file(mock_file, [])
self.assertEqual(textbox.value, "File type disallowed")
self.assertEqual(len(uploads_log), 0)
def test_upload_file_success(self):
"""Test successful file upload scenario"""
with tempfile.NamedTemporaryFile(suffix=".txt") as temp_file:
mock_file = Mock()
mock_file.name = temp_file.name
textbox, uploads_log = self.ui.upload_file(mock_file, [])
self.assertIn("File uploaded:", textbox.value)
self.assertEqual(len(uploads_log), 1)
self.assertTrue(os.path.exists(os.path.join(self.temp_dir, os.path.basename(temp_file.name))))
self.assertEqual(uploads_log[0], os.path.join(self.temp_dir, os.path.basename(temp_file.name)))
def test_upload_file_none(self):
"""Test scenario when no file is selected"""
textbox, uploads_log = self.ui.upload_file(None, [])
self.assertEqual(textbox.value, "No file uploaded")
self.assertEqual(len(uploads_log), 0)
def test_upload_file_invalid_type(self):
"""Test disallowed file type"""
with tempfile.NamedTemporaryFile(suffix=".exe") as temp_file:
mock_file = Mock()
mock_file.name = temp_file.name
textbox, uploads_log = self.ui.upload_file(mock_file, [])
self.assertEqual(textbox.value, "File type disallowed")
self.assertEqual(len(uploads_log), 0)
def test_upload_file_special_chars(self):
"""Test scenario with special characters in filename"""
with tempfile.NamedTemporaryFile(suffix=".txt") as temp_file:
# Create a new temporary file with special characters
special_char_name = os.path.join(os.path.dirname(temp_file.name), "test@#$%^&*.txt")
shutil.copy(temp_file.name, special_char_name)
try:
mock_file = Mock()
mock_file.name = special_char_name
with patch("shutil.copy"):
textbox, uploads_log = self.ui.upload_file(mock_file, [])
self.assertIn("File uploaded:", textbox.value)
self.assertEqual(len(uploads_log), 1)
self.assertIn("test_____", uploads_log[0])
finally:
# Clean up the special character file
if os.path.exists(special_char_name):
os.remove(special_char_name)
def test_upload_file_custom_types(self):
"""Test custom allowed file types"""
with tempfile.NamedTemporaryFile(suffix=".csv") as temp_file:
mock_file = Mock()
mock_file.name = temp_file.name
textbox, uploads_log = self.ui.upload_file(mock_file, [], allowed_file_types=[".csv"])
self.assertIn("File uploaded:", textbox.value)
self.assertEqual(len(uploads_log), 1)