Fix additional args in stream_to_gradio (#221)
This commit is contained in:
parent
2a69f1574e
commit
fdf4fe49ba
|
@ -75,12 +75,12 @@ class ToolCall:
|
||||||
id: str
|
id: str
|
||||||
|
|
||||||
|
|
||||||
class AgentStep:
|
class AgentStepLog:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ActionStep(AgentStep):
|
class ActionStep(AgentStepLog):
|
||||||
agent_memory: List[Dict[str, str]] | None = None
|
agent_memory: List[Dict[str, str]] | None = None
|
||||||
tool_calls: List[ToolCall] | None = None
|
tool_calls: List[ToolCall] | None = None
|
||||||
start_time: float | None = None
|
start_time: float | None = None
|
||||||
|
@ -94,18 +94,18 @@ class ActionStep(AgentStep):
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PlanningStep(AgentStep):
|
class PlanningStep(AgentStepLog):
|
||||||
plan: str
|
plan: str
|
||||||
facts: str
|
facts: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TaskStep(AgentStep):
|
class TaskStep(AgentStepLog):
|
||||||
task: str
|
task: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SystemPromptStep(AgentStep):
|
class SystemPromptStep(AgentStepLog):
|
||||||
system_prompt: str
|
system_prompt: str
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -19,11 +19,13 @@ import os
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from .agents import ActionStep, AgentStep, MultiStepAgent
|
from typing import Optional
|
||||||
|
|
||||||
|
from .agents import ActionStep, AgentStepLog, MultiStepAgent
|
||||||
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
|
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
|
||||||
|
|
||||||
|
|
||||||
def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
|
def pull_messages_from_step(step_log: AgentStepLog, test_mode: bool = True):
|
||||||
"""Extract ChatMessage objects from agent steps"""
|
"""Extract ChatMessage objects from agent steps"""
|
||||||
if isinstance(step_log, ActionStep):
|
if isinstance(step_log, ActionStep):
|
||||||
yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "")
|
yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "")
|
||||||
|
@ -53,11 +55,13 @@ def stream_to_gradio(
|
||||||
task: str,
|
task: str,
|
||||||
test_mode: bool = False,
|
test_mode: bool = False,
|
||||||
reset_agent_memory: bool = False,
|
reset_agent_memory: bool = False,
|
||||||
**kwargs,
|
additional_args: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
|
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
|
||||||
|
|
||||||
for step_log in agent.run(task, stream=True, reset=reset_agent_memory, **kwargs):
|
for step_log in agent.run(
|
||||||
|
task, stream=True, reset=reset_agent_memory, additional_args=additional_args
|
||||||
|
):
|
||||||
for message in pull_messages_from_step(step_log, test_mode=test_mode):
|
for message in pull_messages_from_step(step_log, test_mode=test_mode):
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
|
@ -172,7 +176,7 @@ class GradioUI:
|
||||||
type="messages",
|
type="messages",
|
||||||
avatar_images=(
|
avatar_images=(
|
||||||
None,
|
None,
|
||||||
"https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
|
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
# If an upload folder is provided, enable the upload feature
|
# If an upload folder is provided, enable the upload feature
|
||||||
|
|
Loading…
Reference in New Issue