Improve GradioUI file upload system

This commit is contained in:
Aymeric 2025-01-13 19:46:36 +01:00
parent 1f96560c92
commit 1d846072eb
6 changed files with 66 additions and 44 deletions

View File

@ -5,7 +5,7 @@ from smolagents import (
)
agent = CodeAgent(
tools=[], model=HfApiModel(), max_steps=4, verbose=True
tools=[], model=HfApiModel(), max_steps=4, verbosity_level=0
)
GradioUI(agent, file_upload_folder='./data').launch()

View File

@ -396,7 +396,7 @@ class MultiStepAgent:
}
]
try:
return self.model(self.input_messages)
return self.model(self.input_messages).content
except Exception as e:
return f"Error in generating final LLM output:\n{e}"
@ -666,7 +666,9 @@ You have been provided with these additional arguments, that you can access usin
Now begin!""",
}
answer_facts = self.model([message_prompt_facts, message_prompt_task])
answer_facts = self.model(
[message_prompt_facts, message_prompt_task]
).content
message_system_prompt_plan = {
"role": MessageRole.SYSTEM,
@ -688,7 +690,7 @@ Now begin!""",
answer_plan = self.model(
[message_system_prompt_plan, message_user_prompt_plan],
stop_sequences=["<end_plan>"],
)
).content
final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task:
```
@ -722,7 +724,7 @@ Now begin!""",
}
facts_update = self.model(
[facts_update_system_prompt] + agent_memory + [facts_update_message]
)
).content
# Redact updated plan
plan_update_message = {
@ -807,17 +809,26 @@ class ToolCallingAgent(MultiStepAgent):
tools_to_call_from=list(self.tools.values()),
stop_sequences=["Observation:"],
)
# Extract tool call from model output
if type(model_message.tool_calls) is list and len(model_message.tool_calls) > 0:
if (
type(model_message.tool_calls) is list
and len(model_message.tool_calls) > 0
):
tool_calls = model_message.tool_calls[0]
tool_arguments = tool_calls.function.arguments
tool_name, tool_call_id = tool_calls.function.name, tool_calls.id
else:
start, end = model_message.content.find('{'), model_message.content.rfind('}') + 1
start, end = (
model_message.content.find("{"),
model_message.content.rfind("}") + 1,
)
tool_calls = json.loads(model_message.content[start:end])
tool_arguments = tool_calls["tool_arguments"]
tool_name, tool_call_id = tool_calls["tool_name"], f"call_{len(self.logs)}"
tool_name, tool_call_id = (
tool_calls["tool_name"],
f"call_{len(self.logs)}",
)
except Exception as e:
raise AgentGenerationError(

View File

@ -27,14 +27,15 @@ def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
"""Extract ChatMessage objects from agent steps"""
if isinstance(step_log, ActionStep):
yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "")
if step_log.tool_call is not None:
used_code = step_log.tool_call.name == "code interpreter"
content = step_log.tool_call.arguments
if step_log.tool_calls is not None:
first_tool_call = step_log.tool_calls[0]
used_code = first_tool_call.name == "code interpreter"
content = first_tool_call.arguments
if used_code:
content = f"```py\n{content}\n```"
yield gr.ChatMessage(
role="assistant",
metadata={"title": f"🛠️ Used tool {step_log.tool_call.name}"},
metadata={"title": f"🛠️ Used tool {first_tool_call.name}"},
content=str(content),
)
if step_log.observations is not None:
@ -103,6 +104,7 @@ class GradioUI:
def upload_file(
self,
file,
file_uploads_log,
allowed_file_types=[
"application/pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
@ -110,14 +112,12 @@ class GradioUI:
],
):
"""
Handle file uploads, default allowed types are pdf, docx, and .txt
Handle file uploads, default allowed types are .pdf, .docx, and .txt
"""
# Check if file is uploaded
if file is None:
return "No file uploaded"
# Check if file is in allowed filetypes
try:
mime_type, _ = mimetypes.guess_type(file.name)
except Exception as e:
@ -148,11 +148,23 @@ class GradioUI:
)
shutil.copy(file.name, file_path)
return f"File uploaded successfully to {self.file_upload_folder}"
return gr.Textbox(
f"File uploaded: {file_path}", visible=True
), file_uploads_log + [file_path]
def log_user_message(self, text_input, file_uploads_log):
return (
text_input
+ f"\nYou have been provided with these files, which might be helpful or not: {file_uploads_log}"
if len(file_uploads_log) > 0
else "",
"",
)
def launch(self):
with gr.Blocks() as demo:
stored_message = gr.State([])
stored_messages = gr.State([])
file_uploads_log = gr.State([])
chatbot = gr.Chatbot(
label="Agent",
type="messages",
@ -163,14 +175,21 @@ class GradioUI:
)
# If an upload folder is provided, enable the upload feature
if self.file_upload_folder is not None:
upload_file = gr.File(label="Upload a file")
upload_status = gr.Textbox(label="Upload Status", interactive=False)
upload_file.change(self.upload_file, [upload_file], [upload_status])
upload_file = gr.File(label="Upload a file", height=1)
upload_status = gr.Textbox(
label="Upload Status", interactive=False, visible=False
)
upload_file.change(
self.upload_file,
[upload_file, file_uploads_log],
[upload_status, file_uploads_log],
)
text_input = gr.Textbox(lines=1, label="Chat Message")
text_input.submit(
lambda s: (s, ""), [text_input], [stored_message, text_input]
).then(self.interact_with_agent, [stored_message, chatbot], [chatbot])
self.log_user_message,
[text_input, file_uploads_log],
[stored_messages, text_input],
).then(self.interact_with_agent, [stored_messages, chatbot], [chatbot])
demo.launch()

View File

@ -36,6 +36,8 @@ from transformers import (
StoppingCriteriaList,
is_torch_available,
)
from transformers.utils.import_utils import _is_package_available
import openai
from .tools import Tool
@ -52,13 +54,9 @@ DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_code>",
}
try:
if _is_package_available("litellm"):
import litellm
is_litellm_available = True
except ImportError:
is_litellm_available = False
class MessageRole(str, Enum):
USER = "user"
@ -159,7 +157,7 @@ class Model:
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
) -> str:
) -> ChatCompletionOutputMessage:
"""Process the input messages and return the model's response.
Parameters:
@ -174,15 +172,7 @@ class Model:
Returns:
`str`: The text content of the model's response.
"""
if not isinstance(messages, List):
raise ValueError(
"Messages should be a list of dictionaries with 'role' and 'content' keys."
)
if stop_sequences is None:
stop_sequences = []
response = self.generate(messages, stop_sequences, grammar, max_tokens)
return remove_stop_sequences(response, stop_sequences)
pass # To be implemented in child classes!
class HfApiModel(Model):
@ -238,7 +228,7 @@ class HfApiModel(Model):
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
) -> str:
) -> ChatCompletionOutputMessage:
"""
Gets an LLM output message for the given list of input messages.
If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call.
@ -407,7 +397,7 @@ class LiteLLMModel(Model):
api_key=None,
**kwargs,
):
if not is_litellm_available:
if not _is_package_available("litellm"):
raise ImportError(
"litellm not found. Install it with `pip install litellm`"
)
@ -426,7 +416,7 @@ class LiteLLMModel(Model):
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
) -> str:
) -> ChatCompletionOutputMessage:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
@ -497,7 +487,7 @@ class OpenAIServerModel(Model):
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
) -> str:
) -> ChatCompletionOutputMessage:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)

View File

@ -367,9 +367,10 @@ class AgentTests(unittest.TestCase):
model=fake_code_model_no_return, # use this callable because it never ends
max_steps=5,
)
agent.run("What is 2 multiplied by 3.6452?")
answer = agent.run("What is 2 multiplied by 3.6452?")
assert len(agent.logs) == 8
assert type(agent.logs[-1].error) is AgentMaxStepsError
assert isinstance(answer, str)
def test_tool_descriptions_get_baked_in_system_prompt(self):
tool = PythonInterpreterTool()

View File

@ -486,6 +486,7 @@ if char.isalpha():
code = "import numpy.random as rd"
evaluate_python_code(code, authorized_imports=["numpy.random"], state={})
evaluate_python_code(code, authorized_imports=["numpy"], state={})
evaluate_python_code(code, authorized_imports=["*"], state={})
with pytest.raises(InterpreterError):
evaluate_python_code(code, authorized_imports=["random"], state={})