Fix json schema for final answer

This commit is contained in:
Aymeric 2024-12-26 01:19:56 +01:00
parent 8bd5144da1
commit 93569bd7c1
3 changed files with 18 additions and 12 deletions

View File

@ -539,6 +539,7 @@ class MultiStepAgent:
final_step_log = ActionStep(error=AgentMaxIterationsError(error_message))
self.logs.append(final_step_log)
final_answer = self.provide_final_answer(task)
console.print(Text(f"Final answer: {final_answer}"))
final_step_log.action_output = final_answer
final_step_log.end_time = time.time()
final_step_log.duration = step_log.end_time - step_start_time
@ -588,6 +589,7 @@ class MultiStepAgent:
final_step_log = ActionStep(error=AgentMaxIterationsError(error_message))
self.logs.append(final_step_log)
final_answer = self.provide_final_answer(task)
console.print(Text(f"Final answer: {final_answer}"))
final_step_log.action_output = final_answer
final_step_log.duration = 0
for callback in self.step_callbacks:

View File

@ -105,15 +105,18 @@ class PythonInterpreterTool(Tool):
def forward(self, code: str) -> str:
state = {}
output = str(
self.python_evaluator(
code,
state=state,
static_tools=self.base_python_tools,
authorized_imports=self.authorized_imports,
try:
output = str(
self.python_evaluator(
code,
state=state,
static_tools=self.base_python_tools,
authorized_imports=self.authorized_imports,
)
)
)
return f"Stdout:\n{state['print_outputs']}\nOutput: {output}"
return f"Stdout:\n{state['print_outputs']}\nOutput: {output}"
except Exception as e:
return f"Error: {str(e)}"
class FinalAnswerTool(Tool):

View File

@ -66,6 +66,10 @@ tool_role_conversions = {
def get_json_schema(tool: Tool) -> Dict:
properties = deepcopy(tool.inputs)
for value in properties.values():
if value["type"] == "any":
value["type"] = "string"
return {
"type": "function",
"function": {
@ -73,10 +77,7 @@ def get_json_schema(tool: Tool) -> Dict:
"description": tool.description,
"parameters": {
"type": "object",
"properties": {
k: {k2: v2.replace("any", "object") for k2, v2 in v.items()}
for k, v in tool.inputs.items()
},
"properties": properties,
"required": list(tool.inputs.keys()),
},
},