diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index 66bbdef..93c64fb 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -82,7 +82,7 @@ class AgentStep: @dataclass class ActionStep(AgentStep): agent_memory: List[Dict[str, str]] | None = None - tool_call: ToolCall | None = None + tool_calls: List[ToolCall] | None = None start_time: float | None = None end_time: float | None = None step: int | None = None @@ -302,25 +302,26 @@ class MultiStepAgent: } memory.append(thought_message) - if step_log.tool_call is not None: + if step_log.tool_calls is not None: tool_call_message = { "role": MessageRole.ASSISTANT, "content": str( [ { - "id": step_log.tool_call.id, + "id": tool_call.id, "type": "function", "function": { - "name": step_log.tool_call.name, - "arguments": step_log.tool_call.arguments, + "name": tool_call.name, + "arguments": tool_call.arguments, }, } + for tool_call in step_log.tool_calls ] ), } memory.append(tool_call_message) - if step_log.tool_call is None and step_log.error is not None: + if step_log.tool_calls is None and step_log.error is not None: message_content = ( "Error:\n" + str(step_log.error) @@ -330,7 +331,7 @@ class MultiStepAgent: "role": MessageRole.ASSISTANT, "content": message_content, } - if step_log.tool_call is not None and ( + if step_log.tool_calls is not None and ( step_log.error is not None or step_log.observations is not None ): if step_log.error is not None: @@ -343,7 +344,7 @@ class MultiStepAgent: message_content = f"Observation:\n{step_log.observations}" tool_response_message = { "role": MessageRole.TOOL_RESPONSE, - "content": f"Call id: {(step_log.tool_call.id if getattr(step_log.tool_call, 'id') else 'call_0')}\n" + "content": f"Call id: {(step_log.tool_calls[0].id if getattr(step_log.tool_calls[0], 'id') else 'call_0')}\n" + message_content, } memory.append(tool_response_message) @@ -814,9 +815,9 @@ class ToolCallingAgent(MultiStepAgent): f"Error in generating tool call with model:\n{e}" ) - log_entry.tool_call = ToolCall( - name=tool_name, arguments=tool_arguments, id=tool_call_id - ) + log_entry.tool_calls = [ + ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id) + ] # Execute self.logger.log( @@ -977,11 +978,13 @@ class CodeAgent(MultiStepAgent): ) raise AgentParsingError(error_msg) - log_entry.tool_call = ToolCall( - name="python_interpreter", - arguments=code_action, - id=f"call_{len(self.logs)}", - ) + log_entry.tool_calls = [ + ToolCall( + name="python_interpreter", + arguments=code_action, + id=f"call_{len(self.logs)}", + ) + ] # Execute self.logger.log( diff --git a/src/smolagents/monitoring.py b/src/smolagents/monitoring.py index b6ba78f..13de796 100644 --- a/src/smolagents/monitoring.py +++ b/src/smolagents/monitoring.py @@ -43,9 +43,7 @@ class Monitor: def update_metrics(self, step_log): step_duration = step_log.duration self.step_durations.append(step_duration) - console_outputs = ( - f"[Step {len(self.step_durations)-1}: Duration {step_duration:.2f} seconds" - ) + console_outputs = f"[Step {len(self.step_durations) - 1}: Duration {step_duration:.2f} seconds" if getattr(self.tracked_model, "last_input_token_count", None) is not None: self.total_input_token_count += self.tracked_model.last_input_token_count diff --git a/src/smolagents/tools.py b/src/smolagents/tools.py index fd71662..d5ec6b0 100644 --- a/src/smolagents/tools.py +++ b/src/smolagents/tools.py @@ -179,12 +179,12 @@ class Tool: f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead." ) for input_name, input_content in self.inputs.items(): - assert isinstance( - input_content, dict - ), f"Input '{input_name}' should be a dictionary." - assert ( - "type" in input_content and "description" in input_content - ), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}." + assert isinstance(input_content, dict), ( + f"Input '{input_name}' should be a dictionary." + ) + assert "type" in input_content and "description" in input_content, ( + f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}." + ) if input_content["type"] not in AUTHORIZED_TYPES: raise Exception( f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {AUTHORIZED_TYPES}." @@ -207,13 +207,13 @@ class Tool: json_schema = _convert_type_hints_to_json_schema(self.forward) for key, value in self.inputs.items(): if "nullable" in value: - assert ( - key in json_schema and "nullable" in json_schema[key] - ), f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature." + assert key in json_schema and "nullable" in json_schema[key], ( + f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature." + ) if key in json_schema and "nullable" in json_schema[key]: - assert ( - "nullable" in value - ), f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs." + assert "nullable" in value, ( + f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs." + ) def forward(self, *args, **kwargs): return NotImplementedError("Write this method in your subclass of `Tool`.") @@ -272,7 +272,7 @@ class Tool: class {class_name}(Tool): name = "{self.name}" description = "{self.description}" - inputs = {json.dumps(self.inputs, separators=(',', ':'))} + inputs = {json.dumps(self.inputs, separators=(",", ":"))} output_type = "{self.output_type}" """).strip() import re @@ -439,7 +439,9 @@ class Tool: `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others will be passed along to its init. """ - assert trust_remote_code, "Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool." + assert trust_remote_code, ( + "Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool." + ) hub_kwargs_names = [ "cache_dir", @@ -620,13 +622,10 @@ class Tool: arg.save(temp_file.name) arg = temp_file.name if ( - isinstance(arg, str) - and os.path.isfile(arg) - ) or ( - isinstance(arg, Path) - and arg.exists() - and arg.is_file() - ) or is_http_url_like(arg): + (isinstance(arg, str) and os.path.isfile(arg)) + or (isinstance(arg, Path) and arg.exists() and arg.is_file()) + or is_http_url_like(arg) + ): arg = handle_file(arg) return arg diff --git a/src/smolagents/utils.py b/src/smolagents/utils.py index 3983b40..e3ea23c 100644 --- a/src/smolagents/utils.py +++ b/src/smolagents/utils.py @@ -99,7 +99,7 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]: raise ValueError( f"The JSON blob you used is invalid due to the following error: {e}.\n" f"JSON blob was: {json_blob}, decoding failed on that specific part of the blob:\n" - f"'{json_blob[place-4:place+5]}'." + f"'{json_blob[place - 4 : place + 5]}'." ) except Exception as e: raise ValueError(f"Error in parsing the JSON blob: {e}")