Improve GradioUI file upload system
This commit is contained in:
parent
1f96560c92
commit
1d846072eb
|
@ -5,7 +5,7 @@ from smolagents import (
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[], model=HfApiModel(), max_steps=4, verbose=True
|
tools=[], model=HfApiModel(), max_steps=4, verbosity_level=0
|
||||||
)
|
)
|
||||||
|
|
||||||
GradioUI(agent, file_upload_folder='./data').launch()
|
GradioUI(agent, file_upload_folder='./data').launch()
|
||||||
|
|
|
@ -396,7 +396,7 @@ class MultiStepAgent:
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
try:
|
try:
|
||||||
return self.model(self.input_messages)
|
return self.model(self.input_messages).content
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error in generating final LLM output:\n{e}"
|
return f"Error in generating final LLM output:\n{e}"
|
||||||
|
|
||||||
|
@ -666,7 +666,9 @@ You have been provided with these additional arguments, that you can access usin
|
||||||
Now begin!""",
|
Now begin!""",
|
||||||
}
|
}
|
||||||
|
|
||||||
answer_facts = self.model([message_prompt_facts, message_prompt_task])
|
answer_facts = self.model(
|
||||||
|
[message_prompt_facts, message_prompt_task]
|
||||||
|
).content
|
||||||
|
|
||||||
message_system_prompt_plan = {
|
message_system_prompt_plan = {
|
||||||
"role": MessageRole.SYSTEM,
|
"role": MessageRole.SYSTEM,
|
||||||
|
@ -688,7 +690,7 @@ Now begin!""",
|
||||||
answer_plan = self.model(
|
answer_plan = self.model(
|
||||||
[message_system_prompt_plan, message_user_prompt_plan],
|
[message_system_prompt_plan, message_user_prompt_plan],
|
||||||
stop_sequences=["<end_plan>"],
|
stop_sequences=["<end_plan>"],
|
||||||
)
|
).content
|
||||||
|
|
||||||
final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task:
|
final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task:
|
||||||
```
|
```
|
||||||
|
@ -722,7 +724,7 @@ Now begin!""",
|
||||||
}
|
}
|
||||||
facts_update = self.model(
|
facts_update = self.model(
|
||||||
[facts_update_system_prompt] + agent_memory + [facts_update_message]
|
[facts_update_system_prompt] + agent_memory + [facts_update_message]
|
||||||
)
|
).content
|
||||||
|
|
||||||
# Redact updated plan
|
# Redact updated plan
|
||||||
plan_update_message = {
|
plan_update_message = {
|
||||||
|
@ -807,17 +809,26 @@ class ToolCallingAgent(MultiStepAgent):
|
||||||
tools_to_call_from=list(self.tools.values()),
|
tools_to_call_from=list(self.tools.values()),
|
||||||
stop_sequences=["Observation:"],
|
stop_sequences=["Observation:"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract tool call from model output
|
# Extract tool call from model output
|
||||||
if type(model_message.tool_calls) is list and len(model_message.tool_calls) > 0:
|
if (
|
||||||
|
type(model_message.tool_calls) is list
|
||||||
|
and len(model_message.tool_calls) > 0
|
||||||
|
):
|
||||||
tool_calls = model_message.tool_calls[0]
|
tool_calls = model_message.tool_calls[0]
|
||||||
tool_arguments = tool_calls.function.arguments
|
tool_arguments = tool_calls.function.arguments
|
||||||
tool_name, tool_call_id = tool_calls.function.name, tool_calls.id
|
tool_name, tool_call_id = tool_calls.function.name, tool_calls.id
|
||||||
else:
|
else:
|
||||||
start, end = model_message.content.find('{'), model_message.content.rfind('}') + 1
|
start, end = (
|
||||||
|
model_message.content.find("{"),
|
||||||
|
model_message.content.rfind("}") + 1,
|
||||||
|
)
|
||||||
tool_calls = json.loads(model_message.content[start:end])
|
tool_calls = json.loads(model_message.content[start:end])
|
||||||
tool_arguments = tool_calls["tool_arguments"]
|
tool_arguments = tool_calls["tool_arguments"]
|
||||||
tool_name, tool_call_id = tool_calls["tool_name"], f"call_{len(self.logs)}"
|
tool_name, tool_call_id = (
|
||||||
|
tool_calls["tool_name"],
|
||||||
|
f"call_{len(self.logs)}",
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AgentGenerationError(
|
raise AgentGenerationError(
|
||||||
|
|
|
@ -27,14 +27,15 @@ def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
|
||||||
"""Extract ChatMessage objects from agent steps"""
|
"""Extract ChatMessage objects from agent steps"""
|
||||||
if isinstance(step_log, ActionStep):
|
if isinstance(step_log, ActionStep):
|
||||||
yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "")
|
yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "")
|
||||||
if step_log.tool_call is not None:
|
if step_log.tool_calls is not None:
|
||||||
used_code = step_log.tool_call.name == "code interpreter"
|
first_tool_call = step_log.tool_calls[0]
|
||||||
content = step_log.tool_call.arguments
|
used_code = first_tool_call.name == "code interpreter"
|
||||||
|
content = first_tool_call.arguments
|
||||||
if used_code:
|
if used_code:
|
||||||
content = f"```py\n{content}\n```"
|
content = f"```py\n{content}\n```"
|
||||||
yield gr.ChatMessage(
|
yield gr.ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
metadata={"title": f"🛠️ Used tool {step_log.tool_call.name}"},
|
metadata={"title": f"🛠️ Used tool {first_tool_call.name}"},
|
||||||
content=str(content),
|
content=str(content),
|
||||||
)
|
)
|
||||||
if step_log.observations is not None:
|
if step_log.observations is not None:
|
||||||
|
@ -103,6 +104,7 @@ class GradioUI:
|
||||||
def upload_file(
|
def upload_file(
|
||||||
self,
|
self,
|
||||||
file,
|
file,
|
||||||
|
file_uploads_log,
|
||||||
allowed_file_types=[
|
allowed_file_types=[
|
||||||
"application/pdf",
|
"application/pdf",
|
||||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||||
|
@ -110,14 +112,12 @@ class GradioUI:
|
||||||
],
|
],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Handle file uploads, default allowed types are pdf, docx, and .txt
|
Handle file uploads, default allowed types are .pdf, .docx, and .txt
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Check if file is uploaded
|
|
||||||
if file is None:
|
if file is None:
|
||||||
return "No file uploaded"
|
return "No file uploaded"
|
||||||
|
|
||||||
# Check if file is in allowed filetypes
|
|
||||||
try:
|
try:
|
||||||
mime_type, _ = mimetypes.guess_type(file.name)
|
mime_type, _ = mimetypes.guess_type(file.name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -148,11 +148,23 @@ class GradioUI:
|
||||||
)
|
)
|
||||||
shutil.copy(file.name, file_path)
|
shutil.copy(file.name, file_path)
|
||||||
|
|
||||||
return f"File uploaded successfully to {self.file_upload_folder}"
|
return gr.Textbox(
|
||||||
|
f"File uploaded: {file_path}", visible=True
|
||||||
|
), file_uploads_log + [file_path]
|
||||||
|
|
||||||
|
def log_user_message(self, text_input, file_uploads_log):
|
||||||
|
return (
|
||||||
|
text_input
|
||||||
|
+ f"\nYou have been provided with these files, which might be helpful or not: {file_uploads_log}"
|
||||||
|
if len(file_uploads_log) > 0
|
||||||
|
else "",
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
|
||||||
def launch(self):
|
def launch(self):
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
stored_message = gr.State([])
|
stored_messages = gr.State([])
|
||||||
|
file_uploads_log = gr.State([])
|
||||||
chatbot = gr.Chatbot(
|
chatbot = gr.Chatbot(
|
||||||
label="Agent",
|
label="Agent",
|
||||||
type="messages",
|
type="messages",
|
||||||
|
@ -163,14 +175,21 @@ class GradioUI:
|
||||||
)
|
)
|
||||||
# If an upload folder is provided, enable the upload feature
|
# If an upload folder is provided, enable the upload feature
|
||||||
if self.file_upload_folder is not None:
|
if self.file_upload_folder is not None:
|
||||||
upload_file = gr.File(label="Upload a file")
|
upload_file = gr.File(label="Upload a file", height=1)
|
||||||
upload_status = gr.Textbox(label="Upload Status", interactive=False)
|
upload_status = gr.Textbox(
|
||||||
|
label="Upload Status", interactive=False, visible=False
|
||||||
upload_file.change(self.upload_file, [upload_file], [upload_status])
|
)
|
||||||
|
upload_file.change(
|
||||||
|
self.upload_file,
|
||||||
|
[upload_file, file_uploads_log],
|
||||||
|
[upload_status, file_uploads_log],
|
||||||
|
)
|
||||||
text_input = gr.Textbox(lines=1, label="Chat Message")
|
text_input = gr.Textbox(lines=1, label="Chat Message")
|
||||||
text_input.submit(
|
text_input.submit(
|
||||||
lambda s: (s, ""), [text_input], [stored_message, text_input]
|
self.log_user_message,
|
||||||
).then(self.interact_with_agent, [stored_message, chatbot], [chatbot])
|
[text_input, file_uploads_log],
|
||||||
|
[stored_messages, text_input],
|
||||||
|
).then(self.interact_with_agent, [stored_messages, chatbot], [chatbot])
|
||||||
|
|
||||||
demo.launch()
|
demo.launch()
|
||||||
|
|
||||||
|
|
|
@ -36,6 +36,8 @@ from transformers import (
|
||||||
StoppingCriteriaList,
|
StoppingCriteriaList,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
)
|
)
|
||||||
|
from transformers.utils.import_utils import _is_package_available
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
from .tools import Tool
|
from .tools import Tool
|
||||||
|
@ -52,13 +54,9 @@ DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
|
||||||
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_code>",
|
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_code>",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
if _is_package_available("litellm"):
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
is_litellm_available = True
|
|
||||||
except ImportError:
|
|
||||||
is_litellm_available = False
|
|
||||||
|
|
||||||
|
|
||||||
class MessageRole(str, Enum):
|
class MessageRole(str, Enum):
|
||||||
USER = "user"
|
USER = "user"
|
||||||
|
@ -159,7 +157,7 @@ class Model:
|
||||||
stop_sequences: Optional[List[str]] = None,
|
stop_sequences: Optional[List[str]] = None,
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
max_tokens: int = 1500,
|
max_tokens: int = 1500,
|
||||||
) -> str:
|
) -> ChatCompletionOutputMessage:
|
||||||
"""Process the input messages and return the model's response.
|
"""Process the input messages and return the model's response.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
|
@ -174,15 +172,7 @@ class Model:
|
||||||
Returns:
|
Returns:
|
||||||
`str`: The text content of the model's response.
|
`str`: The text content of the model's response.
|
||||||
"""
|
"""
|
||||||
if not isinstance(messages, List):
|
pass # To be implemented in child classes!
|
||||||
raise ValueError(
|
|
||||||
"Messages should be a list of dictionaries with 'role' and 'content' keys."
|
|
||||||
)
|
|
||||||
if stop_sequences is None:
|
|
||||||
stop_sequences = []
|
|
||||||
response = self.generate(messages, stop_sequences, grammar, max_tokens)
|
|
||||||
|
|
||||||
return remove_stop_sequences(response, stop_sequences)
|
|
||||||
|
|
||||||
|
|
||||||
class HfApiModel(Model):
|
class HfApiModel(Model):
|
||||||
|
@ -238,7 +228,7 @@ class HfApiModel(Model):
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
max_tokens: int = 1500,
|
max_tokens: int = 1500,
|
||||||
tools_to_call_from: Optional[List[Tool]] = None,
|
tools_to_call_from: Optional[List[Tool]] = None,
|
||||||
) -> str:
|
) -> ChatCompletionOutputMessage:
|
||||||
"""
|
"""
|
||||||
Gets an LLM output message for the given list of input messages.
|
Gets an LLM output message for the given list of input messages.
|
||||||
If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call.
|
If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call.
|
||||||
|
@ -407,7 +397,7 @@ class LiteLLMModel(Model):
|
||||||
api_key=None,
|
api_key=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if not is_litellm_available:
|
if not _is_package_available("litellm"):
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"litellm not found. Install it with `pip install litellm`"
|
"litellm not found. Install it with `pip install litellm`"
|
||||||
)
|
)
|
||||||
|
@ -426,7 +416,7 @@ class LiteLLMModel(Model):
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
max_tokens: int = 1500,
|
max_tokens: int = 1500,
|
||||||
tools_to_call_from: Optional[List[Tool]] = None,
|
tools_to_call_from: Optional[List[Tool]] = None,
|
||||||
) -> str:
|
) -> ChatCompletionOutputMessage:
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
messages, role_conversions=tool_role_conversions
|
messages, role_conversions=tool_role_conversions
|
||||||
)
|
)
|
||||||
|
@ -497,7 +487,7 @@ class OpenAIServerModel(Model):
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
max_tokens: int = 1500,
|
max_tokens: int = 1500,
|
||||||
tools_to_call_from: Optional[List[Tool]] = None,
|
tools_to_call_from: Optional[List[Tool]] = None,
|
||||||
) -> str:
|
) -> ChatCompletionOutputMessage:
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
messages, role_conversions=tool_role_conversions
|
messages, role_conversions=tool_role_conversions
|
||||||
)
|
)
|
||||||
|
|
|
@ -367,9 +367,10 @@ class AgentTests(unittest.TestCase):
|
||||||
model=fake_code_model_no_return, # use this callable because it never ends
|
model=fake_code_model_no_return, # use this callable because it never ends
|
||||||
max_steps=5,
|
max_steps=5,
|
||||||
)
|
)
|
||||||
agent.run("What is 2 multiplied by 3.6452?")
|
answer = agent.run("What is 2 multiplied by 3.6452?")
|
||||||
assert len(agent.logs) == 8
|
assert len(agent.logs) == 8
|
||||||
assert type(agent.logs[-1].error) is AgentMaxStepsError
|
assert type(agent.logs[-1].error) is AgentMaxStepsError
|
||||||
|
assert isinstance(answer, str)
|
||||||
|
|
||||||
def test_tool_descriptions_get_baked_in_system_prompt(self):
|
def test_tool_descriptions_get_baked_in_system_prompt(self):
|
||||||
tool = PythonInterpreterTool()
|
tool = PythonInterpreterTool()
|
||||||
|
|
|
@ -486,6 +486,7 @@ if char.isalpha():
|
||||||
code = "import numpy.random as rd"
|
code = "import numpy.random as rd"
|
||||||
evaluate_python_code(code, authorized_imports=["numpy.random"], state={})
|
evaluate_python_code(code, authorized_imports=["numpy.random"], state={})
|
||||||
evaluate_python_code(code, authorized_imports=["numpy"], state={})
|
evaluate_python_code(code, authorized_imports=["numpy"], state={})
|
||||||
|
evaluate_python_code(code, authorized_imports=["*"], state={})
|
||||||
with pytest.raises(InterpreterError):
|
with pytest.raises(InterpreterError):
|
||||||
evaluate_python_code(code, authorized_imports=["random"], state={})
|
evaluate_python_code(code, authorized_imports=["random"], state={})
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue