Add option to upload files to GradioUI (#138)
* Add option to upload files to GradioUI
This commit is contained in:
parent
67ee777370
commit
2a51efe11f
|
@ -6,6 +6,7 @@ wandb
|
||||||
# Data
|
# Data
|
||||||
data
|
data
|
||||||
outputs
|
outputs
|
||||||
|
data/
|
||||||
|
|
||||||
# Apple
|
# Apple
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
from smolagents import (
|
||||||
|
CodeAgent,
|
||||||
|
HfApiModel,
|
||||||
|
GradioUI
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = CodeAgent(
|
||||||
|
tools=[], model=HfApiModel(), max_steps=4, verbose=True
|
||||||
|
)
|
||||||
|
|
||||||
|
GradioUI(agent, file_upload_folder='./data').launch()
|
|
@ -1,6 +1,5 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -15,6 +14,10 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
import shutil
|
||||||
|
import os
|
||||||
|
import mimetypes
|
||||||
|
import re
|
||||||
|
|
||||||
from .agents import ActionStep, AgentStep, MultiStepAgent
|
from .agents import ActionStep, AgentStep, MultiStepAgent
|
||||||
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
|
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
|
||||||
|
@ -82,8 +85,12 @@ def stream_to_gradio(
|
||||||
class GradioUI:
|
class GradioUI:
|
||||||
"""A one-line interface to launch your agent in Gradio"""
|
"""A one-line interface to launch your agent in Gradio"""
|
||||||
|
|
||||||
def __init__(self, agent: MultiStepAgent):
|
def __init__(self, agent: MultiStepAgent, file_upload_folder: str | None=None):
|
||||||
self.agent = agent
|
self.agent = agent
|
||||||
|
self.file_upload_folder = file_upload_folder
|
||||||
|
if self.file_upload_folder is not None:
|
||||||
|
if not os.path.exists(file_upload_folder):
|
||||||
|
os.mkdir(file_upload_folder)
|
||||||
|
|
||||||
def interact_with_agent(self, prompt, messages):
|
def interact_with_agent(self, prompt, messages):
|
||||||
messages.append(gr.ChatMessage(role="user", content=prompt))
|
messages.append(gr.ChatMessage(role="user", content=prompt))
|
||||||
|
@ -93,6 +100,45 @@ class GradioUI:
|
||||||
yield messages
|
yield messages
|
||||||
yield messages
|
yield messages
|
||||||
|
|
||||||
|
def upload_file(self, file, allowed_file_types=["application/pdf", "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "text/plain"]):
|
||||||
|
"""
|
||||||
|
Handle file uploads, default allowed types are pdf, docx, and .txt
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Check if file is uploaded
|
||||||
|
if file is None:
|
||||||
|
return "No file uploaded"
|
||||||
|
|
||||||
|
# Check if file is in allowed filetypes
|
||||||
|
name = os.path.basename(file.name)
|
||||||
|
try:
|
||||||
|
mime_type, _ = mimetypes.guess_type(file.name)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error: {e}"
|
||||||
|
|
||||||
|
if mime_type not in allowed_file_types:
|
||||||
|
return "File type disallowed"
|
||||||
|
|
||||||
|
# Sanitize file name
|
||||||
|
original_name = os.path.basename(file.name)
|
||||||
|
sanitized_name = re.sub(r'[^\w\-.]', '_', original_name) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores
|
||||||
|
|
||||||
|
type_to_ext = {}
|
||||||
|
for ext, t in mimetypes.types_map.items():
|
||||||
|
if t not in type_to_ext:
|
||||||
|
type_to_ext[t] = ext
|
||||||
|
|
||||||
|
# Ensure the extension correlates to the mime type
|
||||||
|
sanitized_name = sanitized_name.split(".")[:-1]
|
||||||
|
sanitized_name.append("" + type_to_ext[mime_type])
|
||||||
|
sanitized_name = "".join(sanitized_name)
|
||||||
|
|
||||||
|
# Save the uploaded file to the specified folder
|
||||||
|
file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name))
|
||||||
|
shutil.copy(file.name, file_path)
|
||||||
|
|
||||||
|
return f"File uploaded successfully to {self.file_upload_folder}"
|
||||||
|
|
||||||
def launch(self):
|
def launch(self):
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
stored_message = gr.State([])
|
stored_message = gr.State([])
|
||||||
|
@ -104,6 +150,14 @@ class GradioUI:
|
||||||
"https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
|
"https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
# If an upload folder is provided, enable the upload feature
|
||||||
|
if self.file_upload_folder is not None:
|
||||||
|
upload_file = gr.File(label="Upload a file")
|
||||||
|
upload_status = gr.Textbox(label="Upload Status", interactive=False)
|
||||||
|
|
||||||
|
upload_file.change(
|
||||||
|
self.upload_file, [upload_file], [upload_status]
|
||||||
|
)
|
||||||
text_input = gr.Textbox(lines=1, label="Chat Message")
|
text_input = gr.Textbox(lines=1, label="Chat Message")
|
||||||
text_input.submit(
|
text_input.submit(
|
||||||
lambda s: (s, ""), [text_input], [stored_message, text_input]
|
lambda s: (s, ""), [text_input], [stored_message, text_input]
|
||||||
|
|
Loading…
Reference in New Issue