Fix additional args in stream_to_gradio (#221)

This commit is contained in:
Aymeric Roucher 2025-01-16 15:47:23 +01:00 committed by GitHub
parent 2a69f1574e
commit fdf4fe49ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 10 deletions

View File

@ -75,12 +75,12 @@ class ToolCall:
id: str
class AgentStep:
class AgentStepLog:
pass
@dataclass
class ActionStep(AgentStep):
class ActionStep(AgentStepLog):
agent_memory: List[Dict[str, str]] | None = None
tool_calls: List[ToolCall] | None = None
start_time: float | None = None
@ -94,18 +94,18 @@ class ActionStep(AgentStep):
@dataclass
class PlanningStep(AgentStep):
class PlanningStep(AgentStepLog):
plan: str
facts: str
@dataclass
class TaskStep(AgentStep):
class TaskStep(AgentStepLog):
task: str
@dataclass
class SystemPromptStep(AgentStep):
class SystemPromptStep(AgentStepLog):
system_prompt: str

View File

@ -19,11 +19,13 @@ import os
import mimetypes
import re
from .agents import ActionStep, AgentStep, MultiStepAgent
from typing import Optional
from .agents import ActionStep, AgentStepLog, MultiStepAgent
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
def pull_messages_from_step(step_log: AgentStepLog, test_mode: bool = True):
"""Extract ChatMessage objects from agent steps"""
if isinstance(step_log, ActionStep):
yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "")
@ -53,11 +55,13 @@ def stream_to_gradio(
task: str,
test_mode: bool = False,
reset_agent_memory: bool = False,
**kwargs,
additional_args: Optional[dict] = None,
):
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
for step_log in agent.run(task, stream=True, reset=reset_agent_memory, **kwargs):
for step_log in agent.run(
task, stream=True, reset=reset_agent_memory, additional_args=additional_args
):
for message in pull_messages_from_step(step_log, test_mode=test_mode):
yield message
@ -172,7 +176,7 @@ class GradioUI:
type="messages",
avatar_images=(
None,
"https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png",
),
)
# If an upload folder is provided, enable the upload feature