Improve GradioUI file upload system
This commit is contained in:
parent
1f96560c92
commit
1d846072eb
|
@ -5,7 +5,7 @@ from smolagents import (
|
|||
)
|
||||
|
||||
agent = CodeAgent(
|
||||
tools=[], model=HfApiModel(), max_steps=4, verbose=True
|
||||
tools=[], model=HfApiModel(), max_steps=4, verbosity_level=0
|
||||
)
|
||||
|
||||
GradioUI(agent, file_upload_folder='./data').launch()
|
||||
|
|
|
@ -396,7 +396,7 @@ class MultiStepAgent:
|
|||
}
|
||||
]
|
||||
try:
|
||||
return self.model(self.input_messages)
|
||||
return self.model(self.input_messages).content
|
||||
except Exception as e:
|
||||
return f"Error in generating final LLM output:\n{e}"
|
||||
|
||||
|
@ -666,7 +666,9 @@ You have been provided with these additional arguments, that you can access usin
|
|||
Now begin!""",
|
||||
}
|
||||
|
||||
answer_facts = self.model([message_prompt_facts, message_prompt_task])
|
||||
answer_facts = self.model(
|
||||
[message_prompt_facts, message_prompt_task]
|
||||
).content
|
||||
|
||||
message_system_prompt_plan = {
|
||||
"role": MessageRole.SYSTEM,
|
||||
|
@ -688,7 +690,7 @@ Now begin!""",
|
|||
answer_plan = self.model(
|
||||
[message_system_prompt_plan, message_user_prompt_plan],
|
||||
stop_sequences=["<end_plan>"],
|
||||
)
|
||||
).content
|
||||
|
||||
final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task:
|
||||
```
|
||||
|
@ -722,7 +724,7 @@ Now begin!""",
|
|||
}
|
||||
facts_update = self.model(
|
||||
[facts_update_system_prompt] + agent_memory + [facts_update_message]
|
||||
)
|
||||
).content
|
||||
|
||||
# Redact updated plan
|
||||
plan_update_message = {
|
||||
|
@ -807,17 +809,26 @@ class ToolCallingAgent(MultiStepAgent):
|
|||
tools_to_call_from=list(self.tools.values()),
|
||||
stop_sequences=["Observation:"],
|
||||
)
|
||||
|
||||
|
||||
# Extract tool call from model output
|
||||
if type(model_message.tool_calls) is list and len(model_message.tool_calls) > 0:
|
||||
if (
|
||||
type(model_message.tool_calls) is list
|
||||
and len(model_message.tool_calls) > 0
|
||||
):
|
||||
tool_calls = model_message.tool_calls[0]
|
||||
tool_arguments = tool_calls.function.arguments
|
||||
tool_name, tool_call_id = tool_calls.function.name, tool_calls.id
|
||||
else:
|
||||
start, end = model_message.content.find('{'), model_message.content.rfind('}') + 1
|
||||
start, end = (
|
||||
model_message.content.find("{"),
|
||||
model_message.content.rfind("}") + 1,
|
||||
)
|
||||
tool_calls = json.loads(model_message.content[start:end])
|
||||
tool_arguments = tool_calls["tool_arguments"]
|
||||
tool_name, tool_call_id = tool_calls["tool_name"], f"call_{len(self.logs)}"
|
||||
tool_name, tool_call_id = (
|
||||
tool_calls["tool_name"],
|
||||
f"call_{len(self.logs)}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise AgentGenerationError(
|
||||
|
|
|
@ -27,14 +27,15 @@ def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
|
|||
"""Extract ChatMessage objects from agent steps"""
|
||||
if isinstance(step_log, ActionStep):
|
||||
yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "")
|
||||
if step_log.tool_call is not None:
|
||||
used_code = step_log.tool_call.name == "code interpreter"
|
||||
content = step_log.tool_call.arguments
|
||||
if step_log.tool_calls is not None:
|
||||
first_tool_call = step_log.tool_calls[0]
|
||||
used_code = first_tool_call.name == "code interpreter"
|
||||
content = first_tool_call.arguments
|
||||
if used_code:
|
||||
content = f"```py\n{content}\n```"
|
||||
yield gr.ChatMessage(
|
||||
role="assistant",
|
||||
metadata={"title": f"🛠️ Used tool {step_log.tool_call.name}"},
|
||||
metadata={"title": f"🛠️ Used tool {first_tool_call.name}"},
|
||||
content=str(content),
|
||||
)
|
||||
if step_log.observations is not None:
|
||||
|
@ -103,6 +104,7 @@ class GradioUI:
|
|||
def upload_file(
|
||||
self,
|
||||
file,
|
||||
file_uploads_log,
|
||||
allowed_file_types=[
|
||||
"application/pdf",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
|
@ -110,14 +112,12 @@ class GradioUI:
|
|||
],
|
||||
):
|
||||
"""
|
||||
Handle file uploads, default allowed types are pdf, docx, and .txt
|
||||
Handle file uploads, default allowed types are .pdf, .docx, and .txt
|
||||
"""
|
||||
|
||||
# Check if file is uploaded
|
||||
if file is None:
|
||||
return "No file uploaded"
|
||||
|
||||
# Check if file is in allowed filetypes
|
||||
try:
|
||||
mime_type, _ = mimetypes.guess_type(file.name)
|
||||
except Exception as e:
|
||||
|
@ -148,11 +148,23 @@ class GradioUI:
|
|||
)
|
||||
shutil.copy(file.name, file_path)
|
||||
|
||||
return f"File uploaded successfully to {self.file_upload_folder}"
|
||||
return gr.Textbox(
|
||||
f"File uploaded: {file_path}", visible=True
|
||||
), file_uploads_log + [file_path]
|
||||
|
||||
def log_user_message(self, text_input, file_uploads_log):
|
||||
return (
|
||||
text_input
|
||||
+ f"\nYou have been provided with these files, which might be helpful or not: {file_uploads_log}"
|
||||
if len(file_uploads_log) > 0
|
||||
else "",
|
||||
"",
|
||||
)
|
||||
|
||||
def launch(self):
|
||||
with gr.Blocks() as demo:
|
||||
stored_message = gr.State([])
|
||||
stored_messages = gr.State([])
|
||||
file_uploads_log = gr.State([])
|
||||
chatbot = gr.Chatbot(
|
||||
label="Agent",
|
||||
type="messages",
|
||||
|
@ -163,14 +175,21 @@ class GradioUI:
|
|||
)
|
||||
# If an upload folder is provided, enable the upload feature
|
||||
if self.file_upload_folder is not None:
|
||||
upload_file = gr.File(label="Upload a file")
|
||||
upload_status = gr.Textbox(label="Upload Status", interactive=False)
|
||||
|
||||
upload_file.change(self.upload_file, [upload_file], [upload_status])
|
||||
upload_file = gr.File(label="Upload a file", height=1)
|
||||
upload_status = gr.Textbox(
|
||||
label="Upload Status", interactive=False, visible=False
|
||||
)
|
||||
upload_file.change(
|
||||
self.upload_file,
|
||||
[upload_file, file_uploads_log],
|
||||
[upload_status, file_uploads_log],
|
||||
)
|
||||
text_input = gr.Textbox(lines=1, label="Chat Message")
|
||||
text_input.submit(
|
||||
lambda s: (s, ""), [text_input], [stored_message, text_input]
|
||||
).then(self.interact_with_agent, [stored_message, chatbot], [chatbot])
|
||||
self.log_user_message,
|
||||
[text_input, file_uploads_log],
|
||||
[stored_messages, text_input],
|
||||
).then(self.interact_with_agent, [stored_messages, chatbot], [chatbot])
|
||||
|
||||
demo.launch()
|
||||
|
||||
|
|
|
@ -36,6 +36,8 @@ from transformers import (
|
|||
StoppingCriteriaList,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.utils.import_utils import _is_package_available
|
||||
|
||||
import openai
|
||||
|
||||
from .tools import Tool
|
||||
|
@ -52,13 +54,9 @@ DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
|
|||
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_code>",
|
||||
}
|
||||
|
||||
try:
|
||||
if _is_package_available("litellm"):
|
||||
import litellm
|
||||
|
||||
is_litellm_available = True
|
||||
except ImportError:
|
||||
is_litellm_available = False
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
USER = "user"
|
||||
|
@ -159,7 +157,7 @@ class Model:
|
|||
stop_sequences: Optional[List[str]] = None,
|
||||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500,
|
||||
) -> str:
|
||||
) -> ChatCompletionOutputMessage:
|
||||
"""Process the input messages and return the model's response.
|
||||
|
||||
Parameters:
|
||||
|
@ -174,15 +172,7 @@ class Model:
|
|||
Returns:
|
||||
`str`: The text content of the model's response.
|
||||
"""
|
||||
if not isinstance(messages, List):
|
||||
raise ValueError(
|
||||
"Messages should be a list of dictionaries with 'role' and 'content' keys."
|
||||
)
|
||||
if stop_sequences is None:
|
||||
stop_sequences = []
|
||||
response = self.generate(messages, stop_sequences, grammar, max_tokens)
|
||||
|
||||
return remove_stop_sequences(response, stop_sequences)
|
||||
pass # To be implemented in child classes!
|
||||
|
||||
|
||||
class HfApiModel(Model):
|
||||
|
@ -238,7 +228,7 @@ class HfApiModel(Model):
|
|||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500,
|
||||
tools_to_call_from: Optional[List[Tool]] = None,
|
||||
) -> str:
|
||||
) -> ChatCompletionOutputMessage:
|
||||
"""
|
||||
Gets an LLM output message for the given list of input messages.
|
||||
If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call.
|
||||
|
@ -407,7 +397,7 @@ class LiteLLMModel(Model):
|
|||
api_key=None,
|
||||
**kwargs,
|
||||
):
|
||||
if not is_litellm_available:
|
||||
if not _is_package_available("litellm"):
|
||||
raise ImportError(
|
||||
"litellm not found. Install it with `pip install litellm`"
|
||||
)
|
||||
|
@ -426,7 +416,7 @@ class LiteLLMModel(Model):
|
|||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500,
|
||||
tools_to_call_from: Optional[List[Tool]] = None,
|
||||
) -> str:
|
||||
) -> ChatCompletionOutputMessage:
|
||||
messages = get_clean_message_list(
|
||||
messages, role_conversions=tool_role_conversions
|
||||
)
|
||||
|
@ -497,7 +487,7 @@ class OpenAIServerModel(Model):
|
|||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500,
|
||||
tools_to_call_from: Optional[List[Tool]] = None,
|
||||
) -> str:
|
||||
) -> ChatCompletionOutputMessage:
|
||||
messages = get_clean_message_list(
|
||||
messages, role_conversions=tool_role_conversions
|
||||
)
|
||||
|
|
|
@ -367,9 +367,10 @@ class AgentTests(unittest.TestCase):
|
|||
model=fake_code_model_no_return, # use this callable because it never ends
|
||||
max_steps=5,
|
||||
)
|
||||
agent.run("What is 2 multiplied by 3.6452?")
|
||||
answer = agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert len(agent.logs) == 8
|
||||
assert type(agent.logs[-1].error) is AgentMaxStepsError
|
||||
assert isinstance(answer, str)
|
||||
|
||||
def test_tool_descriptions_get_baked_in_system_prompt(self):
|
||||
tool = PythonInterpreterTool()
|
||||
|
|
|
@ -486,6 +486,7 @@ if char.isalpha():
|
|||
code = "import numpy.random as rd"
|
||||
evaluate_python_code(code, authorized_imports=["numpy.random"], state={})
|
||||
evaluate_python_code(code, authorized_imports=["numpy"], state={})
|
||||
evaluate_python_code(code, authorized_imports=["*"], state={})
|
||||
with pytest.raises(InterpreterError):
|
||||
evaluate_python_code(code, authorized_imports=["random"], state={})
|
||||
|
||||
|
|
Loading…
Reference in New Issue