Log list of tool calls in ActionStep (#172)

This commit is contained in:
Aymeric Roucher 2025-01-13 12:08:27 +01:00 committed by GitHub
parent 5a62304c91
commit 289c06df0f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 41 additions and 41 deletions

View File

@ -82,7 +82,7 @@ class AgentStep:
@dataclass
class ActionStep(AgentStep):
agent_memory: List[Dict[str, str]] | None = None
tool_call: ToolCall | None = None
tool_calls: List[ToolCall] | None = None
start_time: float | None = None
end_time: float | None = None
step: int | None = None
@ -302,25 +302,26 @@ class MultiStepAgent:
}
memory.append(thought_message)
if step_log.tool_call is not None:
if step_log.tool_calls is not None:
tool_call_message = {
"role": MessageRole.ASSISTANT,
"content": str(
[
{
"id": step_log.tool_call.id,
"id": tool_call.id,
"type": "function",
"function": {
"name": step_log.tool_call.name,
"arguments": step_log.tool_call.arguments,
"name": tool_call.name,
"arguments": tool_call.arguments,
},
}
for tool_call in step_log.tool_calls
]
),
}
memory.append(tool_call_message)
if step_log.tool_call is None and step_log.error is not None:
if step_log.tool_calls is None and step_log.error is not None:
message_content = (
"Error:\n"
+ str(step_log.error)
@ -330,7 +331,7 @@ class MultiStepAgent:
"role": MessageRole.ASSISTANT,
"content": message_content,
}
if step_log.tool_call is not None and (
if step_log.tool_calls is not None and (
step_log.error is not None or step_log.observations is not None
):
if step_log.error is not None:
@ -343,7 +344,7 @@ class MultiStepAgent:
message_content = f"Observation:\n{step_log.observations}"
tool_response_message = {
"role": MessageRole.TOOL_RESPONSE,
"content": f"Call id: {(step_log.tool_call.id if getattr(step_log.tool_call, 'id') else 'call_0')}\n"
"content": f"Call id: {(step_log.tool_calls[0].id if getattr(step_log.tool_calls[0], 'id') else 'call_0')}\n"
+ message_content,
}
memory.append(tool_response_message)
@ -814,9 +815,9 @@ class ToolCallingAgent(MultiStepAgent):
f"Error in generating tool call with model:\n{e}"
)
log_entry.tool_call = ToolCall(
name=tool_name, arguments=tool_arguments, id=tool_call_id
)
log_entry.tool_calls = [
ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)
]
# Execute
self.logger.log(
@ -977,11 +978,13 @@ class CodeAgent(MultiStepAgent):
)
raise AgentParsingError(error_msg)
log_entry.tool_call = ToolCall(
log_entry.tool_calls = [
ToolCall(
name="python_interpreter",
arguments=code_action,
id=f"call_{len(self.logs)}",
)
]
# Execute
self.logger.log(

View File

@ -43,9 +43,7 @@ class Monitor:
def update_metrics(self, step_log):
step_duration = step_log.duration
self.step_durations.append(step_duration)
console_outputs = (
f"[Step {len(self.step_durations)-1}: Duration {step_duration:.2f} seconds"
)
console_outputs = f"[Step {len(self.step_durations) - 1}: Duration {step_duration:.2f} seconds"
if getattr(self.tracked_model, "last_input_token_count", None) is not None:
self.total_input_token_count += self.tracked_model.last_input_token_count

View File

@ -179,12 +179,12 @@ class Tool:
f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead."
)
for input_name, input_content in self.inputs.items():
assert isinstance(
input_content, dict
), f"Input '{input_name}' should be a dictionary."
assert (
"type" in input_content and "description" in input_content
), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
assert isinstance(input_content, dict), (
f"Input '{input_name}' should be a dictionary."
)
assert "type" in input_content and "description" in input_content, (
f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
)
if input_content["type"] not in AUTHORIZED_TYPES:
raise Exception(
f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {AUTHORIZED_TYPES}."
@ -207,13 +207,13 @@ class Tool:
json_schema = _convert_type_hints_to_json_schema(self.forward)
for key, value in self.inputs.items():
if "nullable" in value:
assert (
key in json_schema and "nullable" in json_schema[key]
), f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature."
assert key in json_schema and "nullable" in json_schema[key], (
f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature."
)
if key in json_schema and "nullable" in json_schema[key]:
assert (
"nullable" in value
), f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs."
assert "nullable" in value, (
f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs."
)
def forward(self, *args, **kwargs):
return NotImplementedError("Write this method in your subclass of `Tool`.")
@ -272,7 +272,7 @@ class Tool:
class {class_name}(Tool):
name = "{self.name}"
description = "{self.description}"
inputs = {json.dumps(self.inputs, separators=(',', ':'))}
inputs = {json.dumps(self.inputs, separators=(",", ":"))}
output_type = "{self.output_type}"
""").strip()
import re
@ -439,7 +439,9 @@ class Tool:
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
others will be passed along to its init.
"""
assert trust_remote_code, "Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool."
assert trust_remote_code, (
"Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool."
)
hub_kwargs_names = [
"cache_dir",
@ -620,13 +622,10 @@ class Tool:
arg.save(temp_file.name)
arg = temp_file.name
if (
isinstance(arg, str)
and os.path.isfile(arg)
) or (
isinstance(arg, Path)
and arg.exists()
and arg.is_file()
) or is_http_url_like(arg):
(isinstance(arg, str) and os.path.isfile(arg))
or (isinstance(arg, Path) and arg.exists() and arg.is_file())
or is_http_url_like(arg)
):
arg = handle_file(arg)
return arg

View File

@ -99,7 +99,7 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]:
raise ValueError(
f"The JSON blob you used is invalid due to the following error: {e}.\n"
f"JSON blob was: {json_blob}, decoding failed on that specific part of the blob:\n"
f"'{json_blob[place-4:place+5]}'."
f"'{json_blob[place - 4 : place + 5]}'."
)
except Exception as e:
raise ValueError(f"Error in parsing the JSON blob: {e}")