From fdf4fe49ba063de880e3378d576d662b9cf161b8 Mon Sep 17 00:00:00 2001 From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> Date: Thu, 16 Jan 2025 15:47:23 +0100 Subject: [PATCH] Fix additional args in stream_to_gradio (#221) --- src/smolagents/agents.py | 10 +++++----- src/smolagents/gradio_ui.py | 14 +++++++++----- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index cfa8a6f..f17daa9 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -75,12 +75,12 @@ class ToolCall: id: str -class AgentStep: +class AgentStepLog: pass @dataclass -class ActionStep(AgentStep): +class ActionStep(AgentStepLog): agent_memory: List[Dict[str, str]] | None = None tool_calls: List[ToolCall] | None = None start_time: float | None = None @@ -94,18 +94,18 @@ class ActionStep(AgentStep): @dataclass -class PlanningStep(AgentStep): +class PlanningStep(AgentStepLog): plan: str facts: str @dataclass -class TaskStep(AgentStep): +class TaskStep(AgentStepLog): task: str @dataclass -class SystemPromptStep(AgentStep): +class SystemPromptStep(AgentStepLog): system_prompt: str diff --git a/src/smolagents/gradio_ui.py b/src/smolagents/gradio_ui.py index 2eaeda6..26c2998 100644 --- a/src/smolagents/gradio_ui.py +++ b/src/smolagents/gradio_ui.py @@ -19,11 +19,13 @@ import os import mimetypes import re -from .agents import ActionStep, AgentStep, MultiStepAgent +from typing import Optional + +from .agents import ActionStep, AgentStepLog, MultiStepAgent from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types -def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True): +def pull_messages_from_step(step_log: AgentStepLog, test_mode: bool = True): """Extract ChatMessage objects from agent steps""" if isinstance(step_log, ActionStep): yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "") @@ -53,11 +55,13 @@ def stream_to_gradio( task: str, test_mode: bool = False, reset_agent_memory: bool = False, - **kwargs, + additional_args: Optional[dict] = None, ): """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" - for step_log in agent.run(task, stream=True, reset=reset_agent_memory, **kwargs): + for step_log in agent.run( + task, stream=True, reset=reset_agent_memory, additional_args=additional_args + ): for message in pull_messages_from_step(step_log, test_mode=test_mode): yield message @@ -172,7 +176,7 @@ class GradioUI: type="messages", avatar_images=( None, - "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png", ), ) # If an upload folder is provided, enable the upload feature