Log list of tool calls in ActionStep (#172)

This commit is contained in:
Aymeric Roucher 2025-01-13 12:08:27 +01:00 committed by GitHub
parent 5a62304c91
commit 289c06df0f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 41 additions and 41 deletions

View File

@ -82,7 +82,7 @@ class AgentStep:
@dataclass @dataclass
class ActionStep(AgentStep): class ActionStep(AgentStep):
agent_memory: List[Dict[str, str]] | None = None agent_memory: List[Dict[str, str]] | None = None
tool_call: ToolCall | None = None tool_calls: List[ToolCall] | None = None
start_time: float | None = None start_time: float | None = None
end_time: float | None = None end_time: float | None = None
step: int | None = None step: int | None = None
@ -302,25 +302,26 @@ class MultiStepAgent:
} }
memory.append(thought_message) memory.append(thought_message)
if step_log.tool_call is not None: if step_log.tool_calls is not None:
tool_call_message = { tool_call_message = {
"role": MessageRole.ASSISTANT, "role": MessageRole.ASSISTANT,
"content": str( "content": str(
[ [
{ {
"id": step_log.tool_call.id, "id": tool_call.id,
"type": "function", "type": "function",
"function": { "function": {
"name": step_log.tool_call.name, "name": tool_call.name,
"arguments": step_log.tool_call.arguments, "arguments": tool_call.arguments,
}, },
} }
for tool_call in step_log.tool_calls
] ]
), ),
} }
memory.append(tool_call_message) memory.append(tool_call_message)
if step_log.tool_call is None and step_log.error is not None: if step_log.tool_calls is None and step_log.error is not None:
message_content = ( message_content = (
"Error:\n" "Error:\n"
+ str(step_log.error) + str(step_log.error)
@ -330,7 +331,7 @@ class MultiStepAgent:
"role": MessageRole.ASSISTANT, "role": MessageRole.ASSISTANT,
"content": message_content, "content": message_content,
} }
if step_log.tool_call is not None and ( if step_log.tool_calls is not None and (
step_log.error is not None or step_log.observations is not None step_log.error is not None or step_log.observations is not None
): ):
if step_log.error is not None: if step_log.error is not None:
@ -343,7 +344,7 @@ class MultiStepAgent:
message_content = f"Observation:\n{step_log.observations}" message_content = f"Observation:\n{step_log.observations}"
tool_response_message = { tool_response_message = {
"role": MessageRole.TOOL_RESPONSE, "role": MessageRole.TOOL_RESPONSE,
"content": f"Call id: {(step_log.tool_call.id if getattr(step_log.tool_call, 'id') else 'call_0')}\n" "content": f"Call id: {(step_log.tool_calls[0].id if getattr(step_log.tool_calls[0], 'id') else 'call_0')}\n"
+ message_content, + message_content,
} }
memory.append(tool_response_message) memory.append(tool_response_message)
@ -814,9 +815,9 @@ class ToolCallingAgent(MultiStepAgent):
f"Error in generating tool call with model:\n{e}" f"Error in generating tool call with model:\n{e}"
) )
log_entry.tool_call = ToolCall( log_entry.tool_calls = [
name=tool_name, arguments=tool_arguments, id=tool_call_id ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)
) ]
# Execute # Execute
self.logger.log( self.logger.log(
@ -977,11 +978,13 @@ class CodeAgent(MultiStepAgent):
) )
raise AgentParsingError(error_msg) raise AgentParsingError(error_msg)
log_entry.tool_call = ToolCall( log_entry.tool_calls = [
ToolCall(
name="python_interpreter", name="python_interpreter",
arguments=code_action, arguments=code_action,
id=f"call_{len(self.logs)}", id=f"call_{len(self.logs)}",
) )
]
# Execute # Execute
self.logger.log( self.logger.log(

View File

@ -43,9 +43,7 @@ class Monitor:
def update_metrics(self, step_log): def update_metrics(self, step_log):
step_duration = step_log.duration step_duration = step_log.duration
self.step_durations.append(step_duration) self.step_durations.append(step_duration)
console_outputs = ( console_outputs = f"[Step {len(self.step_durations) - 1}: Duration {step_duration:.2f} seconds"
f"[Step {len(self.step_durations)-1}: Duration {step_duration:.2f} seconds"
)
if getattr(self.tracked_model, "last_input_token_count", None) is not None: if getattr(self.tracked_model, "last_input_token_count", None) is not None:
self.total_input_token_count += self.tracked_model.last_input_token_count self.total_input_token_count += self.tracked_model.last_input_token_count

View File

@ -179,12 +179,12 @@ class Tool:
f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead." f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead."
) )
for input_name, input_content in self.inputs.items(): for input_name, input_content in self.inputs.items():
assert isinstance( assert isinstance(input_content, dict), (
input_content, dict f"Input '{input_name}' should be a dictionary."
), f"Input '{input_name}' should be a dictionary." )
assert ( assert "type" in input_content and "description" in input_content, (
"type" in input_content and "description" in input_content f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}." )
if input_content["type"] not in AUTHORIZED_TYPES: if input_content["type"] not in AUTHORIZED_TYPES:
raise Exception( raise Exception(
f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {AUTHORIZED_TYPES}." f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {AUTHORIZED_TYPES}."
@ -207,13 +207,13 @@ class Tool:
json_schema = _convert_type_hints_to_json_schema(self.forward) json_schema = _convert_type_hints_to_json_schema(self.forward)
for key, value in self.inputs.items(): for key, value in self.inputs.items():
if "nullable" in value: if "nullable" in value:
assert ( assert key in json_schema and "nullable" in json_schema[key], (
key in json_schema and "nullable" in json_schema[key] f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature."
), f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature." )
if key in json_schema and "nullable" in json_schema[key]: if key in json_schema and "nullable" in json_schema[key]:
assert ( assert "nullable" in value, (
"nullable" in value f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs."
), f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs." )
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return NotImplementedError("Write this method in your subclass of `Tool`.") return NotImplementedError("Write this method in your subclass of `Tool`.")
@ -272,7 +272,7 @@ class Tool:
class {class_name}(Tool): class {class_name}(Tool):
name = "{self.name}" name = "{self.name}"
description = "{self.description}" description = "{self.description}"
inputs = {json.dumps(self.inputs, separators=(',', ':'))} inputs = {json.dumps(self.inputs, separators=(",", ":"))}
output_type = "{self.output_type}" output_type = "{self.output_type}"
""").strip() """).strip()
import re import re
@ -439,7 +439,9 @@ class Tool:
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
others will be passed along to its init. others will be passed along to its init.
""" """
assert trust_remote_code, "Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool." assert trust_remote_code, (
"Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool."
)
hub_kwargs_names = [ hub_kwargs_names = [
"cache_dir", "cache_dir",
@ -620,13 +622,10 @@ class Tool:
arg.save(temp_file.name) arg.save(temp_file.name)
arg = temp_file.name arg = temp_file.name
if ( if (
isinstance(arg, str) (isinstance(arg, str) and os.path.isfile(arg))
and os.path.isfile(arg) or (isinstance(arg, Path) and arg.exists() and arg.is_file())
) or ( or is_http_url_like(arg)
isinstance(arg, Path) ):
and arg.exists()
and arg.is_file()
) or is_http_url_like(arg):
arg = handle_file(arg) arg = handle_file(arg)
return arg return arg