Support gradio chatbot with continued discussion
This commit is contained in:
parent
23ab4a9df3
commit
0ada2ebc27
|
@ -144,7 +144,7 @@ class AgentImage(AgentType, ImageType):
|
||||||
if self._raw is not None:
|
if self._raw is not None:
|
||||||
directory = tempfile.mkdtemp()
|
directory = tempfile.mkdtemp()
|
||||||
self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
|
self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
|
||||||
self._raw.save(self._path)
|
self._raw.save(self._path, format="png")
|
||||||
return self._path
|
return self._path
|
||||||
|
|
||||||
if self._tensor is not None:
|
if self._tensor is not None:
|
||||||
|
@ -155,12 +155,11 @@ class AgentImage(AgentType, ImageType):
|
||||||
|
|
||||||
directory = tempfile.mkdtemp()
|
directory = tempfile.mkdtemp()
|
||||||
self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
|
self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
|
||||||
|
|
||||||
img.save(self._path, format="png")
|
img.save(self._path, format="png")
|
||||||
|
|
||||||
return self._path
|
return self._path
|
||||||
|
|
||||||
def save(self, output_bytes, format = None, **params):
|
def save(self, output_bytes, format : str = None, **params):
|
||||||
"""
|
"""
|
||||||
Saves the image to a file.
|
Saves the image to a file.
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -22,7 +22,7 @@ from rich.syntax import Syntax
|
||||||
from transformers.utils import is_torch_available
|
from transformers.utils import is_torch_available
|
||||||
from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content
|
from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content
|
||||||
from .agent_types import AgentAudio, AgentImage
|
from .agent_types import AgentAudio, AgentImage
|
||||||
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
|
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool
|
||||||
from .llm_engine import HfApiEngine, MessageRole
|
from .llm_engine import HfApiEngine, MessageRole
|
||||||
from .monitoring import Monitor
|
from .monitoring import Monitor
|
||||||
from .prompts import (
|
from .prompts import (
|
||||||
|
@ -42,13 +42,11 @@ from .tools import (
|
||||||
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||||
Tool,
|
Tool,
|
||||||
get_tool_description_with_args,
|
get_tool_description_with_args,
|
||||||
load_tool,
|
|
||||||
Toolbox,
|
Toolbox,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
HUGGINGFACE_DEFAULT_TOOLS = {}
|
HUGGINGFACE_DEFAULT_TOOLS = {}
|
||||||
_tools_are_initialized = False
|
|
||||||
|
|
||||||
class AgentError(Exception):
|
class AgentError(Exception):
|
||||||
"""Base class for other agent-related exceptions"""
|
"""Base class for other agent-related exceptions"""
|
||||||
|
@ -101,9 +99,12 @@ class PlanningStep:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TaskStep:
|
class TaskStep:
|
||||||
system_prompt: str
|
|
||||||
task: str
|
task: str
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SystemPromptStep:
|
||||||
|
system_prompt: str
|
||||||
|
|
||||||
def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str:
|
def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str:
|
||||||
tool_descriptions = toolbox.show_tool_descriptions(tool_description_template)
|
tool_descriptions = toolbox.show_tool_descriptions(tool_description_template)
|
||||||
prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions)
|
prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions)
|
||||||
|
@ -189,7 +190,7 @@ class BaseAgent:
|
||||||
self._toolbox, self.system_prompt_template, self.tool_description_template
|
self._toolbox, self.system_prompt_template, self.tool_description_template
|
||||||
)
|
)
|
||||||
self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents)
|
self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents)
|
||||||
self.prompt = None
|
self.prompt_messages = None
|
||||||
self.logs = []
|
self.logs = []
|
||||||
self.task = None
|
self.task = None
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
@ -208,8 +209,7 @@ class BaseAgent:
|
||||||
"""Get the toolbox currently available to the agent"""
|
"""Get the toolbox currently available to the agent"""
|
||||||
return self._toolbox
|
return self._toolbox
|
||||||
|
|
||||||
def initialize_for_run(self):
|
def initialize_system_prompt(self):
|
||||||
self.token_count = 0
|
|
||||||
self.system_prompt = format_prompt_with_tools(
|
self.system_prompt = format_prompt_with_tools(
|
||||||
self._toolbox,
|
self._toolbox,
|
||||||
self.system_prompt_template,
|
self.system_prompt_template,
|
||||||
|
@ -220,27 +220,25 @@ class BaseAgent:
|
||||||
self.system_prompt = format_prompt_with_imports(
|
self.system_prompt = format_prompt_with_imports(
|
||||||
self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports))
|
self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports))
|
||||||
)
|
)
|
||||||
self.logs = [TaskStep(system_prompt=self.system_prompt, task=self.task)]
|
|
||||||
console.rule("[bold]New task", characters='=')
|
return self.system_prompt
|
||||||
console.print(self.task)
|
|
||||||
|
|
||||||
def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]:
|
def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages
|
Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages
|
||||||
that can be used as input to the LLM.
|
that can be used as input to the LLM.
|
||||||
"""
|
"""
|
||||||
prompt_message = {"role": MessageRole.SYSTEM, "content": self.logs[0].system_prompt}
|
memory = []
|
||||||
task_message = {
|
for i, step_log in enumerate(self.logs):
|
||||||
"role": MessageRole.USER,
|
if isinstance(step_log, SystemPromptStep):
|
||||||
"content": "Task: " + self.logs[0].task,
|
if not summary_mode:
|
||||||
}
|
thought_message = {
|
||||||
if summary_mode:
|
"role": MessageRole.SYSTEM,
|
||||||
memory = [task_message]
|
"content": step_log.system_prompt.strip(),
|
||||||
else:
|
}
|
||||||
memory = [prompt_message, task_message]
|
memory.append(thought_message)
|
||||||
for i, step_log in enumerate(self.logs[1:]):
|
|
||||||
|
|
||||||
if isinstance(step_log, PlanningStep):
|
elif isinstance(step_log, PlanningStep):
|
||||||
thought_message = {
|
thought_message = {
|
||||||
"role": MessageRole.ASSISTANT,
|
"role": MessageRole.ASSISTANT,
|
||||||
"content": "[FACTS LIST]:\n" + step_log.facts.strip(),
|
"content": "[FACTS LIST]:\n" + step_log.facts.strip(),
|
||||||
|
@ -398,21 +396,21 @@ class ReactAgent(BaseAgent):
|
||||||
"""
|
"""
|
||||||
This method provides a final answer to the task, based on the logs of the agent's interactions.
|
This method provides a final answer to the task, based on the logs of the agent's interactions.
|
||||||
"""
|
"""
|
||||||
self.prompt = [
|
self.prompt_messages = [
|
||||||
{
|
{
|
||||||
"role": MessageRole.SYSTEM,
|
"role": MessageRole.SYSTEM,
|
||||||
"content": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
|
"content": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
self.prompt += self.write_inner_memory_from_logs()[1:]
|
self.prompt_messages += self.write_inner_memory_from_logs()[1:]
|
||||||
self.prompt += [
|
self.prompt_messages += [
|
||||||
{
|
{
|
||||||
"role": MessageRole.USER,
|
"role": MessageRole.USER,
|
||||||
"content": f"Based on the above, please provide an answer to the following user request:\n{task}",
|
"content": f"Based on the above, please provide an answer to the following user request:\n{task}",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
try:
|
try:
|
||||||
return self.llm_engine(self.prompt)
|
return self.llm_engine(self.prompt_messages)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error in generating final LLM output: {e}."
|
error_msg = f"Error in generating final LLM output: {e}."
|
||||||
console.print(f"[bold red]{error_msg}[/bold red]")
|
console.print(f"[bold red]{error_msg}[/bold red]")
|
||||||
|
@ -423,7 +421,10 @@ class ReactAgent(BaseAgent):
|
||||||
Runs the agent for the given task.
|
Runs the agent for the given task.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task (`str`): The task to perform
|
task (`str`): The task to perform.
|
||||||
|
stream (`bool`): Wether to run in a streaming way.
|
||||||
|
reset (`bool`): Wether to reset the conversation or keep it going from previous run.
|
||||||
|
oneshot (`bool`): Should the agent run in one shot or multi-step fashion?
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```py
|
```py
|
||||||
|
@ -436,10 +437,23 @@ class ReactAgent(BaseAgent):
|
||||||
if len(kwargs) > 0:
|
if len(kwargs) > 0:
|
||||||
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
|
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
|
||||||
self.state = kwargs.copy()
|
self.state = kwargs.copy()
|
||||||
|
|
||||||
|
self.initialize_system_prompt()
|
||||||
|
system_prompt_step = SystemPromptStep(system_prompt=self.system_prompt)
|
||||||
|
|
||||||
if reset:
|
if reset:
|
||||||
self.initialize_for_run()
|
self.token_count = 0
|
||||||
|
self.logs = []
|
||||||
|
self.logs.append(system_prompt_step)
|
||||||
else:
|
else:
|
||||||
self.logs.append(TaskStep(task=task))
|
if len(self.logs) > 0:
|
||||||
|
self.logs[0] = system_prompt_step
|
||||||
|
else:
|
||||||
|
self.logs.append(system_prompt_step)
|
||||||
|
|
||||||
|
console.rule("[bold]New task", characters='=')
|
||||||
|
console.print(self.task)
|
||||||
|
self.logs.append(TaskStep(task=task))
|
||||||
|
|
||||||
if oneshot:
|
if oneshot:
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
|
@ -676,20 +690,20 @@ class JsonAgent(ReactAgent):
|
||||||
"""
|
"""
|
||||||
agent_memory = self.write_inner_memory_from_logs()
|
agent_memory = self.write_inner_memory_from_logs()
|
||||||
|
|
||||||
self.prompt = agent_memory
|
self.prompt_messages = agent_memory
|
||||||
|
|
||||||
# Add new step in logs
|
# Add new step in logs
|
||||||
log_entry.agent_memory = agent_memory.copy()
|
log_entry.agent_memory = agent_memory.copy()
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
console.rule("[italic]Calling LLM engine with this last message:", align="left")
|
console.rule("[italic]Calling LLM engine with this last message:", align="left")
|
||||||
console.print(self.prompt[-1])
|
console.print(self.prompt_messages[-1])
|
||||||
console.rule()
|
console.rule()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
|
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
|
||||||
llm_output = self.llm_engine(
|
llm_output = self.llm_engine(
|
||||||
self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args
|
self.prompt_messages, stop_sequences=["<end_action>", "Observation:"], **additional_args
|
||||||
)
|
)
|
||||||
log_entry.llm_output = llm_output
|
log_entry.llm_output = llm_output
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -796,20 +810,20 @@ class CodeAgent(ReactAgent):
|
||||||
"""
|
"""
|
||||||
agent_memory = self.write_inner_memory_from_logs()
|
agent_memory = self.write_inner_memory_from_logs()
|
||||||
|
|
||||||
self.prompt = agent_memory.copy()
|
self.prompt_messages = agent_memory.copy()
|
||||||
|
|
||||||
# Add new step in logs
|
# Add new step in logs
|
||||||
log_entry.agent_memory = agent_memory.copy()
|
log_entry.agent_memory = agent_memory.copy()
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
console.rule("[italic]Calling LLM engine with these last messages:", align="left")
|
console.rule("[italic]Calling LLM engine with these last messages:", align="left")
|
||||||
console.print(self.prompt[-2:])
|
console.print(self.prompt_messages[-2:])
|
||||||
console.rule()
|
console.rule()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
|
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
|
||||||
llm_output = self.llm_engine(
|
llm_output = self.llm_engine(
|
||||||
self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args
|
self.prompt_messages, stop_sequences=["<end_action>", "Observation:"], **additional_args
|
||||||
)
|
)
|
||||||
log_entry.llm_output = llm_output
|
log_entry.llm_output = llm_output
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -893,7 +907,7 @@ You have been submitted this task by your manager.
|
||||||
Task:
|
Task:
|
||||||
{task}
|
{task}
|
||||||
---
|
---
|
||||||
You're helping your manager solve a wider task: so make sure to not provide a one-line answer, but give as much information as possible so that they have a clear understanding of the answer.
|
You're helping your manager solve a wider task: so make sure to not provide a one-line answer, but give as much information as possible to give them a clear understanding of the answer.
|
||||||
|
|
||||||
Your final_answer WILL HAVE to contain these parts:
|
Your final_answer WILL HAVE to contain these parts:
|
||||||
### 1. Task outcome (short version):
|
### 1. Task outcome (short version):
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import importlib.util
|
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
@ -25,7 +24,7 @@ from huggingface_hub import hf_hub_download, list_spaces
|
||||||
|
|
||||||
from transformers.utils import is_offline_mode
|
from transformers.utils import is_offline_mode
|
||||||
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
|
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
|
||||||
from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool
|
from .tools import TOOL_CONFIG_FILE, Tool
|
||||||
|
|
||||||
|
|
||||||
def custom_print(*args):
|
def custom_print(*args):
|
||||||
|
@ -97,12 +96,6 @@ class PreTool:
|
||||||
repo_id: str
|
repo_id: str
|
||||||
|
|
||||||
|
|
||||||
HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB = [
|
|
||||||
"image-transformation",
|
|
||||||
"text-to-image",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def get_remote_tools(logger, organization="huggingface-tools"):
|
def get_remote_tools(logger, organization="huggingface-tools"):
|
||||||
if is_offline_mode():
|
if is_offline_mode():
|
||||||
logger.info("You are in offline mode, so remote tools are not available.")
|
logger.info("You are in offline mode, so remote tools are not available.")
|
||||||
|
@ -128,26 +121,6 @@ def get_remote_tools(logger, organization="huggingface-tools"):
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
|
||||||
def setup_default_tools():
|
|
||||||
default_tools = {}
|
|
||||||
main_module = importlib.import_module("transformers")
|
|
||||||
tools_module = main_module.agents
|
|
||||||
|
|
||||||
for task_name, tool_class_name in TOOL_MAPPING.items():
|
|
||||||
tool_class = getattr(tools_module, tool_class_name)
|
|
||||||
tool_instance = tool_class()
|
|
||||||
default_tools[tool_class.name] = PreTool(
|
|
||||||
name=tool_instance.name,
|
|
||||||
inputs=tool_instance.inputs,
|
|
||||||
output_type=tool_instance.output_type,
|
|
||||||
task=task_name,
|
|
||||||
description=tool_instance.description,
|
|
||||||
repo_id=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
return default_tools
|
|
||||||
|
|
||||||
|
|
||||||
class PythonInterpreterTool(Tool):
|
class PythonInterpreterTool(Tool):
|
||||||
name = "python_interpreter"
|
name = "python_interpreter"
|
||||||
description = "This is a tool that evaluates python code. It can be used to perform calculations."
|
description = "This is a tool that evaluates python code. It can be used to perform calculations."
|
||||||
|
|
|
@ -1,88 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# coding=utf-8
|
|
||||||
|
|
||||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import re
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from transformers import AutoProcessor, VisionEncoderDecoderModel
|
|
||||||
from transformers.utils import is_vision_available
|
|
||||||
from .tools import PipelineTool
|
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
class DocumentQuestionAnsweringTool(PipelineTool):
|
|
||||||
default_checkpoint = "naver-clova-ix/donut-base-finetuned-docvqa"
|
|
||||||
description = "This is a tool that answers a question about an document (pdf). It returns a string that contains the answer to the question."
|
|
||||||
name = "document_qa"
|
|
||||||
pre_processor_class = AutoProcessor
|
|
||||||
model_class = VisionEncoderDecoderModel
|
|
||||||
|
|
||||||
inputs = {
|
|
||||||
"document": {
|
|
||||||
"type": "image",
|
|
||||||
"description": "The image containing the information. Can be a PIL Image or a string path to the image.",
|
|
||||||
},
|
|
||||||
"question": {"type": "string", "description": "The question in English"},
|
|
||||||
}
|
|
||||||
output_type = "string"
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
if not is_vision_available():
|
|
||||||
raise ValueError("Pillow must be installed to use the DocumentQuestionAnsweringTool.")
|
|
||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def encode(self, document: "Image", question: str):
|
|
||||||
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
|
|
||||||
prompt = task_prompt.replace("{user_input}", question)
|
|
||||||
decoder_input_ids = self.pre_processor.tokenizer(
|
|
||||||
prompt, add_special_tokens=False, return_tensors="pt"
|
|
||||||
).input_ids
|
|
||||||
if isinstance(document, str):
|
|
||||||
img = Image.open(document).convert("RGB")
|
|
||||||
img_array = np.array(img).transpose(2, 0, 1)
|
|
||||||
document = torch.from_numpy(img_array)
|
|
||||||
pixel_values = self.pre_processor(document, return_tensors="pt").pixel_values
|
|
||||||
|
|
||||||
return {"decoder_input_ids": decoder_input_ids, "pixel_values": pixel_values}
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
|
||||||
return self.model.generate(
|
|
||||||
inputs["pixel_values"].to(self.device),
|
|
||||||
decoder_input_ids=inputs["decoder_input_ids"].to(self.device),
|
|
||||||
max_length=self.model.decoder.config.max_position_embeddings,
|
|
||||||
early_stopping=True,
|
|
||||||
pad_token_id=self.pre_processor.tokenizer.pad_token_id,
|
|
||||||
eos_token_id=self.pre_processor.tokenizer.eos_token_id,
|
|
||||||
use_cache=True,
|
|
||||||
num_beams=1,
|
|
||||||
bad_words_ids=[[self.pre_processor.tokenizer.unk_token_id]],
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
).sequences
|
|
||||||
|
|
||||||
def decode(self, outputs):
|
|
||||||
sequence = self.pre_processor.batch_decode(outputs)[0]
|
|
||||||
sequence = sequence.replace(self.pre_processor.tokenizer.eos_token, "")
|
|
||||||
sequence = sequence.replace(self.pre_processor.tokenizer.pad_token, "")
|
|
||||||
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
|
|
||||||
sequence = self.pre_processor.token2json(sequence)
|
|
||||||
|
|
||||||
return sequence["answer"]
|
|
|
@ -1,58 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# coding=utf-8
|
|
||||||
|
|
||||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from transformers import AutoModelForVisualQuestionAnswering, AutoProcessor
|
|
||||||
from transformers.utils import requires_backends
|
|
||||||
from .tools import PipelineTool
|
|
||||||
|
|
||||||
|
|
||||||
class ImageQuestionAnsweringTool(PipelineTool):
|
|
||||||
default_checkpoint = "dandelin/vilt-b32-finetuned-vqa"
|
|
||||||
description = (
|
|
||||||
"This is a tool that answers a question about an image. It "
|
|
||||||
"returns a text that is the answer to the question."
|
|
||||||
)
|
|
||||||
name = "image_qa"
|
|
||||||
pre_processor_class = AutoProcessor
|
|
||||||
model_class = AutoModelForVisualQuestionAnswering
|
|
||||||
|
|
||||||
inputs = {
|
|
||||||
"image": {
|
|
||||||
"type": "image",
|
|
||||||
"description": "The image containing the information. Can be a PIL Image or a string path to the image.",
|
|
||||||
},
|
|
||||||
"question": {"type": "string", "description": "The question in English"},
|
|
||||||
}
|
|
||||||
output_type = "string"
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["vision"])
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def encode(self, image: "Image", question: str):
|
|
||||||
return self.pre_processor(image, question, return_tensors="pt")
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
|
||||||
with torch.no_grad():
|
|
||||||
return self.model(**inputs).logits
|
|
||||||
|
|
||||||
def decode(self, outputs):
|
|
||||||
idx = outputs.argmax(-1).item()
|
|
||||||
return self.model.config.id2label[idx]
|
|
|
@ -17,20 +17,8 @@
|
||||||
from .agent_types import AgentAudio, AgentImage, AgentText
|
from .agent_types import AgentAudio, AgentImage, AgentText
|
||||||
from .utils import console
|
from .utils import console
|
||||||
|
|
||||||
|
|
||||||
def pull_message(step_log: dict, test_mode: bool = True):
|
def pull_message(step_log: dict, test_mode: bool = True):
|
||||||
try:
|
from gradio import ChatMessage
|
||||||
from gradio import ChatMessage
|
|
||||||
except ImportError:
|
|
||||||
if test_mode:
|
|
||||||
|
|
||||||
class ChatMessage:
|
|
||||||
def __init__(self, role, content, metadata=None):
|
|
||||||
self.role = role
|
|
||||||
self.content = content
|
|
||||||
self.metadata = metadata
|
|
||||||
else:
|
|
||||||
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
|
|
||||||
|
|
||||||
if step_log.get("rationale"):
|
if step_log.get("rationale"):
|
||||||
yield ChatMessage(role="assistant", content=step_log["rationale"])
|
yield ChatMessage(role="assistant", content=step_log["rationale"])
|
||||||
|
@ -54,23 +42,11 @@ def pull_message(step_log: dict, test_mode: bool = True):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def stream_to_gradio(agent, task: str, test_mode: bool = False, **kwargs):
|
def stream_to_gradio(agent, task: str, test_mode: bool = False, reset_agent_memory: bool=False, **kwargs):
|
||||||
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
|
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
|
||||||
|
from gradio import ChatMessage
|
||||||
|
|
||||||
try:
|
for step_log in agent.run(task, stream=True, reset=reset_agent_memory, **kwargs):
|
||||||
from gradio import ChatMessage
|
|
||||||
except ImportError:
|
|
||||||
if test_mode:
|
|
||||||
|
|
||||||
class ChatMessage:
|
|
||||||
def __init__(self, role, content, metadata=None):
|
|
||||||
self.role = role
|
|
||||||
self.content = content
|
|
||||||
self.metadata = metadata
|
|
||||||
else:
|
|
||||||
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
|
|
||||||
|
|
||||||
for step_log in agent.run(task, stream=True, **kwargs):
|
|
||||||
if isinstance(step_log, dict):
|
if isinstance(step_log, dict):
|
||||||
for message in pull_message(step_log, test_mode=test_mode):
|
for message in pull_message(step_log, test_mode=test_mode):
|
||||||
yield message
|
yield message
|
||||||
|
|
|
@ -395,7 +395,7 @@ Do not add anything else."""
|
||||||
SYSTEM_PROMPT_PLAN = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
|
SYSTEM_PROMPT_PLAN = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
|
||||||
|
|
||||||
Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
|
Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
|
||||||
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
|
This plan should involve individual tasks based on the available tools, that if executed correctly will yield the correct answer.
|
||||||
Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
|
Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
|
||||||
After writing the final step of the plan, write the '\n<end_plan>' tag and stop there."""
|
After writing the final step of the plan, write the '\n<end_plan>' tag and stop there."""
|
||||||
|
|
||||||
|
@ -466,7 +466,7 @@ Here is the up to date list of facts that you know:
|
||||||
```
|
```
|
||||||
|
|
||||||
Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
|
Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
|
||||||
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
|
This plan should involve individual tasks based on the available tools, that if executed correctly will yield the correct answer.
|
||||||
Beware that you have {remaining_steps} steps remaining.
|
Beware that you have {remaining_steps} steps remaining.
|
||||||
Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
|
Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
|
||||||
After writing the final step of the plan, write the '\n<end_plan>' tag and stop there.
|
After writing the final step of the plan, write the '\n<end_plan>' tag and stop there.
|
||||||
|
|
|
@ -1,39 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# coding=utf-8
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from transformers import WhisperForConditionalGeneration, WhisperProcessor
|
|
||||||
from .tools import PipelineTool
|
|
||||||
|
|
||||||
|
|
||||||
class SpeechToTextTool(PipelineTool):
|
|
||||||
default_checkpoint = "distil-whisper/distil-large-v3"
|
|
||||||
description = "This is a tool that transcribes an audio into text. It returns the transcribed text."
|
|
||||||
name = "transcriber"
|
|
||||||
pre_processor_class = WhisperProcessor
|
|
||||||
model_class = WhisperForConditionalGeneration
|
|
||||||
|
|
||||||
inputs = {"audio": {"type": "audio", "description": "The audio to transcribe"}}
|
|
||||||
output_type = "string"
|
|
||||||
|
|
||||||
def encode(self, audio):
|
|
||||||
return self.pre_processor(audio, return_tensors="pt")
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
|
||||||
return self.model.generate(inputs["input_features"])
|
|
||||||
|
|
||||||
def decode(self, outputs):
|
|
||||||
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
|
|
@ -1,67 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# coding=utf-8
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
|
|
||||||
from transformers.utils import is_datasets_available
|
|
||||||
from .tools import PipelineTool
|
|
||||||
|
|
||||||
|
|
||||||
if is_datasets_available():
|
|
||||||
from datasets import load_dataset
|
|
||||||
|
|
||||||
|
|
||||||
class TextToSpeechTool(PipelineTool):
|
|
||||||
default_checkpoint = "microsoft/speecht5_tts"
|
|
||||||
description = (
|
|
||||||
"This is a tool that reads an English text out loud. It returns a waveform object containing the sound."
|
|
||||||
)
|
|
||||||
name = "text_to_speech"
|
|
||||||
pre_processor_class = SpeechT5Processor
|
|
||||||
model_class = SpeechT5ForTextToSpeech
|
|
||||||
post_processor_class = SpeechT5HifiGan
|
|
||||||
|
|
||||||
inputs = {"text": {"type": "string", "description": "The text to read out loud (in English)"}}
|
|
||||||
output_type = "audio"
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
if self.post_processor is None:
|
|
||||||
self.post_processor = "microsoft/speecht5_hifigan"
|
|
||||||
super().setup()
|
|
||||||
|
|
||||||
def encode(self, text, speaker_embeddings=None):
|
|
||||||
inputs = self.pre_processor(text=text, return_tensors="pt", truncation=True)
|
|
||||||
|
|
||||||
if speaker_embeddings is None:
|
|
||||||
if not is_datasets_available():
|
|
||||||
raise ImportError("Datasets needs to be installed if not passing speaker embeddings.")
|
|
||||||
|
|
||||||
embeddings_dataset = load_dataset(
|
|
||||||
"Matthijs/cmu-arctic-xvectors", split="validation", trust_remote_code=True
|
|
||||||
)
|
|
||||||
speaker_embeddings = torch.tensor(embeddings_dataset[7305]["xvector"]).unsqueeze(0)
|
|
||||||
|
|
||||||
return {"input_ids": inputs["input_ids"], "speaker_embeddings": speaker_embeddings}
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
|
||||||
with torch.no_grad():
|
|
||||||
return self.model.generate_speech(**inputs)
|
|
||||||
|
|
||||||
def decode(self, outputs):
|
|
||||||
with torch.no_grad():
|
|
||||||
return self.post_processor(outputs).cpu().detach()
|
|
|
@ -80,6 +80,19 @@ def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
|
||||||
return "space"
|
return "space"
|
||||||
|
|
||||||
|
|
||||||
|
def setup_default_tools():
|
||||||
|
default_tools = {}
|
||||||
|
main_module = importlib.import_module("transformers")
|
||||||
|
tools_module = main_module.agents
|
||||||
|
|
||||||
|
for task_name, tool_class_name in TOOL_MAPPING.items():
|
||||||
|
tool_class = getattr(tools_module, tool_class_name)
|
||||||
|
tool_instance = tool_class()
|
||||||
|
default_tools[tool_class.name] = tool_instance
|
||||||
|
|
||||||
|
return default_tools
|
||||||
|
|
||||||
|
|
||||||
# docstyle-ignore
|
# docstyle-ignore
|
||||||
APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo
|
APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo
|
||||||
from {module_name} import {class_name}
|
from {module_name} import {class_name}
|
||||||
|
@ -811,11 +824,6 @@ def launch_gradio_demo(tool_class: Tool):
|
||||||
|
|
||||||
|
|
||||||
TOOL_MAPPING = {
|
TOOL_MAPPING = {
|
||||||
"document_question_answering": "DocumentQuestionAnsweringTool",
|
|
||||||
"image_question_answering": "ImageQuestionAnsweringTool",
|
|
||||||
"speech_to_text": "SpeechToTextTool",
|
|
||||||
"text_to_speech": "TextToSpeechTool",
|
|
||||||
"translation": "TranslationTool",
|
|
||||||
"python_interpreter": "PythonInterpreterTool",
|
"python_interpreter": "PythonInterpreterTool",
|
||||||
"web_search": "DuckDuckGoSearchTool",
|
"web_search": "DuckDuckGoSearchTool",
|
||||||
}
|
}
|
||||||
|
@ -1018,18 +1026,14 @@ class Toolbox:
|
||||||
self._tools = {tool.name: tool for tool in tools}
|
self._tools = {tool.name: tool for tool in tools}
|
||||||
if add_base_tools:
|
if add_base_tools:
|
||||||
self.add_base_tools()
|
self.add_base_tools()
|
||||||
# self._load_tools_if_needed()
|
|
||||||
|
|
||||||
def add_base_tools(self, add_python_interpreter: bool = False):
|
def add_base_tools(self, add_python_interpreter: bool = False):
|
||||||
global _tools_are_initialized
|
|
||||||
global HUGGINGFACE_DEFAULT_TOOLS
|
global HUGGINGFACE_DEFAULT_TOOLS
|
||||||
if not _tools_are_initialized:
|
if len(HUGGINGFACE_DEFAULT_TOOLS.keys()) == 0:
|
||||||
HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools()
|
HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools()
|
||||||
_tools_are_initialized = True
|
|
||||||
for tool in HUGGINGFACE_DEFAULT_TOOLS.values():
|
for tool in HUGGINGFACE_DEFAULT_TOOLS.values():
|
||||||
if tool.name != "python_interpreter" or add_python_interpreter:
|
if tool.name != "python_interpreter" or add_python_interpreter:
|
||||||
self.add_tool(tool)
|
self.add_tool(tool)
|
||||||
# self._load_tools_if_needed()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tools(self) -> Dict[str, Tool]:
|
def tools(self) -> Dict[str, Tool]:
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from agents import load_tool, ReactCodeAgent, HfApiEngine
|
from agents import load_tool, CodeAgent, HfApiEngine
|
||||||
|
|
||||||
# Import tool from Hub
|
# Import tool from Hub
|
||||||
image_generation_tool = load_tool("m-ric/text-to-image", cache=False)
|
image_generation_tool = load_tool("m-ric/text-to-image", cache=False)
|
||||||
|
@ -10,7 +10,7 @@ search_tool = DuckDuckGoSearchTool()
|
||||||
|
|
||||||
llm_engine = HfApiEngine("Qwen/Qwen2.5-72B-Instruct")
|
llm_engine = HfApiEngine("Qwen/Qwen2.5-72B-Instruct")
|
||||||
# Initialize the agent with both tools
|
# Initialize the agent with both tools
|
||||||
agent = ReactCodeAgent(tools=[image_generation_tool, search_tool], llm_engine=llm_engine)
|
agent = CodeAgent(tools=[image_generation_tool, search_tool], llm_engine=llm_engine)
|
||||||
|
|
||||||
# Run it!
|
# Run it!
|
||||||
result = agent.run(
|
result = agent.run(
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
from agents import stream_to_gradio, HfApiEngine, load_tool, CodeAgent
|
||||||
|
|
||||||
|
image_generation_tool = load_tool("m-ric/text-to-image")
|
||||||
|
|
||||||
|
llm_engine = HfApiEngine("Qwen/Qwen2.5-72B-Instruct")
|
||||||
|
|
||||||
|
agent = CodeAgent(tools=[image_generation_tool], llm_engine=llm_engine)
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
|
def interact_with_agent(prompt, messages):
|
||||||
|
messages.append(gr.ChatMessage(role="user", content=prompt))
|
||||||
|
yield messages
|
||||||
|
for msg in stream_to_gradio(agent, task=prompt, reset_agent_memory=False):
|
||||||
|
messages.append(msg)
|
||||||
|
yield messages
|
||||||
|
yield messages
|
||||||
|
|
||||||
|
|
||||||
|
with gr.Blocks() as demo:
|
||||||
|
stored_message = gr.State([])
|
||||||
|
chatbot = gr.Chatbot(label="Agent",
|
||||||
|
type="messages",
|
||||||
|
avatar_images=(None, "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png"))
|
||||||
|
text_input = gr.Textbox(lines=1, label="Chat Message")
|
||||||
|
text_input.submit(lambda s: (s, ""), [text_input], [stored_message, text_input]).then(interact_with_agent, [stored_message, chatbot], [chatbot])
|
||||||
|
|
||||||
|
demo.launch()
|
File diff suppressed because it is too large
Load Diff
|
@ -68,8 +68,12 @@ jinja2 = "^3.1.4"
|
||||||
pillow = "^11.0.0"
|
pillow = "^11.0.0"
|
||||||
llama-cpp-python = "^0.3.4"
|
llama-cpp-python = "^0.3.4"
|
||||||
markdownify = "^0.14.1"
|
markdownify = "^0.14.1"
|
||||||
|
gradio = "^5.8.0"
|
||||||
|
|
||||||
|
|
||||||
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
ipykernel = "^6.29.5"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
Loading…
Reference in New Issue