Add stdout to PythonInterpreterTool

This commit is contained in:
Aymeric 2024-12-26 00:49:45 +01:00
parent 5e82faf595
commit 8bd5144da1
3 changed files with 17 additions and 13 deletions

View File

@ -889,16 +889,16 @@ class CodeAgent(MultiStepAgent):
console.print_exception()
raise AgentGenerationError(f"Error in generating model output:\n{e}")
from rich.live import Live
from rich.markdown import Markdown
import time
# from rich.live import Live
# from rich.markdown import Markdown
# import time
with Live(console=console, vertical_overflow="visible") as live:
message = ""
for i in range(100):
time.sleep(0.02)
message += str(i)
live.update(Markdown(message))
# with Live(console=console, vertical_overflow="visible") as live:
# message = ""
# for i in range(100):
# time.sleep(0.02)
# message += str(i)
# live.update(Markdown(message))
if self.verbose:
console.print(

View File

@ -99,19 +99,21 @@ class PythonInterpreterTool(Tool):
),
}
}
self.base_python_tool = BASE_PYTHON_TOOLS
self.base_python_tools = BASE_PYTHON_TOOLS
self.python_evaluator = evaluate_python_code
super().__init__(*args, **kwargs)
def forward(self, code: str) -> str:
state = {}
output = str(
self.python_evaluator(
code,
static_tools=self.base_python_tool,
state=state,
static_tools=self.base_python_tools,
authorized_imports=self.authorized_imports,
)
)
return output
return f"Stdout:\n{state['print_outputs']}\nOutput: {output}"
class FinalAnswerTool(Tool):
@ -240,7 +242,7 @@ class GoogleSearchTool(Tool):
class VisitWebpageTool(Tool):
name = "visit_webpage"
description = "Visits a webpage at the given url and returns its content as a markdown string."
description = "Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages."
inputs = {
"url": {
"type": "string",

View File

@ -258,6 +258,7 @@ class HfApiModel(HfModel):
self,
messages: List[Dict[str, str]],
available_tools: List[Tool],
stop_sequences,
):
"""Generates a tool call for the given message list. This method is used only by `ToolCallingAgent`."""
messages = get_clean_message_list(
@ -267,6 +268,7 @@ class HfApiModel(HfModel):
messages=messages,
tools=[get_json_schema(tool) for tool in available_tools],
tool_choice="auto",
stop=stop_sequences,
)
tool_call = response.choices[0].message.tool_calls[0]
self.last_input_token_count = response.usage.prompt_tokens