Remove unused and undocumented parameter (#273)

This commit is contained in:
Lucain 2025-01-20 10:45:55 +01:00 committed by GitHub
parent 0abd91cf72
commit 3178b18aab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 6 deletions

View File

@ -25,7 +25,7 @@ from .agents import ActionStep, AgentStepLog, MultiStepAgent
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
def pull_messages_from_step(step_log: AgentStepLog, test_mode: bool = True): def pull_messages_from_step(step_log: AgentStepLog):
"""Extract ChatMessage objects from agent steps""" """Extract ChatMessage objects from agent steps"""
if isinstance(step_log, ActionStep): if isinstance(step_log, ActionStep):
yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "") yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "")
@ -53,14 +53,13 @@ def pull_messages_from_step(step_log: AgentStepLog, test_mode: bool = True):
def stream_to_gradio( def stream_to_gradio(
agent, agent,
task: str, task: str,
test_mode: bool = False,
reset_agent_memory: bool = False, reset_agent_memory: bool = False,
additional_args: Optional[dict] = None, additional_args: Optional[dict] = None,
): ):
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args): for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args):
for message in pull_messages_from_step(step_log, test_mode=test_mode): for message in pull_messages_from_step(step_log):
yield message yield message
final_answer = step_log # Last log is the run's final_answer final_answer = step_log # Last log is the run's final_answer

View File

@ -131,7 +131,7 @@ class MonitoringTester(unittest.TestCase):
) )
# Use stream_to_gradio to capture the output # Use stream_to_gradio to capture the output
outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True)) outputs = list(stream_to_gradio(agent, task="Test task"))
self.assertEqual(len(outputs), 4) self.assertEqual(len(outputs), 4)
final_message = outputs[-1] final_message = outputs[-1]
@ -151,7 +151,6 @@ class MonitoringTester(unittest.TestCase):
agent, agent,
task="Test task", task="Test task",
additional_args=dict(image=AgentImage(value="path.png")), additional_args=dict(image=AgentImage(value="path.png")),
test_mode=True,
) )
) )
@ -173,7 +172,7 @@ class MonitoringTester(unittest.TestCase):
) )
# Use stream_to_gradio to capture the output # Use stream_to_gradio to capture the output
outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True)) outputs = list(stream_to_gradio(agent, task="Test task"))
self.assertEqual(len(outputs), 5) self.assertEqual(len(outputs), 5)
final_message = outputs[-1] final_message = outputs[-1]