Support multi-agent

This commit is contained in:
Aymeric 2024-12-09 22:58:30 +01:00
parent 43a3f46835
commit 23ab4a9df3
5 changed files with 155 additions and 24 deletions

View File

@ -92,6 +92,7 @@ class ActionStep:
final_answer: Any = None final_answer: Any = None
error: AgentError | None = None error: AgentError | None = None
step_duration: float | None = None step_duration: float | None = None
llm_output: str | None = None
@dataclass @dataclass
class PlanningStep: class PlanningStep:
@ -440,9 +441,6 @@ class ReactAgent(BaseAgent):
else: else:
self.logs.append(TaskStep(task=task)) self.logs.append(TaskStep(task=task))
with console.status(
"Agent is running...", spinner="aesthetic"
):
if oneshot: if oneshot:
step_start_time = time.time() step_start_time = time.time()
step_log = ActionStep(start_time=step_start_time) step_log = ActionStep(start_time=step_start_time)
@ -468,6 +466,9 @@ class ReactAgent(BaseAgent):
step_start_time = time.time() step_start_time = time.time()
step_log = ActionStep(iteration=iteration, start_time=step_start_time) step_log = ActionStep(iteration=iteration, start_time=step_start_time)
try: try:
if self.planning_interval is not None and iteration % self.planning_interval == 0:
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
console.rule("[bold]New step")
self.step(step_log) self.step(step_log)
if step_log.final_answer is not None: if step_log.final_answer is not None:
final_answer = step_log.final_answer final_answer = step_log.final_answer
@ -484,7 +485,6 @@ class ReactAgent(BaseAgent):
if final_answer is None and iteration == self.max_iterations: if final_answer is None and iteration == self.max_iterations:
error_message = "Reached max iterations." error_message = "Reached max iterations."
console.print(f"[bold red]{error_message}")
final_step_log = ActionStep(error=AgentMaxIterationsError(error_message)) final_step_log = ActionStep(error=AgentMaxIterationsError(error_message))
self.logs.append(final_step_log) self.logs.append(final_step_log)
final_answer = self.provide_final_answer(task) final_answer = self.provide_final_answer(task)
@ -509,6 +509,7 @@ class ReactAgent(BaseAgent):
try: try:
if self.planning_interval is not None and iteration % self.planning_interval == 0: if self.planning_interval is not None and iteration % self.planning_interval == 0:
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration) self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
console.rule("[bold]New step")
self.step(step_log) self.step(step_log)
if step_log.final_answer is not None: if step_log.final_answer is not None:
final_answer = step_log.final_answer final_answer = step_log.final_answer
@ -527,7 +528,6 @@ class ReactAgent(BaseAgent):
error_message = "Reached max iterations." error_message = "Reached max iterations."
final_step_log = ActionStep(error=AgentMaxIterationsError(error_message)) final_step_log = ActionStep(error=AgentMaxIterationsError(error_message))
self.logs.append(final_step_log) self.logs.append(final_step_log)
console.print(f"[bold red]{error_message}")
final_answer = self.provide_final_answer(task) final_answer = self.provide_final_answer(task)
final_step_log.final_answer = final_answer final_step_log.final_answer = final_answer
final_step_log.step_duration = 0 final_step_log.step_duration = 0
@ -677,7 +677,6 @@ class JsonAgent(ReactAgent):
agent_memory = self.write_inner_memory_from_logs() agent_memory = self.write_inner_memory_from_logs()
self.prompt = agent_memory self.prompt = agent_memory
console.rule("New step")
# Add new step in logs # Add new step in logs
log_entry.agent_memory = agent_memory.copy() log_entry.agent_memory = agent_memory.copy()
@ -692,11 +691,13 @@ class JsonAgent(ReactAgent):
llm_output = self.llm_engine( llm_output = self.llm_engine(
self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args
) )
log_entry.llm_output = llm_output
except Exception as e: except Exception as e:
raise AgentGenerationError(f"Error in generating llm output: {e}.") raise AgentGenerationError(f"Error in generating llm output: {e}.")
console.rule("Output message of the LLM")
if self.verbose:
console.rule("[italic]Output message of the LLM:")
console.print(llm_output) console.print(llm_output)
log_entry.llm_output = llm_output
# Parse # Parse
rationale, action = self.extract_action(llm_output=llm_output, split_token="Action:") rationale, action = self.extract_action(llm_output=llm_output, split_token="Action:")
@ -796,7 +797,6 @@ class CodeAgent(ReactAgent):
agent_memory = self.write_inner_memory_from_logs() agent_memory = self.write_inner_memory_from_logs()
self.prompt = agent_memory.copy() self.prompt = agent_memory.copy()
console.rule("New step")
# Add new step in logs # Add new step in logs
log_entry.agent_memory = agent_memory.copy() log_entry.agent_memory = agent_memory.copy()
@ -811,13 +811,13 @@ class CodeAgent(ReactAgent):
llm_output = self.llm_engine( llm_output = self.llm_engine(
self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args
) )
log_entry.llm_output = llm_output
except Exception as e: except Exception as e:
raise AgentGenerationError(f"Error in generating llm output: {e}.") raise AgentGenerationError(f"Error in generating llm output: {e}.")
if self.verbose: if self.verbose:
console.rule("[italic]Output message of the LLM:") console.rule("[italic]Output message of the LLM:")
console.print(Syntax(llm_output, lexer='markdown', background_color='default')) console.print(Syntax(llm_output, lexer='markdown', background_color='default'))
log_entry.llm_output = llm_output
# Parse # Parse
try: try:

View File

@ -185,3 +185,13 @@ class FinalAnswerTool(Tool):
def forward(self, answer): def forward(self, answer):
return answer return answer
class UserInputTool(Tool):
name = "user_input"
description = "Asks for user's input on a specific question"
inputs = {"question": {"type": "string", "description": "The question to ask the user"}}
output_type = "string"
def forward(self, question):
user_input = input(f"{question} => ")
return user_input

73
examples/orchestrator.py Normal file
View File

@ -0,0 +1,73 @@
import re
import requests
from markdownify import markdownify as md
from requests.exceptions import RequestException
from agents import (
tool,
CodeAgent,
JsonAgent,
HfApiEngine,
ManagedAgent,
)
from agents.default_tools import UserInputTool
from agents.search import DuckDuckGoSearchTool
from agents.utils import console
model = "Qwen/Qwen2.5-72B-Instruct"
@tool
def visit_webpage(url: str) -> str:
"""Visits a webpage at the given URL and returns its content as a markdown string.
Args:
url: The URL of the webpage to visit.
Returns:
The content of the webpage converted to Markdown, or an error message if the request fails.
"""
try:
# Send a GET request to the URL
response = requests.get(url)
response.raise_for_status() # Raise an exception for bad status codes
# Convert the HTML content to Markdown
markdown_content = md(response.text).strip()
# Remove multiple line breaks
markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content)
return markdown_content
except RequestException as e:
return f"Error fetching the webpage: {str(e)}"
except Exception as e:
return f"An unexpected error occurred: {str(e)}"
llm_engine = HfApiEngine(model)
web_agent = JsonAgent(
tools=[DuckDuckGoSearchTool(), visit_webpage],
llm_engine=llm_engine,
max_iterations=10,
)
managed_web_agent = ManagedAgent(
agent=web_agent,
name="search",
description="Runs web searches for you. Give it your query as an argument.",
)
manager_agent = CodeAgent(
tools=[UserInputTool()],
llm_engine=llm_engine,
managed_agents=[managed_web_agent],
additional_authorized_imports=["time", "datetime"],
)
with console.status(
"Agent is running...", spinner="aesthetic"
):
manager_agent.run("""How many years ago was Stripe founded?
You should ask for user input on wether the answer is correct before returning your final answer.""")

49
poetry.lock generated
View File

@ -1,5 +1,26 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]]
name = "beautifulsoup4"
version = "4.12.3"
description = "Screen-scraping library"
optional = false
python-versions = ">=3.6.0"
files = [
{file = "beautifulsoup4-4.12.3-py3-none-any.whl", hash = "sha256:b80878c9f40111313e55da8ba20bdba06d8fa3969fc68304167741bbf9e082ed"},
{file = "beautifulsoup4-4.12.3.tar.gz", hash = "sha256:74e3d1928edc070d21748185c46e3fb33490f22f52a3addee9aee0f4f7781051"},
]
[package.dependencies]
soupsieve = ">1.2"
[package.extras]
cchardet = ["cchardet"]
chardet = ["chardet"]
charset-normalizer = ["charset-normalizer"]
html5lib = ["html5lib"]
lxml = ["lxml"]
[[package]] [[package]]
name = "certifi" name = "certifi"
version = "2024.8.30" version = "2024.8.30"
@ -338,6 +359,21 @@ profiling = ["gprof2dot"]
rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"]
testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
[[package]]
name = "markdownify"
version = "0.14.1"
description = "Convert HTML to markdown."
optional = false
python-versions = "*"
files = [
{file = "markdownify-0.14.1-py3-none-any.whl", hash = "sha256:4c46a6c0c12c6005ddcd49b45a5a890398b002ef51380cd319db62df5e09bc2a"},
{file = "markdownify-0.14.1.tar.gz", hash = "sha256:a62a7a216947ed0b8dafb95b99b2ef4a0edd1e18d5653c656f68f03db2bfb2f1"},
]
[package.dependencies]
beautifulsoup4 = ">=4.9,<5"
six = ">=1.15,<2"
[[package]] [[package]]
name = "markupsafe" name = "markupsafe"
version = "3.0.2" version = "3.0.2"
@ -1096,6 +1132,17 @@ files = [
{file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"},
] ]
[[package]]
name = "soupsieve"
version = "2.6"
description = "A modern CSS selector implementation for Beautiful Soup."
optional = false
python-versions = ">=3.8"
files = [
{file = "soupsieve-2.6-py3-none-any.whl", hash = "sha256:e72c4ff06e4fb6e4b5a9f0f55fe6e81514581fca1515028625d0f299c602ccc9"},
{file = "soupsieve-2.6.tar.gz", hash = "sha256:e2e68417777af359ec65daac1057404a3c8a5455bb8abc36f1a9866ab1a51abb"},
]
[[package]] [[package]]
name = "tokenizers" name = "tokenizers"
version = "0.21.0" version = "0.21.0"
@ -1301,4 +1348,4 @@ zstd = ["zstandard (>=0.18.0)"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.10,<3.13" python-versions = ">=3.10,<3.13"
content-hash = "6c3841968936d66bf70e11c6c8e0a16fec6c2f4d88d79cd8ac5a412225e7cf56" content-hash = "3a0896faf882952a0d780efcc862017989612fcb421a6ee01e4eec0ba6c0f638"

View File

@ -67,6 +67,7 @@ pandas = "^2.2.3"
jinja2 = "^3.1.4" jinja2 = "^3.1.4"
pillow = "^11.0.0" pillow = "^11.0.0"
llama-cpp-python = "^0.3.4" llama-cpp-python = "^0.3.4"
markdownify = "^0.14.1"
[build-system] [build-system]