Support multi-agent
This commit is contained in:
parent
43a3f46835
commit
23ab4a9df3
|
@ -92,6 +92,7 @@ class ActionStep:
|
||||||
final_answer: Any = None
|
final_answer: Any = None
|
||||||
error: AgentError | None = None
|
error: AgentError | None = None
|
||||||
step_duration: float | None = None
|
step_duration: float | None = None
|
||||||
|
llm_output: str | None = None
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PlanningStep:
|
class PlanningStep:
|
||||||
|
@ -440,9 +441,6 @@ class ReactAgent(BaseAgent):
|
||||||
else:
|
else:
|
||||||
self.logs.append(TaskStep(task=task))
|
self.logs.append(TaskStep(task=task))
|
||||||
|
|
||||||
with console.status(
|
|
||||||
"Agent is running...", spinner="aesthetic"
|
|
||||||
):
|
|
||||||
if oneshot:
|
if oneshot:
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
step_log = ActionStep(start_time=step_start_time)
|
step_log = ActionStep(start_time=step_start_time)
|
||||||
|
@ -468,6 +466,9 @@ class ReactAgent(BaseAgent):
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
step_log = ActionStep(iteration=iteration, start_time=step_start_time)
|
step_log = ActionStep(iteration=iteration, start_time=step_start_time)
|
||||||
try:
|
try:
|
||||||
|
if self.planning_interval is not None and iteration % self.planning_interval == 0:
|
||||||
|
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
|
||||||
|
console.rule("[bold]New step")
|
||||||
self.step(step_log)
|
self.step(step_log)
|
||||||
if step_log.final_answer is not None:
|
if step_log.final_answer is not None:
|
||||||
final_answer = step_log.final_answer
|
final_answer = step_log.final_answer
|
||||||
|
@ -484,7 +485,6 @@ class ReactAgent(BaseAgent):
|
||||||
|
|
||||||
if final_answer is None and iteration == self.max_iterations:
|
if final_answer is None and iteration == self.max_iterations:
|
||||||
error_message = "Reached max iterations."
|
error_message = "Reached max iterations."
|
||||||
console.print(f"[bold red]{error_message}")
|
|
||||||
final_step_log = ActionStep(error=AgentMaxIterationsError(error_message))
|
final_step_log = ActionStep(error=AgentMaxIterationsError(error_message))
|
||||||
self.logs.append(final_step_log)
|
self.logs.append(final_step_log)
|
||||||
final_answer = self.provide_final_answer(task)
|
final_answer = self.provide_final_answer(task)
|
||||||
|
@ -509,6 +509,7 @@ class ReactAgent(BaseAgent):
|
||||||
try:
|
try:
|
||||||
if self.planning_interval is not None and iteration % self.planning_interval == 0:
|
if self.planning_interval is not None and iteration % self.planning_interval == 0:
|
||||||
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
|
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
|
||||||
|
console.rule("[bold]New step")
|
||||||
self.step(step_log)
|
self.step(step_log)
|
||||||
if step_log.final_answer is not None:
|
if step_log.final_answer is not None:
|
||||||
final_answer = step_log.final_answer
|
final_answer = step_log.final_answer
|
||||||
|
@ -527,7 +528,6 @@ class ReactAgent(BaseAgent):
|
||||||
error_message = "Reached max iterations."
|
error_message = "Reached max iterations."
|
||||||
final_step_log = ActionStep(error=AgentMaxIterationsError(error_message))
|
final_step_log = ActionStep(error=AgentMaxIterationsError(error_message))
|
||||||
self.logs.append(final_step_log)
|
self.logs.append(final_step_log)
|
||||||
console.print(f"[bold red]{error_message}")
|
|
||||||
final_answer = self.provide_final_answer(task)
|
final_answer = self.provide_final_answer(task)
|
||||||
final_step_log.final_answer = final_answer
|
final_step_log.final_answer = final_answer
|
||||||
final_step_log.step_duration = 0
|
final_step_log.step_duration = 0
|
||||||
|
@ -677,7 +677,6 @@ class JsonAgent(ReactAgent):
|
||||||
agent_memory = self.write_inner_memory_from_logs()
|
agent_memory = self.write_inner_memory_from_logs()
|
||||||
|
|
||||||
self.prompt = agent_memory
|
self.prompt = agent_memory
|
||||||
console.rule("New step")
|
|
||||||
|
|
||||||
# Add new step in logs
|
# Add new step in logs
|
||||||
log_entry.agent_memory = agent_memory.copy()
|
log_entry.agent_memory = agent_memory.copy()
|
||||||
|
@ -692,11 +691,13 @@ class JsonAgent(ReactAgent):
|
||||||
llm_output = self.llm_engine(
|
llm_output = self.llm_engine(
|
||||||
self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args
|
self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args
|
||||||
)
|
)
|
||||||
|
log_entry.llm_output = llm_output
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
||||||
console.rule("Output message of the LLM")
|
|
||||||
|
if self.verbose:
|
||||||
|
console.rule("[italic]Output message of the LLM:")
|
||||||
console.print(llm_output)
|
console.print(llm_output)
|
||||||
log_entry.llm_output = llm_output
|
|
||||||
|
|
||||||
# Parse
|
# Parse
|
||||||
rationale, action = self.extract_action(llm_output=llm_output, split_token="Action:")
|
rationale, action = self.extract_action(llm_output=llm_output, split_token="Action:")
|
||||||
|
@ -796,7 +797,6 @@ class CodeAgent(ReactAgent):
|
||||||
agent_memory = self.write_inner_memory_from_logs()
|
agent_memory = self.write_inner_memory_from_logs()
|
||||||
|
|
||||||
self.prompt = agent_memory.copy()
|
self.prompt = agent_memory.copy()
|
||||||
console.rule("New step")
|
|
||||||
|
|
||||||
# Add new step in logs
|
# Add new step in logs
|
||||||
log_entry.agent_memory = agent_memory.copy()
|
log_entry.agent_memory = agent_memory.copy()
|
||||||
|
@ -811,13 +811,13 @@ class CodeAgent(ReactAgent):
|
||||||
llm_output = self.llm_engine(
|
llm_output = self.llm_engine(
|
||||||
self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args
|
self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args
|
||||||
)
|
)
|
||||||
|
log_entry.llm_output = llm_output
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
console.rule("[italic]Output message of the LLM:")
|
console.rule("[italic]Output message of the LLM:")
|
||||||
console.print(Syntax(llm_output, lexer='markdown', background_color='default'))
|
console.print(Syntax(llm_output, lexer='markdown', background_color='default'))
|
||||||
log_entry.llm_output = llm_output
|
|
||||||
|
|
||||||
# Parse
|
# Parse
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -185,3 +185,13 @@ class FinalAnswerTool(Tool):
|
||||||
|
|
||||||
def forward(self, answer):
|
def forward(self, answer):
|
||||||
return answer
|
return answer
|
||||||
|
|
||||||
|
class UserInputTool(Tool):
|
||||||
|
name = "user_input"
|
||||||
|
description = "Asks for user's input on a specific question"
|
||||||
|
inputs = {"question": {"type": "string", "description": "The question to ask the user"}}
|
||||||
|
output_type = "string"
|
||||||
|
|
||||||
|
def forward(self, question):
|
||||||
|
user_input = input(f"{question} => ")
|
||||||
|
return user_input
|
||||||
|
|
|
@ -0,0 +1,73 @@
|
||||||
|
import re
|
||||||
|
import requests
|
||||||
|
from markdownify import markdownify as md
|
||||||
|
from requests.exceptions import RequestException
|
||||||
|
|
||||||
|
from agents import (
|
||||||
|
tool,
|
||||||
|
CodeAgent,
|
||||||
|
JsonAgent,
|
||||||
|
HfApiEngine,
|
||||||
|
ManagedAgent,
|
||||||
|
)
|
||||||
|
from agents.default_tools import UserInputTool
|
||||||
|
from agents.search import DuckDuckGoSearchTool
|
||||||
|
from agents.utils import console
|
||||||
|
|
||||||
|
model = "Qwen/Qwen2.5-72B-Instruct"
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def visit_webpage(url: str) -> str:
|
||||||
|
"""Visits a webpage at the given URL and returns its content as a markdown string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The URL of the webpage to visit.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The content of the webpage converted to Markdown, or an error message if the request fails.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Send a GET request to the URL
|
||||||
|
response = requests.get(url)
|
||||||
|
response.raise_for_status() # Raise an exception for bad status codes
|
||||||
|
|
||||||
|
# Convert the HTML content to Markdown
|
||||||
|
markdown_content = md(response.text).strip()
|
||||||
|
|
||||||
|
# Remove multiple line breaks
|
||||||
|
markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content)
|
||||||
|
|
||||||
|
return markdown_content
|
||||||
|
|
||||||
|
except RequestException as e:
|
||||||
|
return f"Error fetching the webpage: {str(e)}"
|
||||||
|
except Exception as e:
|
||||||
|
return f"An unexpected error occurred: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
llm_engine = HfApiEngine(model)
|
||||||
|
|
||||||
|
web_agent = JsonAgent(
|
||||||
|
tools=[DuckDuckGoSearchTool(), visit_webpage],
|
||||||
|
llm_engine=llm_engine,
|
||||||
|
max_iterations=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
managed_web_agent = ManagedAgent(
|
||||||
|
agent=web_agent,
|
||||||
|
name="search",
|
||||||
|
description="Runs web searches for you. Give it your query as an argument.",
|
||||||
|
)
|
||||||
|
|
||||||
|
manager_agent = CodeAgent(
|
||||||
|
tools=[UserInputTool()],
|
||||||
|
llm_engine=llm_engine,
|
||||||
|
managed_agents=[managed_web_agent],
|
||||||
|
additional_authorized_imports=["time", "datetime"],
|
||||||
|
)
|
||||||
|
|
||||||
|
with console.status(
|
||||||
|
"Agent is running...", spinner="aesthetic"
|
||||||
|
):
|
||||||
|
manager_agent.run("""How many years ago was Stripe founded?
|
||||||
|
You should ask for user input on wether the answer is correct before returning your final answer.""")
|
|
@ -1,5 +1,26 @@
|
||||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "beautifulsoup4"
|
||||||
|
version = "4.12.3"
|
||||||
|
description = "Screen-scraping library"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.6.0"
|
||||||
|
files = [
|
||||||
|
{file = "beautifulsoup4-4.12.3-py3-none-any.whl", hash = "sha256:b80878c9f40111313e55da8ba20bdba06d8fa3969fc68304167741bbf9e082ed"},
|
||||||
|
{file = "beautifulsoup4-4.12.3.tar.gz", hash = "sha256:74e3d1928edc070d21748185c46e3fb33490f22f52a3addee9aee0f4f7781051"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
soupsieve = ">1.2"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
cchardet = ["cchardet"]
|
||||||
|
chardet = ["chardet"]
|
||||||
|
charset-normalizer = ["charset-normalizer"]
|
||||||
|
html5lib = ["html5lib"]
|
||||||
|
lxml = ["lxml"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "certifi"
|
name = "certifi"
|
||||||
version = "2024.8.30"
|
version = "2024.8.30"
|
||||||
|
@ -338,6 +359,21 @@ profiling = ["gprof2dot"]
|
||||||
rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"]
|
rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"]
|
||||||
testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
|
testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "markdownify"
|
||||||
|
version = "0.14.1"
|
||||||
|
description = "Convert HTML to markdown."
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
files = [
|
||||||
|
{file = "markdownify-0.14.1-py3-none-any.whl", hash = "sha256:4c46a6c0c12c6005ddcd49b45a5a890398b002ef51380cd319db62df5e09bc2a"},
|
||||||
|
{file = "markdownify-0.14.1.tar.gz", hash = "sha256:a62a7a216947ed0b8dafb95b99b2ef4a0edd1e18d5653c656f68f03db2bfb2f1"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
beautifulsoup4 = ">=4.9,<5"
|
||||||
|
six = ">=1.15,<2"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "markupsafe"
|
name = "markupsafe"
|
||||||
version = "3.0.2"
|
version = "3.0.2"
|
||||||
|
@ -1096,6 +1132,17 @@ files = [
|
||||||
{file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"},
|
{file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "soupsieve"
|
||||||
|
version = "2.6"
|
||||||
|
description = "A modern CSS selector implementation for Beautiful Soup."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "soupsieve-2.6-py3-none-any.whl", hash = "sha256:e72c4ff06e4fb6e4b5a9f0f55fe6e81514581fca1515028625d0f299c602ccc9"},
|
||||||
|
{file = "soupsieve-2.6.tar.gz", hash = "sha256:e2e68417777af359ec65daac1057404a3c8a5455bb8abc36f1a9866ab1a51abb"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokenizers"
|
name = "tokenizers"
|
||||||
version = "0.21.0"
|
version = "0.21.0"
|
||||||
|
@ -1301,4 +1348,4 @@ zstd = ["zstandard (>=0.18.0)"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.10,<3.13"
|
python-versions = ">=3.10,<3.13"
|
||||||
content-hash = "6c3841968936d66bf70e11c6c8e0a16fec6c2f4d88d79cd8ac5a412225e7cf56"
|
content-hash = "3a0896faf882952a0d780efcc862017989612fcb421a6ee01e4eec0ba6c0f638"
|
||||||
|
|
|
@ -67,6 +67,7 @@ pandas = "^2.2.3"
|
||||||
jinja2 = "^3.1.4"
|
jinja2 = "^3.1.4"
|
||||||
pillow = "^11.0.0"
|
pillow = "^11.0.0"
|
||||||
llama-cpp-python = "^0.3.4"
|
llama-cpp-python = "^0.3.4"
|
||||||
|
markdownify = "^0.14.1"
|
||||||
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
|
Loading…
Reference in New Issue