Add option to upload files to GradioUI (#138)

* Add option to upload files to GradioUI
This commit is contained in:
stackviolator 2025-01-13 09:33:45 -06:00 committed by GitHub
parent 67ee777370
commit 2a51efe11f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 68 additions and 2 deletions

1
.gitignore vendored
View File

@ -6,6 +6,7 @@ wandb
# Data
data
outputs
data/
# Apple
.DS_Store

11
examples/gradio_upload.py Normal file
View File

@ -0,0 +1,11 @@
from smolagents import (
CodeAgent,
HfApiModel,
GradioUI
)
agent = CodeAgent(
tools=[], model=HfApiModel(), max_steps=4, verbose=True
)
GradioUI(agent, file_upload_folder='./data').launch()

View File

@ -1,6 +1,5 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -15,6 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gradio as gr
import shutil
import os
import mimetypes
import re
from .agents import ActionStep, AgentStep, MultiStepAgent
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
@ -82,8 +85,12 @@ def stream_to_gradio(
class GradioUI:
"""A one-line interface to launch your agent in Gradio"""
def __init__(self, agent: MultiStepAgent):
def __init__(self, agent: MultiStepAgent, file_upload_folder: str | None=None):
self.agent = agent
self.file_upload_folder = file_upload_folder
if self.file_upload_folder is not None:
if not os.path.exists(file_upload_folder):
os.mkdir(file_upload_folder)
def interact_with_agent(self, prompt, messages):
messages.append(gr.ChatMessage(role="user", content=prompt))
@ -93,6 +100,45 @@ class GradioUI:
yield messages
yield messages
def upload_file(self, file, allowed_file_types=["application/pdf", "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "text/plain"]):
"""
Handle file uploads, default allowed types are pdf, docx, and .txt
"""
# Check if file is uploaded
if file is None:
return "No file uploaded"
# Check if file is in allowed filetypes
name = os.path.basename(file.name)
try:
mime_type, _ = mimetypes.guess_type(file.name)
except Exception as e:
return f"Error: {e}"
if mime_type not in allowed_file_types:
return "File type disallowed"
# Sanitize file name
original_name = os.path.basename(file.name)
sanitized_name = re.sub(r'[^\w\-.]', '_', original_name) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores
type_to_ext = {}
for ext, t in mimetypes.types_map.items():
if t not in type_to_ext:
type_to_ext[t] = ext
# Ensure the extension correlates to the mime type
sanitized_name = sanitized_name.split(".")[:-1]
sanitized_name.append("" + type_to_ext[mime_type])
sanitized_name = "".join(sanitized_name)
# Save the uploaded file to the specified folder
file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name))
shutil.copy(file.name, file_path)
return f"File uploaded successfully to {self.file_upload_folder}"
def launch(self):
with gr.Blocks() as demo:
stored_message = gr.State([])
@ -104,6 +150,14 @@ class GradioUI:
"https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
),
)
# If an upload folder is provided, enable the upload feature
if self.file_upload_folder is not None:
upload_file = gr.File(label="Upload a file")
upload_status = gr.Textbox(label="Upload Status", interactive=False)
upload_file.change(
self.upload_file, [upload_file], [upload_status]
)
text_input = gr.Textbox(lines=1, label="Chat Message")
text_input.submit(
lambda s: (s, ""), [text_input], [stored_message, text_input]