Fix additional args in stream_to_gradio (#221)
This commit is contained in:
parent
2a69f1574e
commit
fdf4fe49ba
|
@ -75,12 +75,12 @@ class ToolCall:
|
|||
id: str
|
||||
|
||||
|
||||
class AgentStep:
|
||||
class AgentStepLog:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionStep(AgentStep):
|
||||
class ActionStep(AgentStepLog):
|
||||
agent_memory: List[Dict[str, str]] | None = None
|
||||
tool_calls: List[ToolCall] | None = None
|
||||
start_time: float | None = None
|
||||
|
@ -94,18 +94,18 @@ class ActionStep(AgentStep):
|
|||
|
||||
|
||||
@dataclass
|
||||
class PlanningStep(AgentStep):
|
||||
class PlanningStep(AgentStepLog):
|
||||
plan: str
|
||||
facts: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskStep(AgentStep):
|
||||
class TaskStep(AgentStepLog):
|
||||
task: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemPromptStep(AgentStep):
|
||||
class SystemPromptStep(AgentStepLog):
|
||||
system_prompt: str
|
||||
|
||||
|
||||
|
|
|
@ -19,11 +19,13 @@ import os
|
|||
import mimetypes
|
||||
import re
|
||||
|
||||
from .agents import ActionStep, AgentStep, MultiStepAgent
|
||||
from typing import Optional
|
||||
|
||||
from .agents import ActionStep, AgentStepLog, MultiStepAgent
|
||||
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
|
||||
|
||||
|
||||
def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
|
||||
def pull_messages_from_step(step_log: AgentStepLog, test_mode: bool = True):
|
||||
"""Extract ChatMessage objects from agent steps"""
|
||||
if isinstance(step_log, ActionStep):
|
||||
yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "")
|
||||
|
@ -53,11 +55,13 @@ def stream_to_gradio(
|
|||
task: str,
|
||||
test_mode: bool = False,
|
||||
reset_agent_memory: bool = False,
|
||||
**kwargs,
|
||||
additional_args: Optional[dict] = None,
|
||||
):
|
||||
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
|
||||
|
||||
for step_log in agent.run(task, stream=True, reset=reset_agent_memory, **kwargs):
|
||||
for step_log in agent.run(
|
||||
task, stream=True, reset=reset_agent_memory, additional_args=additional_args
|
||||
):
|
||||
for message in pull_messages_from_step(step_log, test_mode=test_mode):
|
||||
yield message
|
||||
|
||||
|
@ -172,7 +176,7 @@ class GradioUI:
|
|||
type="messages",
|
||||
avatar_images=(
|
||||
None,
|
||||
"https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png",
|
||||
),
|
||||
)
|
||||
# If an upload folder is provided, enable the upload feature
|
||||
|
|
Loading…
Reference in New Issue