Log list of tool calls in ActionStep (#172)
This commit is contained in:
parent
5a62304c91
commit
289c06df0f
|
@ -82,7 +82,7 @@ class AgentStep:
|
||||||
@dataclass
|
@dataclass
|
||||||
class ActionStep(AgentStep):
|
class ActionStep(AgentStep):
|
||||||
agent_memory: List[Dict[str, str]] | None = None
|
agent_memory: List[Dict[str, str]] | None = None
|
||||||
tool_call: ToolCall | None = None
|
tool_calls: List[ToolCall] | None = None
|
||||||
start_time: float | None = None
|
start_time: float | None = None
|
||||||
end_time: float | None = None
|
end_time: float | None = None
|
||||||
step: int | None = None
|
step: int | None = None
|
||||||
|
@ -302,25 +302,26 @@ class MultiStepAgent:
|
||||||
}
|
}
|
||||||
memory.append(thought_message)
|
memory.append(thought_message)
|
||||||
|
|
||||||
if step_log.tool_call is not None:
|
if step_log.tool_calls is not None:
|
||||||
tool_call_message = {
|
tool_call_message = {
|
||||||
"role": MessageRole.ASSISTANT,
|
"role": MessageRole.ASSISTANT,
|
||||||
"content": str(
|
"content": str(
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"id": step_log.tool_call.id,
|
"id": tool_call.id,
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": step_log.tool_call.name,
|
"name": tool_call.name,
|
||||||
"arguments": step_log.tool_call.arguments,
|
"arguments": tool_call.arguments,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
for tool_call in step_log.tool_calls
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
memory.append(tool_call_message)
|
memory.append(tool_call_message)
|
||||||
|
|
||||||
if step_log.tool_call is None and step_log.error is not None:
|
if step_log.tool_calls is None and step_log.error is not None:
|
||||||
message_content = (
|
message_content = (
|
||||||
"Error:\n"
|
"Error:\n"
|
||||||
+ str(step_log.error)
|
+ str(step_log.error)
|
||||||
|
@ -330,7 +331,7 @@ class MultiStepAgent:
|
||||||
"role": MessageRole.ASSISTANT,
|
"role": MessageRole.ASSISTANT,
|
||||||
"content": message_content,
|
"content": message_content,
|
||||||
}
|
}
|
||||||
if step_log.tool_call is not None and (
|
if step_log.tool_calls is not None and (
|
||||||
step_log.error is not None or step_log.observations is not None
|
step_log.error is not None or step_log.observations is not None
|
||||||
):
|
):
|
||||||
if step_log.error is not None:
|
if step_log.error is not None:
|
||||||
|
@ -343,7 +344,7 @@ class MultiStepAgent:
|
||||||
message_content = f"Observation:\n{step_log.observations}"
|
message_content = f"Observation:\n{step_log.observations}"
|
||||||
tool_response_message = {
|
tool_response_message = {
|
||||||
"role": MessageRole.TOOL_RESPONSE,
|
"role": MessageRole.TOOL_RESPONSE,
|
||||||
"content": f"Call id: {(step_log.tool_call.id if getattr(step_log.tool_call, 'id') else 'call_0')}\n"
|
"content": f"Call id: {(step_log.tool_calls[0].id if getattr(step_log.tool_calls[0], 'id') else 'call_0')}\n"
|
||||||
+ message_content,
|
+ message_content,
|
||||||
}
|
}
|
||||||
memory.append(tool_response_message)
|
memory.append(tool_response_message)
|
||||||
|
@ -814,9 +815,9 @@ class ToolCallingAgent(MultiStepAgent):
|
||||||
f"Error in generating tool call with model:\n{e}"
|
f"Error in generating tool call with model:\n{e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
log_entry.tool_call = ToolCall(
|
log_entry.tool_calls = [
|
||||||
name=tool_name, arguments=tool_arguments, id=tool_call_id
|
ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)
|
||||||
)
|
]
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
self.logger.log(
|
self.logger.log(
|
||||||
|
@ -977,11 +978,13 @@ class CodeAgent(MultiStepAgent):
|
||||||
)
|
)
|
||||||
raise AgentParsingError(error_msg)
|
raise AgentParsingError(error_msg)
|
||||||
|
|
||||||
log_entry.tool_call = ToolCall(
|
log_entry.tool_calls = [
|
||||||
|
ToolCall(
|
||||||
name="python_interpreter",
|
name="python_interpreter",
|
||||||
arguments=code_action,
|
arguments=code_action,
|
||||||
id=f"call_{len(self.logs)}",
|
id=f"call_{len(self.logs)}",
|
||||||
)
|
)
|
||||||
|
]
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
self.logger.log(
|
self.logger.log(
|
||||||
|
|
|
@ -43,9 +43,7 @@ class Monitor:
|
||||||
def update_metrics(self, step_log):
|
def update_metrics(self, step_log):
|
||||||
step_duration = step_log.duration
|
step_duration = step_log.duration
|
||||||
self.step_durations.append(step_duration)
|
self.step_durations.append(step_duration)
|
||||||
console_outputs = (
|
console_outputs = f"[Step {len(self.step_durations) - 1}: Duration {step_duration:.2f} seconds"
|
||||||
f"[Step {len(self.step_durations)-1}: Duration {step_duration:.2f} seconds"
|
|
||||||
)
|
|
||||||
|
|
||||||
if getattr(self.tracked_model, "last_input_token_count", None) is not None:
|
if getattr(self.tracked_model, "last_input_token_count", None) is not None:
|
||||||
self.total_input_token_count += self.tracked_model.last_input_token_count
|
self.total_input_token_count += self.tracked_model.last_input_token_count
|
||||||
|
|
|
@ -179,12 +179,12 @@ class Tool:
|
||||||
f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead."
|
f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead."
|
||||||
)
|
)
|
||||||
for input_name, input_content in self.inputs.items():
|
for input_name, input_content in self.inputs.items():
|
||||||
assert isinstance(
|
assert isinstance(input_content, dict), (
|
||||||
input_content, dict
|
f"Input '{input_name}' should be a dictionary."
|
||||||
), f"Input '{input_name}' should be a dictionary."
|
)
|
||||||
assert (
|
assert "type" in input_content and "description" in input_content, (
|
||||||
"type" in input_content and "description" in input_content
|
f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
|
||||||
), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
|
)
|
||||||
if input_content["type"] not in AUTHORIZED_TYPES:
|
if input_content["type"] not in AUTHORIZED_TYPES:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {AUTHORIZED_TYPES}."
|
f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {AUTHORIZED_TYPES}."
|
||||||
|
@ -207,13 +207,13 @@ class Tool:
|
||||||
json_schema = _convert_type_hints_to_json_schema(self.forward)
|
json_schema = _convert_type_hints_to_json_schema(self.forward)
|
||||||
for key, value in self.inputs.items():
|
for key, value in self.inputs.items():
|
||||||
if "nullable" in value:
|
if "nullable" in value:
|
||||||
assert (
|
assert key in json_schema and "nullable" in json_schema[key], (
|
||||||
key in json_schema and "nullable" in json_schema[key]
|
f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature."
|
||||||
), f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature."
|
)
|
||||||
if key in json_schema and "nullable" in json_schema[key]:
|
if key in json_schema and "nullable" in json_schema[key]:
|
||||||
assert (
|
assert "nullable" in value, (
|
||||||
"nullable" in value
|
f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs."
|
||||||
), f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs."
|
)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
return NotImplementedError("Write this method in your subclass of `Tool`.")
|
return NotImplementedError("Write this method in your subclass of `Tool`.")
|
||||||
|
@ -272,7 +272,7 @@ class Tool:
|
||||||
class {class_name}(Tool):
|
class {class_name}(Tool):
|
||||||
name = "{self.name}"
|
name = "{self.name}"
|
||||||
description = "{self.description}"
|
description = "{self.description}"
|
||||||
inputs = {json.dumps(self.inputs, separators=(',', ':'))}
|
inputs = {json.dumps(self.inputs, separators=(",", ":"))}
|
||||||
output_type = "{self.output_type}"
|
output_type = "{self.output_type}"
|
||||||
""").strip()
|
""").strip()
|
||||||
import re
|
import re
|
||||||
|
@ -439,7 +439,9 @@ class Tool:
|
||||||
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
|
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
|
||||||
others will be passed along to its init.
|
others will be passed along to its init.
|
||||||
"""
|
"""
|
||||||
assert trust_remote_code, "Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool."
|
assert trust_remote_code, (
|
||||||
|
"Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool."
|
||||||
|
)
|
||||||
|
|
||||||
hub_kwargs_names = [
|
hub_kwargs_names = [
|
||||||
"cache_dir",
|
"cache_dir",
|
||||||
|
@ -620,13 +622,10 @@ class Tool:
|
||||||
arg.save(temp_file.name)
|
arg.save(temp_file.name)
|
||||||
arg = temp_file.name
|
arg = temp_file.name
|
||||||
if (
|
if (
|
||||||
isinstance(arg, str)
|
(isinstance(arg, str) and os.path.isfile(arg))
|
||||||
and os.path.isfile(arg)
|
or (isinstance(arg, Path) and arg.exists() and arg.is_file())
|
||||||
) or (
|
or is_http_url_like(arg)
|
||||||
isinstance(arg, Path)
|
):
|
||||||
and arg.exists()
|
|
||||||
and arg.is_file()
|
|
||||||
) or is_http_url_like(arg):
|
|
||||||
arg = handle_file(arg)
|
arg = handle_file(arg)
|
||||||
return arg
|
return arg
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue