From 49c34f625cab89d6b137de34fa2cd9b4be4bebbd Mon Sep 17 00:00:00 2001 From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> Date: Tue, 28 Jan 2025 09:24:13 +0100 Subject: [PATCH] Gradio chatbot: step duration, number, token count, support nested thoughts (#384) * Add enhanced Gradio UI with nested agents calls, execution-logs, errors - Add virtual separation between steps - Highlight final answer as required in internal discussion - Show step numbers and token counts - Include step duration tracking - Improve message display structure --------- Co-authored-by: Yuvraj Sharma <48665385+yvrjsharma@users.noreply.github.com> --- pyproject.toml | 2 +- src/smolagents/gradio_ui.py | 129 ++++++++++++++++++++++++++++++------ tests/test_monitoring.py | 8 +-- 3 files changed, 112 insertions(+), 27 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 22c1252..42c5fc3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ e2b = [ "python-dotenv>=1.0.1", ] gradio = [ - "gradio>=5.8.0", + "gradio>=5.13.0", ] litellm = [ "litellm>=1.55.10", diff --git a/src/smolagents/gradio_ui.py b/src/smolagents/gradio_ui.py index dc27c31..d7d1211 100644 --- a/src/smolagents/gradio_ui.py +++ b/src/smolagents/gradio_ui.py @@ -24,31 +24,102 @@ from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types from .utils import _is_package_available -def pull_messages_from_step(step_log: AgentStepLog): - """Extract ChatMessage objects from agent steps""" +def pull_messages_from_step( + step_log: AgentStepLog, +): + """Extract ChatMessage objects from agent steps with proper nesting""" import gradio as gr if isinstance(step_log, ActionStep): - yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "") - if step_log.tool_calls is not None: + # Output the step number + step_number = f"Step {step_log.step_number}" if step_log.step_number is not None else "" + yield gr.ChatMessage(role="assistant", content=f"**{step_number}**") + + # First yield the thought/reasoning from the LLM + if hasattr(step_log, "llm_output") and step_log.llm_output is not None: + # Clean up the LLM output + llm_output = step_log.llm_output.strip() + # Remove any trailing and extra backticks, handling multiple possible formats + llm_output = re.sub(r"```\s*", "```", llm_output) # handles ``` + llm_output = re.sub(r"\s*```", "```", llm_output) # handles ``` + llm_output = re.sub(r"```\s*\n\s*", "```", llm_output) # handles ```\n + llm_output = llm_output.strip() + yield gr.ChatMessage(role="assistant", content=llm_output) + + # For tool calls, create a parent message + if hasattr(step_log, "tool_calls") and step_log.tool_calls is not None: first_tool_call = step_log.tool_calls[0] - used_code = first_tool_call.name == "code interpreter" - content = first_tool_call.arguments + used_code = first_tool_call.name == "python_interpreter" + parent_id = f"call_{len(step_log.tool_calls)}" + + # Tool call becomes the parent message with timing info + # First we will handle arguments based on type + args = first_tool_call.arguments + if isinstance(args, dict): + content = str(args.get("answer", str(args))) + else: + content = str(args).strip() + if used_code: - content = f"```py\n{content}\n```" - yield gr.ChatMessage( + # Clean up the content by removing any end code tags + content = re.sub(r"```.*?\n", "", content) # Remove existing code blocks + content = re.sub(r"\s*\s*", "", content) # Remove end_code tags + content = content.strip() + if not content.startswith("```python"): + content = f"```python\n{content}\n```" + + parent_message_tool = gr.ChatMessage( role="assistant", - metadata={"title": f"🛠️ Used tool {first_tool_call.name}"}, - content=str(content), + content=content, + metadata={ + "title": f"🛠️ Used tool {first_tool_call.name}", + "id": parent_id, + "status": "pending", + }, ) - if step_log.observations is not None: - yield gr.ChatMessage(role="assistant", content=step_log.observations) - if step_log.error is not None: - yield gr.ChatMessage( - role="assistant", - content=str(step_log.error), - metadata={"title": "💥 Error"}, + yield parent_message_tool + + # Nesting execution logs under the tool call if they exist + if hasattr(step_log, "observations") and ( + step_log.observations is not None and step_log.observations.strip() + ): # Only yield execution logs if there's actual content + log_content = step_log.observations.strip() + if log_content: + log_content = re.sub(r"^Execution logs:\s*", "", log_content) + yield gr.ChatMessage( + role="assistant", + content=f"{log_content}", + metadata={"title": "📝 Execution Logs", "parent_id": parent_id, "status": "done"}, + ) + + # Nesting any errors under the tool call + if hasattr(step_log, "error") and step_log.error is not None: + yield gr.ChatMessage( + role="assistant", + content=str(step_log.error), + metadata={"title": "💥 Error", "parent_id": parent_id, "status": "done"}, + ) + + # Update parent message metadata to done status without yielding a new message + parent_message_tool.metadata["status"] = "done" + + # Handle standalone errors but not from tool calls + elif hasattr(step_log, "error") and step_log.error is not None: + yield gr.ChatMessage(role="assistant", content=str(step_log.error), metadata={"title": "💥 Error"}) + + # Calculate duration and token information + step_footnote = f"{step_number}" + if hasattr(step_log, "input_token_count") and hasattr(step_log, "output_token_count"): + token_str = ( + f" | Input-tokens:{step_log.input_token_count:,} | Output-tokens:{step_log.output_token_count:,}" ) + step_footnote += token_str + if hasattr(step_log, "duration"): + step_duration = f" | Duration: {round(float(step_log.duration), 2)}" if step_log.duration else None + step_footnote += step_duration + step_footnote = f"""{step_footnote} """ + yield gr.ChatMessage(role="assistant", content=f"{step_footnote}") + yield gr.ChatMessage(role="assistant", content="-----") def stream_to_gradio( @@ -60,12 +131,25 @@ def stream_to_gradio( """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" if not _is_package_available("gradio"): raise ModuleNotFoundError( - "Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[audio]'`" + "Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[gradio]'`" ) import gradio as gr + total_input_tokens = 0 + total_output_tokens = 0 + for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args): - for message in pull_messages_from_step(step_log): + # Track tokens if model provides them + if hasattr(agent.model, "last_input_token_count"): + total_input_tokens += agent.model.last_input_token_count + total_output_tokens += agent.model.last_output_token_count + if isinstance(step_log, ActionStep): + step_log.input_token_count = agent.model.last_input_token_count + step_log.output_token_count = agent.model.last_output_token_count + + for message in pull_messages_from_step( + step_log, + ): yield message final_answer = step_log # Last log is the run's final_answer @@ -87,7 +171,7 @@ def stream_to_gradio( content={"path": final_answer.to_string(), "mime_type": "audio/wav"}, ) else: - yield gr.ChatMessage(role="assistant", content=str(final_answer)) + yield gr.ChatMessage(role="assistant", content=f"**Final answer:** {str(final_answer)}") class GradioUI: @@ -176,7 +260,7 @@ class GradioUI: def launch(self, **kwargs): import gradio as gr - with gr.Blocks() as demo: + with gr.Blocks(fill_height=True) as demo: stored_messages = gr.State([]) file_uploads_log = gr.State([]) chatbot = gr.Chatbot( @@ -187,6 +271,7 @@ class GradioUI: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png", ), resizeable=True, + scale=1, ) # If an upload folder is provided, enable the upload feature if self.file_upload_folder is not None: @@ -204,7 +289,7 @@ class GradioUI: [stored_messages, text_input], ).then(self.interact_with_agent, [stored_messages, chatbot], [chatbot]) - demo.launch(**kwargs) + demo.launch(debug=True, share=True, **kwargs) __all__ = ["stream_to_gradio", "GradioUI"] diff --git a/tests/test_monitoring.py b/tests/test_monitoring.py index 9fa30bb..abe1740 100644 --- a/tests/test_monitoring.py +++ b/tests/test_monitoring.py @@ -71,7 +71,7 @@ class MonitoringTester(unittest.TestCase): self.assertEqual(agent.monitor.total_input_token_count, 10) self.assertEqual(agent.monitor.total_output_token_count, 20) - def test_json_agent_metrics(self): + def test_toolcalling_agent_metrics(self): agent = ToolCallingAgent( tools=[], model=FakeLLMModel(), @@ -134,7 +134,7 @@ class MonitoringTester(unittest.TestCase): # Use stream_to_gradio to capture the output outputs = list(stream_to_gradio(agent, task="Test task")) - self.assertEqual(len(outputs), 4) + self.assertEqual(len(outputs), 7) final_message = outputs[-1] self.assertEqual(final_message.role, "assistant") self.assertIn("This is the final answer.", final_message.content) @@ -155,7 +155,7 @@ class MonitoringTester(unittest.TestCase): ) ) - self.assertEqual(len(outputs), 3) + self.assertEqual(len(outputs), 5) final_message = outputs[-1] self.assertEqual(final_message.role, "assistant") self.assertIsInstance(final_message.content, dict) @@ -177,7 +177,7 @@ class MonitoringTester(unittest.TestCase): # Use stream_to_gradio to capture the output outputs = list(stream_to_gradio(agent, task="Test task")) - self.assertEqual(len(outputs), 5) + self.assertEqual(len(outputs), 9) final_message = outputs[-1] self.assertEqual(final_message.role, "assistant") self.assertIn("Simulated agent error", final_message.content)