Fix additional args in stream_to_gradio (#221)

This commit is contained in:
Aymeric Roucher 2025-01-16 15:47:23 +01:00 committed by GitHub
parent 2a69f1574e
commit fdf4fe49ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 10 deletions

View File

@ -75,12 +75,12 @@ class ToolCall:
id: str id: str
class AgentStep: class AgentStepLog:
pass pass
@dataclass @dataclass
class ActionStep(AgentStep): class ActionStep(AgentStepLog):
agent_memory: List[Dict[str, str]] | None = None agent_memory: List[Dict[str, str]] | None = None
tool_calls: List[ToolCall] | None = None tool_calls: List[ToolCall] | None = None
start_time: float | None = None start_time: float | None = None
@ -94,18 +94,18 @@ class ActionStep(AgentStep):
@dataclass @dataclass
class PlanningStep(AgentStep): class PlanningStep(AgentStepLog):
plan: str plan: str
facts: str facts: str
@dataclass @dataclass
class TaskStep(AgentStep): class TaskStep(AgentStepLog):
task: str task: str
@dataclass @dataclass
class SystemPromptStep(AgentStep): class SystemPromptStep(AgentStepLog):
system_prompt: str system_prompt: str

View File

@ -19,11 +19,13 @@ import os
import mimetypes import mimetypes
import re import re
from .agents import ActionStep, AgentStep, MultiStepAgent from typing import Optional
from .agents import ActionStep, AgentStepLog, MultiStepAgent
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True): def pull_messages_from_step(step_log: AgentStepLog, test_mode: bool = True):
"""Extract ChatMessage objects from agent steps""" """Extract ChatMessage objects from agent steps"""
if isinstance(step_log, ActionStep): if isinstance(step_log, ActionStep):
yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "") yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "")
@ -53,11 +55,13 @@ def stream_to_gradio(
task: str, task: str,
test_mode: bool = False, test_mode: bool = False,
reset_agent_memory: bool = False, reset_agent_memory: bool = False,
**kwargs, additional_args: Optional[dict] = None,
): ):
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
for step_log in agent.run(task, stream=True, reset=reset_agent_memory, **kwargs): for step_log in agent.run(
task, stream=True, reset=reset_agent_memory, additional_args=additional_args
):
for message in pull_messages_from_step(step_log, test_mode=test_mode): for message in pull_messages_from_step(step_log, test_mode=test_mode):
yield message yield message
@ -172,7 +176,7 @@ class GradioUI:
type="messages", type="messages",
avatar_images=( avatar_images=(
None, None,
"https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png", "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png",
), ),
) )
# If an upload folder is provided, enable the upload feature # If an upload folder is provided, enable the upload feature