Improve type handling and logging observations for ToolCallingAgent

This commit is contained in:
Aymeric 2024-12-25 23:55:57 +01:00
parent c4f38850b2
commit 5e82faf595
2 changed files with 12 additions and 6 deletions

View File

@ -390,12 +390,12 @@ class MultiStepAgent:
try:
if isinstance(arguments, str):
observation = available_tools[tool_name](arguments)
observation = available_tools[tool_name].__call__(arguments, sanitize_inputs_outputs=True)
elif isinstance(arguments, dict):
for key, value in arguments.items():
if isinstance(value, str) and value in self.state:
arguments[key] = self.state[value]
observation = available_tools[tool_name](**arguments)
observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True)
else:
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
console.print(f"[bold red]{error_msg}")
@ -779,9 +779,14 @@ class ToolCallingAgent(MultiStepAgent):
if (
isinstance(answer, str) and answer in self.state.keys()
): # if the answer is a state variable, return the value
answer = self.state[answer]
log_entry.action_output = answer
return answer
final_answer = self.state[answer]
console.print(f"[bold {YELLOW_HEX}]Final answer:[/bold {YELLOW_HEX}] Extracting key '{answer}' from state to return value '{final_answer}'.")
else:
final_answer = answer
console.print(Text(f"Final answer: {final_answer}", style=f"bold {YELLOW_HEX}"))
log_entry.action_output = final_answer
return final_answer
else:
if tool_arguments is None:
tool_arguments = {}
@ -798,6 +803,7 @@ class ToolCallingAgent(MultiStepAgent):
updated_information = f"Stored '{observation_name}' in memory."
else:
updated_information = str(observation).strip()
console.print(f"Observations: {updated_information}")
log_entry.observations = updated_information
return None

View File

@ -249,7 +249,7 @@ class AgentAudio(AgentType, str):
AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio}
INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage, np.ndarray: AgentAudio}
INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage, torch.Tensor: AgentAudio}
if is_torch_available():
INSTANCE_TYPE_MAPPING[Tensor] = AgentAudio