Remove unused and undocumented parameter (#273)

This commit is contained in:
Lucain 2025-01-20 10:45:55 +01:00 committed by GitHub
parent 0abd91cf72
commit 3178b18aab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 6 deletions

View File

@ -25,7 +25,7 @@ from .agents import ActionStep, AgentStepLog, MultiStepAgent
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
def pull_messages_from_step(step_log: AgentStepLog, test_mode: bool = True):
def pull_messages_from_step(step_log: AgentStepLog):
"""Extract ChatMessage objects from agent steps"""
if isinstance(step_log, ActionStep):
yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "")
@ -53,14 +53,13 @@ def pull_messages_from_step(step_log: AgentStepLog, test_mode: bool = True):
def stream_to_gradio(
agent,
task: str,
test_mode: bool = False,
reset_agent_memory: bool = False,
additional_args: Optional[dict] = None,
):
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args):
for message in pull_messages_from_step(step_log, test_mode=test_mode):
for message in pull_messages_from_step(step_log):
yield message
final_answer = step_log # Last log is the run's final_answer

View File

@ -131,7 +131,7 @@ class MonitoringTester(unittest.TestCase):
)
# Use stream_to_gradio to capture the output
outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True))
outputs = list(stream_to_gradio(agent, task="Test task"))
self.assertEqual(len(outputs), 4)
final_message = outputs[-1]
@ -151,7 +151,6 @@ class MonitoringTester(unittest.TestCase):
agent,
task="Test task",
additional_args=dict(image=AgentImage(value="path.png")),
test_mode=True,
)
)
@ -173,7 +172,7 @@ class MonitoringTester(unittest.TestCase):
)
# Use stream_to_gradio to capture the output
outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True))
outputs = list(stream_to_gradio(agent, task="Test task"))
self.assertEqual(len(outputs), 5)
final_message = outputs[-1]