Support gradio chatbot with continued discussion
This commit is contained in:
parent
23ab4a9df3
commit
0ada2ebc27
|
@ -144,7 +144,7 @@ class AgentImage(AgentType, ImageType):
|
|||
if self._raw is not None:
|
||||
directory = tempfile.mkdtemp()
|
||||
self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
|
||||
self._raw.save(self._path)
|
||||
self._raw.save(self._path, format="png")
|
||||
return self._path
|
||||
|
||||
if self._tensor is not None:
|
||||
|
@ -155,12 +155,11 @@ class AgentImage(AgentType, ImageType):
|
|||
|
||||
directory = tempfile.mkdtemp()
|
||||
self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
|
||||
|
||||
img.save(self._path, format="png")
|
||||
|
||||
return self._path
|
||||
|
||||
def save(self, output_bytes, format = None, **params):
|
||||
def save(self, output_bytes, format : str = None, **params):
|
||||
"""
|
||||
Saves the image to a file.
|
||||
Args:
|
||||
|
|
|
@ -22,7 +22,7 @@ from rich.syntax import Syntax
|
|||
from transformers.utils import is_torch_available
|
||||
from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content
|
||||
from .agent_types import AgentAudio, AgentImage
|
||||
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
|
||||
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool
|
||||
from .llm_engine import HfApiEngine, MessageRole
|
||||
from .monitoring import Monitor
|
||||
from .prompts import (
|
||||
|
@ -42,13 +42,11 @@ from .tools import (
|
|||
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||
Tool,
|
||||
get_tool_description_with_args,
|
||||
load_tool,
|
||||
Toolbox,
|
||||
)
|
||||
|
||||
|
||||
HUGGINGFACE_DEFAULT_TOOLS = {}
|
||||
_tools_are_initialized = False
|
||||
|
||||
class AgentError(Exception):
|
||||
"""Base class for other agent-related exceptions"""
|
||||
|
@ -101,9 +99,12 @@ class PlanningStep:
|
|||
|
||||
@dataclass
|
||||
class TaskStep:
|
||||
system_prompt: str
|
||||
task: str
|
||||
|
||||
@dataclass
|
||||
class SystemPromptStep:
|
||||
system_prompt: str
|
||||
|
||||
def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str:
|
||||
tool_descriptions = toolbox.show_tool_descriptions(tool_description_template)
|
||||
prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions)
|
||||
|
@ -189,7 +190,7 @@ class BaseAgent:
|
|||
self._toolbox, self.system_prompt_template, self.tool_description_template
|
||||
)
|
||||
self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents)
|
||||
self.prompt = None
|
||||
self.prompt_messages = None
|
||||
self.logs = []
|
||||
self.task = None
|
||||
self.verbose = verbose
|
||||
|
@ -208,8 +209,7 @@ class BaseAgent:
|
|||
"""Get the toolbox currently available to the agent"""
|
||||
return self._toolbox
|
||||
|
||||
def initialize_for_run(self):
|
||||
self.token_count = 0
|
||||
def initialize_system_prompt(self):
|
||||
self.system_prompt = format_prompt_with_tools(
|
||||
self._toolbox,
|
||||
self.system_prompt_template,
|
||||
|
@ -220,27 +220,25 @@ class BaseAgent:
|
|||
self.system_prompt = format_prompt_with_imports(
|
||||
self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports))
|
||||
)
|
||||
self.logs = [TaskStep(system_prompt=self.system_prompt, task=self.task)]
|
||||
console.rule("[bold]New task", characters='=')
|
||||
console.print(self.task)
|
||||
|
||||
return self.system_prompt
|
||||
|
||||
def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages
|
||||
that can be used as input to the LLM.
|
||||
"""
|
||||
prompt_message = {"role": MessageRole.SYSTEM, "content": self.logs[0].system_prompt}
|
||||
task_message = {
|
||||
"role": MessageRole.USER,
|
||||
"content": "Task: " + self.logs[0].task,
|
||||
memory = []
|
||||
for i, step_log in enumerate(self.logs):
|
||||
if isinstance(step_log, SystemPromptStep):
|
||||
if not summary_mode:
|
||||
thought_message = {
|
||||
"role": MessageRole.SYSTEM,
|
||||
"content": step_log.system_prompt.strip(),
|
||||
}
|
||||
if summary_mode:
|
||||
memory = [task_message]
|
||||
else:
|
||||
memory = [prompt_message, task_message]
|
||||
for i, step_log in enumerate(self.logs[1:]):
|
||||
memory.append(thought_message)
|
||||
|
||||
if isinstance(step_log, PlanningStep):
|
||||
elif isinstance(step_log, PlanningStep):
|
||||
thought_message = {
|
||||
"role": MessageRole.ASSISTANT,
|
||||
"content": "[FACTS LIST]:\n" + step_log.facts.strip(),
|
||||
|
@ -398,21 +396,21 @@ class ReactAgent(BaseAgent):
|
|||
"""
|
||||
This method provides a final answer to the task, based on the logs of the agent's interactions.
|
||||
"""
|
||||
self.prompt = [
|
||||
self.prompt_messages = [
|
||||
{
|
||||
"role": MessageRole.SYSTEM,
|
||||
"content": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
|
||||
}
|
||||
]
|
||||
self.prompt += self.write_inner_memory_from_logs()[1:]
|
||||
self.prompt += [
|
||||
self.prompt_messages += self.write_inner_memory_from_logs()[1:]
|
||||
self.prompt_messages += [
|
||||
{
|
||||
"role": MessageRole.USER,
|
||||
"content": f"Based on the above, please provide an answer to the following user request:\n{task}",
|
||||
}
|
||||
]
|
||||
try:
|
||||
return self.llm_engine(self.prompt)
|
||||
return self.llm_engine(self.prompt_messages)
|
||||
except Exception as e:
|
||||
error_msg = f"Error in generating final LLM output: {e}."
|
||||
console.print(f"[bold red]{error_msg}[/bold red]")
|
||||
|
@ -423,7 +421,10 @@ class ReactAgent(BaseAgent):
|
|||
Runs the agent for the given task.
|
||||
|
||||
Args:
|
||||
task (`str`): The task to perform
|
||||
task (`str`): The task to perform.
|
||||
stream (`bool`): Wether to run in a streaming way.
|
||||
reset (`bool`): Wether to reset the conversation or keep it going from previous run.
|
||||
oneshot (`bool`): Should the agent run in one shot or multi-step fashion?
|
||||
|
||||
Example:
|
||||
```py
|
||||
|
@ -436,9 +437,22 @@ class ReactAgent(BaseAgent):
|
|||
if len(kwargs) > 0:
|
||||
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
|
||||
self.state = kwargs.copy()
|
||||
|
||||
self.initialize_system_prompt()
|
||||
system_prompt_step = SystemPromptStep(system_prompt=self.system_prompt)
|
||||
|
||||
if reset:
|
||||
self.initialize_for_run()
|
||||
self.token_count = 0
|
||||
self.logs = []
|
||||
self.logs.append(system_prompt_step)
|
||||
else:
|
||||
if len(self.logs) > 0:
|
||||
self.logs[0] = system_prompt_step
|
||||
else:
|
||||
self.logs.append(system_prompt_step)
|
||||
|
||||
console.rule("[bold]New task", characters='=')
|
||||
console.print(self.task)
|
||||
self.logs.append(TaskStep(task=task))
|
||||
|
||||
if oneshot:
|
||||
|
@ -676,20 +690,20 @@ class JsonAgent(ReactAgent):
|
|||
"""
|
||||
agent_memory = self.write_inner_memory_from_logs()
|
||||
|
||||
self.prompt = agent_memory
|
||||
self.prompt_messages = agent_memory
|
||||
|
||||
# Add new step in logs
|
||||
log_entry.agent_memory = agent_memory.copy()
|
||||
|
||||
if self.verbose:
|
||||
console.rule("[italic]Calling LLM engine with this last message:", align="left")
|
||||
console.print(self.prompt[-1])
|
||||
console.print(self.prompt_messages[-1])
|
||||
console.rule()
|
||||
|
||||
try:
|
||||
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
|
||||
llm_output = self.llm_engine(
|
||||
self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args
|
||||
self.prompt_messages, stop_sequences=["<end_action>", "Observation:"], **additional_args
|
||||
)
|
||||
log_entry.llm_output = llm_output
|
||||
except Exception as e:
|
||||
|
@ -796,20 +810,20 @@ class CodeAgent(ReactAgent):
|
|||
"""
|
||||
agent_memory = self.write_inner_memory_from_logs()
|
||||
|
||||
self.prompt = agent_memory.copy()
|
||||
self.prompt_messages = agent_memory.copy()
|
||||
|
||||
# Add new step in logs
|
||||
log_entry.agent_memory = agent_memory.copy()
|
||||
|
||||
if self.verbose:
|
||||
console.rule("[italic]Calling LLM engine with these last messages:", align="left")
|
||||
console.print(self.prompt[-2:])
|
||||
console.print(self.prompt_messages[-2:])
|
||||
console.rule()
|
||||
|
||||
try:
|
||||
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
|
||||
llm_output = self.llm_engine(
|
||||
self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args
|
||||
self.prompt_messages, stop_sequences=["<end_action>", "Observation:"], **additional_args
|
||||
)
|
||||
log_entry.llm_output = llm_output
|
||||
except Exception as e:
|
||||
|
@ -893,7 +907,7 @@ You have been submitted this task by your manager.
|
|||
Task:
|
||||
{task}
|
||||
---
|
||||
You're helping your manager solve a wider task: so make sure to not provide a one-line answer, but give as much information as possible so that they have a clear understanding of the answer.
|
||||
You're helping your manager solve a wider task: so make sure to not provide a one-line answer, but give as much information as possible to give them a clear understanding of the answer.
|
||||
|
||||
Your final_answer WILL HAVE to contain these parts:
|
||||
### 1. Task outcome (short version):
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib.util
|
||||
import json
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
@ -25,7 +24,7 @@ from huggingface_hub import hf_hub_download, list_spaces
|
|||
|
||||
from transformers.utils import is_offline_mode
|
||||
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
|
||||
from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool
|
||||
from .tools import TOOL_CONFIG_FILE, Tool
|
||||
|
||||
|
||||
def custom_print(*args):
|
||||
|
@ -97,12 +96,6 @@ class PreTool:
|
|||
repo_id: str
|
||||
|
||||
|
||||
HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB = [
|
||||
"image-transformation",
|
||||
"text-to-image",
|
||||
]
|
||||
|
||||
|
||||
def get_remote_tools(logger, organization="huggingface-tools"):
|
||||
if is_offline_mode():
|
||||
logger.info("You are in offline mode, so remote tools are not available.")
|
||||
|
@ -128,26 +121,6 @@ def get_remote_tools(logger, organization="huggingface-tools"):
|
|||
return tools
|
||||
|
||||
|
||||
def setup_default_tools():
|
||||
default_tools = {}
|
||||
main_module = importlib.import_module("transformers")
|
||||
tools_module = main_module.agents
|
||||
|
||||
for task_name, tool_class_name in TOOL_MAPPING.items():
|
||||
tool_class = getattr(tools_module, tool_class_name)
|
||||
tool_instance = tool_class()
|
||||
default_tools[tool_class.name] = PreTool(
|
||||
name=tool_instance.name,
|
||||
inputs=tool_instance.inputs,
|
||||
output_type=tool_instance.output_type,
|
||||
task=task_name,
|
||||
description=tool_instance.description,
|
||||
repo_id=None,
|
||||
)
|
||||
|
||||
return default_tools
|
||||
|
||||
|
||||
class PythonInterpreterTool(Tool):
|
||||
name = "python_interpreter"
|
||||
description = "This is a tool that evaluates python code. It can be used to perform calculations."
|
||||
|
|
|
@ -1,88 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from transformers import AutoProcessor, VisionEncoderDecoderModel
|
||||
from transformers.utils import is_vision_available
|
||||
from .tools import PipelineTool
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class DocumentQuestionAnsweringTool(PipelineTool):
|
||||
default_checkpoint = "naver-clova-ix/donut-base-finetuned-docvqa"
|
||||
description = "This is a tool that answers a question about an document (pdf). It returns a string that contains the answer to the question."
|
||||
name = "document_qa"
|
||||
pre_processor_class = AutoProcessor
|
||||
model_class = VisionEncoderDecoderModel
|
||||
|
||||
inputs = {
|
||||
"document": {
|
||||
"type": "image",
|
||||
"description": "The image containing the information. Can be a PIL Image or a string path to the image.",
|
||||
},
|
||||
"question": {"type": "string", "description": "The question in English"},
|
||||
}
|
||||
output_type = "string"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if not is_vision_available():
|
||||
raise ValueError("Pillow must be installed to use the DocumentQuestionAnsweringTool.")
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def encode(self, document: "Image", question: str):
|
||||
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
|
||||
prompt = task_prompt.replace("{user_input}", question)
|
||||
decoder_input_ids = self.pre_processor.tokenizer(
|
||||
prompt, add_special_tokens=False, return_tensors="pt"
|
||||
).input_ids
|
||||
if isinstance(document, str):
|
||||
img = Image.open(document).convert("RGB")
|
||||
img_array = np.array(img).transpose(2, 0, 1)
|
||||
document = torch.from_numpy(img_array)
|
||||
pixel_values = self.pre_processor(document, return_tensors="pt").pixel_values
|
||||
|
||||
return {"decoder_input_ids": decoder_input_ids, "pixel_values": pixel_values}
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.model.generate(
|
||||
inputs["pixel_values"].to(self.device),
|
||||
decoder_input_ids=inputs["decoder_input_ids"].to(self.device),
|
||||
max_length=self.model.decoder.config.max_position_embeddings,
|
||||
early_stopping=True,
|
||||
pad_token_id=self.pre_processor.tokenizer.pad_token_id,
|
||||
eos_token_id=self.pre_processor.tokenizer.eos_token_id,
|
||||
use_cache=True,
|
||||
num_beams=1,
|
||||
bad_words_ids=[[self.pre_processor.tokenizer.unk_token_id]],
|
||||
return_dict_in_generate=True,
|
||||
).sequences
|
||||
|
||||
def decode(self, outputs):
|
||||
sequence = self.pre_processor.batch_decode(outputs)[0]
|
||||
sequence = sequence.replace(self.pre_processor.tokenizer.eos_token, "")
|
||||
sequence = sequence.replace(self.pre_processor.tokenizer.pad_token, "")
|
||||
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
|
||||
sequence = self.pre_processor.token2json(sequence)
|
||||
|
||||
return sequence["answer"]
|
|
@ -1,58 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from transformers import AutoModelForVisualQuestionAnswering, AutoProcessor
|
||||
from transformers.utils import requires_backends
|
||||
from .tools import PipelineTool
|
||||
|
||||
|
||||
class ImageQuestionAnsweringTool(PipelineTool):
|
||||
default_checkpoint = "dandelin/vilt-b32-finetuned-vqa"
|
||||
description = (
|
||||
"This is a tool that answers a question about an image. It "
|
||||
"returns a text that is the answer to the question."
|
||||
)
|
||||
name = "image_qa"
|
||||
pre_processor_class = AutoProcessor
|
||||
model_class = AutoModelForVisualQuestionAnswering
|
||||
|
||||
inputs = {
|
||||
"image": {
|
||||
"type": "image",
|
||||
"description": "The image containing the information. Can be a PIL Image or a string path to the image.",
|
||||
},
|
||||
"question": {"type": "string", "description": "The question in English"},
|
||||
}
|
||||
output_type = "string"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def encode(self, image: "Image", question: str):
|
||||
return self.pre_processor(image, question, return_tensors="pt")
|
||||
|
||||
def forward(self, inputs):
|
||||
with torch.no_grad():
|
||||
return self.model(**inputs).logits
|
||||
|
||||
def decode(self, outputs):
|
||||
idx = outputs.argmax(-1).item()
|
||||
return self.model.config.id2label[idx]
|
|
@ -17,20 +17,8 @@
|
|||
from .agent_types import AgentAudio, AgentImage, AgentText
|
||||
from .utils import console
|
||||
|
||||
|
||||
def pull_message(step_log: dict, test_mode: bool = True):
|
||||
try:
|
||||
from gradio import ChatMessage
|
||||
except ImportError:
|
||||
if test_mode:
|
||||
|
||||
class ChatMessage:
|
||||
def __init__(self, role, content, metadata=None):
|
||||
self.role = role
|
||||
self.content = content
|
||||
self.metadata = metadata
|
||||
else:
|
||||
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
|
||||
|
||||
if step_log.get("rationale"):
|
||||
yield ChatMessage(role="assistant", content=step_log["rationale"])
|
||||
|
@ -54,23 +42,11 @@ def pull_message(step_log: dict, test_mode: bool = True):
|
|||
)
|
||||
|
||||
|
||||
def stream_to_gradio(agent, task: str, test_mode: bool = False, **kwargs):
|
||||
def stream_to_gradio(agent, task: str, test_mode: bool = False, reset_agent_memory: bool=False, **kwargs):
|
||||
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
|
||||
|
||||
try:
|
||||
from gradio import ChatMessage
|
||||
except ImportError:
|
||||
if test_mode:
|
||||
|
||||
class ChatMessage:
|
||||
def __init__(self, role, content, metadata=None):
|
||||
self.role = role
|
||||
self.content = content
|
||||
self.metadata = metadata
|
||||
else:
|
||||
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
|
||||
|
||||
for step_log in agent.run(task, stream=True, **kwargs):
|
||||
for step_log in agent.run(task, stream=True, reset=reset_agent_memory, **kwargs):
|
||||
if isinstance(step_log, dict):
|
||||
for message in pull_message(step_log, test_mode=test_mode):
|
||||
yield message
|
||||
|
|
|
@ -395,7 +395,7 @@ Do not add anything else."""
|
|||
SYSTEM_PROMPT_PLAN = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
|
||||
|
||||
Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
|
||||
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
|
||||
This plan should involve individual tasks based on the available tools, that if executed correctly will yield the correct answer.
|
||||
Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
|
||||
After writing the final step of the plan, write the '\n<end_plan>' tag and stop there."""
|
||||
|
||||
|
@ -466,7 +466,7 @@ Here is the up to date list of facts that you know:
|
|||
```
|
||||
|
||||
Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
|
||||
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
|
||||
This plan should involve individual tasks based on the available tools, that if executed correctly will yield the correct answer.
|
||||
Beware that you have {remaining_steps} steps remaining.
|
||||
Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
|
||||
After writing the final step of the plan, write the '\n<end_plan>' tag and stop there.
|
||||
|
|
|
@ -1,39 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from transformers import WhisperForConditionalGeneration, WhisperProcessor
|
||||
from .tools import PipelineTool
|
||||
|
||||
|
||||
class SpeechToTextTool(PipelineTool):
|
||||
default_checkpoint = "distil-whisper/distil-large-v3"
|
||||
description = "This is a tool that transcribes an audio into text. It returns the transcribed text."
|
||||
name = "transcriber"
|
||||
pre_processor_class = WhisperProcessor
|
||||
model_class = WhisperForConditionalGeneration
|
||||
|
||||
inputs = {"audio": {"type": "audio", "description": "The audio to transcribe"}}
|
||||
output_type = "string"
|
||||
|
||||
def encode(self, audio):
|
||||
return self.pre_processor(audio, return_tensors="pt")
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.model.generate(inputs["input_features"])
|
||||
|
||||
def decode(self, outputs):
|
||||
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
|
@ -1,67 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
|
||||
from transformers.utils import is_datasets_available
|
||||
from .tools import PipelineTool
|
||||
|
||||
|
||||
if is_datasets_available():
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
class TextToSpeechTool(PipelineTool):
|
||||
default_checkpoint = "microsoft/speecht5_tts"
|
||||
description = (
|
||||
"This is a tool that reads an English text out loud. It returns a waveform object containing the sound."
|
||||
)
|
||||
name = "text_to_speech"
|
||||
pre_processor_class = SpeechT5Processor
|
||||
model_class = SpeechT5ForTextToSpeech
|
||||
post_processor_class = SpeechT5HifiGan
|
||||
|
||||
inputs = {"text": {"type": "string", "description": "The text to read out loud (in English)"}}
|
||||
output_type = "audio"
|
||||
|
||||
def setup(self):
|
||||
if self.post_processor is None:
|
||||
self.post_processor = "microsoft/speecht5_hifigan"
|
||||
super().setup()
|
||||
|
||||
def encode(self, text, speaker_embeddings=None):
|
||||
inputs = self.pre_processor(text=text, return_tensors="pt", truncation=True)
|
||||
|
||||
if speaker_embeddings is None:
|
||||
if not is_datasets_available():
|
||||
raise ImportError("Datasets needs to be installed if not passing speaker embeddings.")
|
||||
|
||||
embeddings_dataset = load_dataset(
|
||||
"Matthijs/cmu-arctic-xvectors", split="validation", trust_remote_code=True
|
||||
)
|
||||
speaker_embeddings = torch.tensor(embeddings_dataset[7305]["xvector"]).unsqueeze(0)
|
||||
|
||||
return {"input_ids": inputs["input_ids"], "speaker_embeddings": speaker_embeddings}
|
||||
|
||||
def forward(self, inputs):
|
||||
with torch.no_grad():
|
||||
return self.model.generate_speech(**inputs)
|
||||
|
||||
def decode(self, outputs):
|
||||
with torch.no_grad():
|
||||
return self.post_processor(outputs).cpu().detach()
|
|
@ -80,6 +80,19 @@ def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
|
|||
return "space"
|
||||
|
||||
|
||||
def setup_default_tools():
|
||||
default_tools = {}
|
||||
main_module = importlib.import_module("transformers")
|
||||
tools_module = main_module.agents
|
||||
|
||||
for task_name, tool_class_name in TOOL_MAPPING.items():
|
||||
tool_class = getattr(tools_module, tool_class_name)
|
||||
tool_instance = tool_class()
|
||||
default_tools[tool_class.name] = tool_instance
|
||||
|
||||
return default_tools
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo
|
||||
from {module_name} import {class_name}
|
||||
|
@ -811,11 +824,6 @@ def launch_gradio_demo(tool_class: Tool):
|
|||
|
||||
|
||||
TOOL_MAPPING = {
|
||||
"document_question_answering": "DocumentQuestionAnsweringTool",
|
||||
"image_question_answering": "ImageQuestionAnsweringTool",
|
||||
"speech_to_text": "SpeechToTextTool",
|
||||
"text_to_speech": "TextToSpeechTool",
|
||||
"translation": "TranslationTool",
|
||||
"python_interpreter": "PythonInterpreterTool",
|
||||
"web_search": "DuckDuckGoSearchTool",
|
||||
}
|
||||
|
@ -1018,18 +1026,14 @@ class Toolbox:
|
|||
self._tools = {tool.name: tool for tool in tools}
|
||||
if add_base_tools:
|
||||
self.add_base_tools()
|
||||
# self._load_tools_if_needed()
|
||||
|
||||
def add_base_tools(self, add_python_interpreter: bool = False):
|
||||
global _tools_are_initialized
|
||||
global HUGGINGFACE_DEFAULT_TOOLS
|
||||
if not _tools_are_initialized:
|
||||
if len(HUGGINGFACE_DEFAULT_TOOLS.keys()) == 0:
|
||||
HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools()
|
||||
_tools_are_initialized = True
|
||||
for tool in HUGGINGFACE_DEFAULT_TOOLS.values():
|
||||
if tool.name != "python_interpreter" or add_python_interpreter:
|
||||
self.add_tool(tool)
|
||||
# self._load_tools_if_needed()
|
||||
|
||||
@property
|
||||
def tools(self) -> Dict[str, Tool]:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from agents import load_tool, ReactCodeAgent, HfApiEngine
|
||||
from agents import load_tool, CodeAgent, HfApiEngine
|
||||
|
||||
# Import tool from Hub
|
||||
image_generation_tool = load_tool("m-ric/text-to-image", cache=False)
|
||||
|
@ -10,7 +10,7 @@ search_tool = DuckDuckGoSearchTool()
|
|||
|
||||
llm_engine = HfApiEngine("Qwen/Qwen2.5-72B-Instruct")
|
||||
# Initialize the agent with both tools
|
||||
agent = ReactCodeAgent(tools=[image_generation_tool, search_tool], llm_engine=llm_engine)
|
||||
agent = CodeAgent(tools=[image_generation_tool, search_tool], llm_engine=llm_engine)
|
||||
|
||||
# Run it!
|
||||
result = agent.run(
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
from agents import stream_to_gradio, HfApiEngine, load_tool, CodeAgent
|
||||
|
||||
image_generation_tool = load_tool("m-ric/text-to-image")
|
||||
|
||||
llm_engine = HfApiEngine("Qwen/Qwen2.5-72B-Instruct")
|
||||
|
||||
agent = CodeAgent(tools=[image_generation_tool], llm_engine=llm_engine)
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def interact_with_agent(prompt, messages):
|
||||
messages.append(gr.ChatMessage(role="user", content=prompt))
|
||||
yield messages
|
||||
for msg in stream_to_gradio(agent, task=prompt, reset_agent_memory=False):
|
||||
messages.append(msg)
|
||||
yield messages
|
||||
yield messages
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
stored_message = gr.State([])
|
||||
chatbot = gr.Chatbot(label="Agent",
|
||||
type="messages",
|
||||
avatar_images=(None, "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png"))
|
||||
text_input = gr.Textbox(lines=1, label="Chat Message")
|
||||
text_input.submit(lambda s: (s, ""), [text_input], [stored_message, text_input]).then(interact_with_agent, [stored_message, chatbot], [chatbot])
|
||||
|
||||
demo.launch()
|
File diff suppressed because it is too large
Load Diff
|
@ -68,8 +68,12 @@ jinja2 = "^3.1.4"
|
|||
pillow = "^11.0.0"
|
||||
llama-cpp-python = "^0.3.4"
|
||||
markdownify = "^0.14.1"
|
||||
gradio = "^5.8.0"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ipykernel = "^6.29.5"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
Loading…
Reference in New Issue