diff --git a/docs/source/en/guided_tour.md b/docs/source/en/guided_tour.md index 4e2fe44..b446245 100644 --- a/docs/source/en/guided_tour.md +++ b/docs/source/en/guided_tour.md @@ -113,8 +113,7 @@ The Python interpreter also doesn't allow imports by default outside of a safe l You can authorize additional imports by passing the authorized modules as a list of strings in argument `additional_authorized_imports` upon initialization of your [`CodeAgent`]: ```py -from smolagents import CodeAgent - +model = HfApiModel() agent = CodeAgent(tools=[], model=model, additional_authorized_imports=['requests', 'bs4']) agent.run("Could you get me the title of the page at url 'https://huggingface.co/blog'?") ``` @@ -164,12 +163,12 @@ Transformers comes with a default toolbox for empowering agents, that you can ad - **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`ToolCallingAgent`] if you initialize it with `add_base_tools=True`, since code-based agent can already natively execute Python code - **Transcriber**: a speech-to-text pipeline built on Whisper-Turbo that transcribes an audio to text. -You can manually use a tool by calling the [`load_tool`] function and a task to perform. +You can manually use a tool by calling it with its arguments. ```python -from smolagents import load_tool +from smolagents import DuckDuckGoSearchTool -search_tool = load_tool("web_search") +search_tool = DuckDuckGoSearchTool() print(search_tool("Who's the current president of Russia?")) ``` diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index ab320e5..4d7ddf6 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -776,11 +776,15 @@ class ToolCallingAgent(MultiStepAgent): log_entry.agent_memory = agent_memory.copy() try: - tool_name, tool_arguments, tool_call_id = self.model.get_tool_call( + model_message = self.model( self.input_messages, - available_tools=list(self.tools.values()), + tools_to_call_from=list(self.tools.values()), stop_sequences=["Observation:"], ) + tool_calls = model_message.tool_calls[0] + tool_arguments = tool_calls.function.arguments + tool_name, tool_call_id = tool_calls.function.name, tool_calls.id + except Exception as e: raise AgentGenerationError( f"Error in generating tool call with model:\n{e}" @@ -913,7 +917,7 @@ class CodeAgent(MultiStepAgent): self.input_messages, stop_sequences=["", "Observation:"], **additional_args, - ) + ).content log_entry.llm_output = llm_output except Exception as e: raise AgentGenerationError(f"Error in generating model output:\n{e}") diff --git a/src/smolagents/models.py b/src/smolagents/models.py index a7bea46..403e9fa 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -20,10 +20,16 @@ import os import random from copy import deepcopy from enum import Enum -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional import torch -from huggingface_hub import InferenceClient +from huggingface_hub import ( + InferenceClient, + ChatCompletionOutputMessage, + ChatCompletionOutputToolCall, + ChatCompletionOutputFunctionDefinition, +) + from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -33,7 +39,6 @@ from transformers import ( import openai from .tools import Tool -from .utils import parse_json_tool_call logger = logging.getLogger(__name__) @@ -234,63 +239,46 @@ class HfApiModel(Model): model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct", token: Optional[str] = None, timeout: Optional[int] = 120, + temperature: float = 0.5, ): super().__init__() self.model_id = model_id if token is None: token = os.getenv("HF_TOKEN") self.client = InferenceClient(self.model_id, token=token, timeout=timeout) + self.temperature = temperature - def generate( + def __call__( self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None, max_tokens: int = 1500, + tools_to_call_from: Optional[List[Tool]] = None, ) -> str: - """Generates a text completion for the given message list""" messages = get_clean_message_list( messages, role_conversions=tool_role_conversions ) - - # Send messages to the Hugging Face Inference API - if grammar is not None: - output = self.client.chat_completion( - messages, + if tools_to_call_from: + response = self.client.chat.completions.create( + messages=messages, + tools=[get_json_schema(tool) for tool in tools_to_call_from], + tool_choice="auto", stop=stop_sequences, - response_format=grammar, max_tokens=max_tokens, + temperature=self.temperature, ) else: - output = self.client.chat.completions.create( - messages, stop=stop_sequences, max_tokens=max_tokens + response = self.client.chat.completions.create( + model=self.model_id, + messages=messages, + stop=stop_sequences, + max_tokens=max_tokens, + temperature=self.temperature, ) - - response = output.choices[0].message.content - self.last_input_token_count = output.usage.prompt_tokens - self.last_output_token_count = output.usage.completion_tokens - return response - - def get_tool_call( - self, - messages: List[Dict[str, str]], - available_tools: List[Tool], - stop_sequences, - ): - """Generates a tool call for the given message list. This method is used only by `ToolCallingAgent`.""" - messages = get_clean_message_list( - messages, role_conversions=tool_role_conversions - ) - response = self.client.chat.completions.create( - messages=messages, - tools=[get_json_schema(tool) for tool in available_tools], - tool_choice="auto", - stop=stop_sequences, - ) - tool_call = response.choices[0].message.tool_calls[0] self.last_input_token_count = response.usage.prompt_tokens self.last_output_token_count = response.usage.completion_tokens - return tool_call.function.name, tool_call.function.arguments, tool_call.id + return response.choices[0].message class TransformersModel(Model): @@ -354,23 +342,32 @@ class TransformersModel(Model): return StoppingCriteriaList([StopOnStrings(stop_sequences, self.tokenizer)]) - def generate( + def __call__( self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None, max_tokens: int = 1500, + tools_to_call_from: Optional[List[Tool]] = None, ) -> str: messages = get_clean_message_list( messages, role_conversions=tool_role_conversions ) - # Get LLM output - prompt_tensor = self.tokenizer.apply_chat_template( - messages, - return_tensors="pt", - return_dict=True, - ) + if tools_to_call_from is not None: + prompt_tensor = self.tokenizer.apply_chat_template( + messages, + tools=[get_json_schema(tool) for tool in tools_to_call_from], + return_tensors="pt", + return_dict=True, + add_generation_prompt=True, + ) + else: + prompt_tensor = self.tokenizer.apply_chat_template( + messages, + return_tensors="pt", + return_dict=True, + ) prompt_tensor = prompt_tensor.to(self.model.device) count_prompt_tokens = prompt_tensor["input_ids"].shape[1] @@ -382,56 +379,31 @@ class TransformersModel(Model): ), ) generated_tokens = out[0, count_prompt_tokens:] - response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) + output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) self.last_input_token_count = count_prompt_tokens self.last_output_token_count = len(generated_tokens) if stop_sequences is not None: - response = remove_stop_sequences(response, stop_sequences) - return response + output = remove_stop_sequences(output, stop_sequences) - def get_tool_call( - self, - messages: List[Dict[str, str]], - available_tools: List[Tool], - stop_sequences: Optional[List[str]] = None, - max_tokens: int = 500, - ) -> Tuple[str, Union[str, None], str]: - messages = get_clean_message_list( - messages, role_conversions=tool_role_conversions - ) - - prompt = self.tokenizer.apply_chat_template( - messages, - tools=[get_json_schema(tool) for tool in available_tools], - return_tensors="pt", - return_dict=True, - add_generation_prompt=True, - ) - prompt = prompt.to(self.model.device) - count_prompt_tokens = prompt["input_ids"].shape[1] - - out = self.model.generate( - **prompt, - max_new_tokens=max_tokens, - stopping_criteria=( - self.make_stopping_criteria(stop_sequences) if stop_sequences else None - ), - ) - generated_tokens = out[0, count_prompt_tokens:] - response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) - - self.last_input_token_count = count_prompt_tokens - self.last_output_token_count = len(generated_tokens) - - if stop_sequences is not None: - response = remove_stop_sequences(response, stop_sequences) - - tool_name, tool_input = parse_json_tool_call(response) - call_id = "".join(random.choices("0123456789", k=5)) - - return tool_name, tool_input, call_id + if tools_to_call_from is None: + return ChatCompletionOutputMessage(role="assistant", content=output) + else: + tool_name, tool_arguments = json.load(output) + return ChatCompletionOutputMessage( + role="assistant", + content="", + tool_calls=[ + ChatCompletionOutputToolCall( + id="".join(random.choices("0123456789", k=5)), + type="function", + function=ChatCompletionOutputFunctionDefinition( + name=tool_name, arguments=tool_arguments + ), + ) + ], + ) class LiteLLMModel(Model): @@ -460,50 +432,36 @@ class LiteLLMModel(Model): stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None, max_tokens: int = 1500, + tools_to_call_from: Optional[List[Tool]] = None, ) -> str: messages = get_clean_message_list( messages, role_conversions=tool_role_conversions ) - - response = litellm.completion( - model=self.model_id, - messages=messages, - stop=stop_sequences, - max_tokens=max_tokens, - api_base=self.api_base, - api_key=self.api_key, - **self.kwargs, - ) + if tools_to_call_from: + response = litellm.completion( + model=self.model_id, + messages=messages, + tools=[get_json_schema(tool) for tool in tools_to_call_from], + tool_choice="required", + stop=stop_sequences, + max_tokens=max_tokens, + api_base=self.api_base, + api_key=self.api_key, + **self.kwargs, + ) + else: + response = litellm.completion( + model=self.model_id, + messages=messages, + stop=stop_sequences, + max_tokens=max_tokens, + api_base=self.api_base, + api_key=self.api_key, + **self.kwargs, + ) self.last_input_token_count = response.usage.prompt_tokens self.last_output_token_count = response.usage.completion_tokens - return response.choices[0].message.content - - def get_tool_call( - self, - messages: List[Dict[str, str]], - available_tools: List[Tool], - stop_sequences: Optional[List[str]] = None, - max_tokens: int = 1500, - ): - messages = get_clean_message_list( - messages, role_conversions=tool_role_conversions - ) - response = litellm.completion( - model=self.model_id, - messages=messages, - tools=[get_json_schema(tool) for tool in available_tools], - tool_choice="required", - stop=stop_sequences, - max_tokens=max_tokens, - api_base=self.api_base, - api_key=self.api_key, - **self.kwargs, - ) - tool_calls = response.choices[0].message.tool_calls[0] - self.last_input_token_count = response.usage.prompt_tokens - self.last_output_token_count = response.usage.completion_tokens - arguments = json.loads(tool_calls.function.arguments) - return tool_calls.function.name, arguments, tool_calls.id + return response.choices[0].message class OpenAIServerModel(Model): @@ -539,64 +497,40 @@ class OpenAIServerModel(Model): self.temperature = temperature self.kwargs = kwargs - def generate( + def __call__( self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None, max_tokens: int = 1500, + tools_to_call_from: Optional[List[Tool]] = None, ) -> str: - """Generates a text completion for the given message list""" messages = get_clean_message_list( messages, role_conversions=tool_role_conversions ) - - response = self.client.chat.completions.create( - model=self.model_id, - messages=messages, - stop=stop_sequences, - max_tokens=max_tokens, - temperature=self.temperature, - **self.kwargs, - ) - + if tools_to_call_from: + response = self.client.chat.completions.create( + model=self.model_id, + messages=messages, + tools=[get_json_schema(tool) for tool in tools_to_call_from], + tool_choice="auto", + stop=stop_sequences, + max_tokens=max_tokens, + temperature=self.temperature, + **self.kwargs, + ) + else: + response = self.client.chat.completions.create( + model=self.model_id, + messages=messages, + stop=stop_sequences, + max_tokens=max_tokens, + temperature=self.temperature, + **self.kwargs, + ) self.last_input_token_count = response.usage.prompt_tokens self.last_output_token_count = response.usage.completion_tokens - return response.choices[0].message.content - - def get_tool_call( - self, - messages: List[Dict[str, str]], - available_tools: List[Tool], - stop_sequences: Optional[List[str]] = None, - max_tokens: int = 500, - ) -> Tuple[str, Union[str, Dict], str]: - """Generates a tool call for the given message list""" - messages = get_clean_message_list( - messages, role_conversions=tool_role_conversions - ) - - response = self.client.chat.completions.create( - model=self.model_id, - messages=messages, - tools=[get_json_schema(tool) for tool in available_tools], - tool_choice="auto", - stop=stop_sequences, - max_tokens=max_tokens, - temperature=self.temperature, - **self.kwargs, - ) - - tool_calls = response.choices[0].message.tool_calls[0] - self.last_input_token_count = response.usage.prompt_tokens - self.last_output_token_count = response.usage.completion_tokens - - try: - arguments = json.loads(tool_calls.function.arguments) - except json.JSONDecodeError: - arguments = tool_calls.function.arguments - - return tool_calls.function.name, arguments, tool_calls.id + return response.choices[0].message __all__ = [ diff --git a/tests/test_agents.py b/tests/test_agents.py index 9327285..f51ce9f 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -30,6 +30,11 @@ from smolagents.agents import ( from smolagents.default_tools import PythonInterpreterTool from smolagents.tools import tool from smolagents.types import AgentImage, AgentText +from huggingface_hub import ( + ChatCompletionOutputMessage, + ChatCompletionOutputToolCall, + ChatCompletionOutputFunctionDefinition, +) def get_new_path(suffix="") -> str: @@ -38,54 +43,106 @@ def get_new_path(suffix="") -> str: class FakeToolCallModel: - def get_tool_call( - self, messages, available_tools, stop_sequences=None, grammar=None + def __call__( + self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None ): if len(messages) < 3: - return "python_interpreter", {"code": "2*3.6452"}, "call_0" + return ChatCompletionOutputMessage( + role="assistant", + content="", + tool_calls=[ + ChatCompletionOutputToolCall( + id="call_0", + type="function", + function=ChatCompletionOutputFunctionDefinition( + name="python_interpreter", arguments={"code": "2*3.6452"} + ), + ) + ], + ) else: - return "final_answer", {"answer": "7.2904"}, "call_1" + return ChatCompletionOutputMessage( + role="assistant", + content="", + tool_calls=[ + ChatCompletionOutputToolCall( + id="call_1", + type="function", + function=ChatCompletionOutputFunctionDefinition( + name="final_answer", arguments={"answer": "7.2904"} + ), + ) + ], + ) class FakeToolCallModelImage: - def get_tool_call( - self, messages, available_tools, stop_sequences=None, grammar=None + def __call__( + self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None ): if len(messages) < 3: - return ( - "fake_image_generation_tool", - {"prompt": "An image of a cat"}, - "call_0", + return ChatCompletionOutputMessage( + role="assistant", + content="", + tool_calls=[ + ChatCompletionOutputToolCall( + id="call_0", + type="function", + function=ChatCompletionOutputFunctionDefinition( + name="fake_image_generation_tool", + arguments={"prompt": "An image of a cat"}, + ), + ) + ], + ) + else: + return ChatCompletionOutputMessage( + role="assistant", + content="", + tool_calls=[ + ChatCompletionOutputToolCall( + id="call_1", + type="function", + function=ChatCompletionOutputFunctionDefinition( + name="final_answer", arguments="image.png" + ), + ) + ], ) - - else: # We're at step 2 - return "final_answer", "image.png", "call_1" def fake_code_model(messages, stop_sequences=None, grammar=None) -> str: prompt = str(messages) if "special_marker" not in prompt: - return """ + return ChatCompletionOutputMessage( + role="assistant", + content=""" Thought: I should multiply 2 by 3.6452. special_marker Code: ```py result = 2**3.6452 ``` -""" +""", + ) else: # We're at step 2 - return """ + return ChatCompletionOutputMessage( + role="assistant", + content=""" Thought: I can now answer the initial question Code: ```py final_answer(7.2904) ``` -""" +""", + ) def fake_code_model_error(messages, stop_sequences=None) -> str: prompt = str(messages) if "special_marker" not in prompt: - return """ + return ChatCompletionOutputMessage( + role="assistant", + content=""" Thought: I should multiply 2 by 3.6452. special_marker Code: ```py @@ -94,21 +151,27 @@ b = a * 2 print = 2 print("Ok, calculation done!") ``` -""" +""", + ) else: # We're at step 2 - return """ + return ChatCompletionOutputMessage( + role="assistant", + content=""" Thought: I can now answer the initial question Code: ```py final_answer("got an error") ``` -""" +""", + ) def fake_code_model_syntax_error(messages, stop_sequences=None) -> str: prompt = str(messages) if "special_marker" not in prompt: - return """ + return ChatCompletionOutputMessage( + role="assistant", + content=""" Thought: I should multiply 2 by 3.6452. special_marker Code: ```py @@ -117,32 +180,41 @@ b = a * 2 print("Failing due to unexpected indent") print("Ok, calculation done!") ``` -""" +""", + ) else: # We're at step 2 - return """ + return ChatCompletionOutputMessage( + role="assistant", + content=""" Thought: I can now answer the initial question Code: ```py final_answer("got an error") ``` -""" +""", + ) def fake_code_model_import(messages, stop_sequences=None) -> str: - return """ + return ChatCompletionOutputMessage( + role="assistant", + content=""" Thought: I can answer the question Code: ```py import numpy as np final_answer("got an error") ``` -""" +""", + ) def fake_code_functiondef(messages, stop_sequences=None) -> str: prompt = str(messages) if "special_marker" not in prompt: - return """ + return ChatCompletionOutputMessage( + role="assistant", + content=""" Thought: Let's define the function. special_marker Code: ```py @@ -151,9 +223,12 @@ import numpy as np def moving_average(x, w): return np.convolve(x, np.ones(w), 'valid') / w ``` -""" +""", + ) else: # We're at step 2 - return """ + return ChatCompletionOutputMessage( + role="assistant", + content=""" Thought: I can now answer the initial question Code: ```py @@ -161,29 +236,36 @@ x, w = [0, 1, 2, 3, 4, 5], 2 res = moving_average(x, w) final_answer(res) ``` -""" +""", + ) def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) -> str: - return """ + return ChatCompletionOutputMessage( + role="assistant", + content=""" Thought: I should multiply 2 by 3.6452. special_marker Code: ```py result = python_interpreter(code="2*3.6452") final_answer(result) ``` -""" +""", + ) def fake_code_model_no_return(messages, stop_sequences=None, grammar=None) -> str: - return """ + return ChatCompletionOutputMessage( + role="assistant", + content=""" Thought: I should multiply 2 by 3.6452. special_marker Code: ```py result = python_interpreter(code="2*3.6452") print(result) ``` -""" +""", + ) class AgentTests(unittest.TestCase): @@ -360,52 +442,92 @@ class AgentTests(unittest.TestCase): def test_multiagents(self): class FakeModelMultiagentsManagerAgent: - def __call__(self, messages, stop_sequences=None, grammar=None): - if len(messages) < 3: - return """ + def __call__( + self, + messages, + stop_sequences=None, + grammar=None, + tools_to_call_from=None, + ): + if tools_to_call_from is not None: + if len(messages) < 3: + return ChatCompletionOutputMessage( + role="assistant", + content="", + tool_calls=[ + ChatCompletionOutputToolCall( + id="call_0", + type="function", + function=ChatCompletionOutputFunctionDefinition( + name="search_agent", + arguments="Who is the current US president?", + ), + ) + ], + ) + else: + assert "Report on the current US president" in str(messages) + return ChatCompletionOutputMessage( + role="assistant", + content="", + tool_calls=[ + ChatCompletionOutputToolCall( + id="call_0", + type="function", + function=ChatCompletionOutputFunctionDefinition( + name="final_answer", arguments="Final report." + ), + ) + ], + ) + else: + if len(messages) < 3: + return ChatCompletionOutputMessage( + role="assistant", + content=""" Thought: Let's call our search agent. Code: ```py result = search_agent("Who is the current US president?") ``` -""" - else: - assert "Report on the current US president" in str(messages) - return """ +""", + ) + else: + assert "Report on the current US president" in str(messages) + return ChatCompletionOutputMessage( + role="assistant", + content=""" Thought: Let's return the report. Code: ```py final_answer("Final report.") ``` -""" - - def get_tool_call( - self, messages, available_tools, stop_sequences=None, grammar=None - ): - if len(messages) < 3: - return ( - "search_agent", - "Who is the current US president?", - "call_0", - ) - else: - assert "Report on the current US president" in str(messages) - return ( - "final_answer", - "Final report.", - "call_0", - ) +""", + ) manager_model = FakeModelMultiagentsManagerAgent() class FakeModelMultiagentsManagedAgent: - def get_tool_call( - self, messages, available_tools, stop_sequences=None, grammar=None + def __call__( + self, + messages, + tools_to_call_from=None, + stop_sequences=None, + grammar=None, ): - return ( - "final_answer", - {"report": "Report on the current US president"}, - "call_0", + return ChatCompletionOutputMessage( + role="assistant", + content="", + tool_calls=[ + ChatCompletionOutputToolCall( + id="call_0", + type="function", + function=ChatCompletionOutputFunctionDefinition( + name="final_answer", + arguments="Report on the current US president", + ), + ) + ], ) managed_model = FakeModelMultiagentsManagedAgent() @@ -443,13 +565,16 @@ final_answer("Final report.") def test_code_nontrivial_final_answer_works(self): def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None): - return """Code: + return ChatCompletionOutputMessage( + role="assistant", + content="""Code: ```py def nested_answer(): final_answer("Correct!") nested_answer() -```""" +```""", + ) agent = CodeAgent(tools=[], model=fake_code_model_final_answer) diff --git a/tests/test_all_docs.py b/tests/test_all_docs.py index 3433352..3ad901d 100644 --- a/tests/test_all_docs.py +++ b/tests/test_all_docs.py @@ -92,7 +92,6 @@ class TestDocs: raise ValueError(f"Docs directory not found at {cls.docs_dir}") load_dotenv() - cls.hf_token = os.getenv("HF_TOKEN") cls.md_files = list(cls.docs_dir.rglob("*.md")) if not cls.md_files: @@ -115,6 +114,7 @@ class TestDocs: "from_langchain", # Langchain is not a dependency "while llm_should_continue(memory):", # This is pseudo code "ollama_chat/llama3.2", # Exclude ollama building in guided tour + "model = TransformersModel(model_id=model_id)", # Exclude testing with transformers model ] code_blocks = [ block @@ -131,10 +131,15 @@ class TestDocs: ast.parse(block) # Create and execute test script + print("\n\nCollected code block:==========\n".join(code_blocks)) try: code_blocks = [ - block.replace("", self.hf_token).replace( - "{your_username}", "m-ric" + ( + block.replace( + "", os.getenv("HF_TOKEN") + ) + .replace("YOUR_ANTHROPIC_API_KEY", os.getenv("ANTHROPIC_API_KEY")) + .replace("{your_username}", "m-ric") ) for block in code_blocks ] diff --git a/tests/test_monitoring.py b/tests/test_monitoring.py index 5f2401d..11594e7 100644 --- a/tests/test_monitoring.py +++ b/tests/test_monitoring.py @@ -22,42 +22,57 @@ from smolagents import ( ToolCallingAgent, stream_to_gradio, ) +from huggingface_hub import ( + ChatCompletionOutputMessage, + ChatCompletionOutputToolCall, + ChatCompletionOutputFunctionDefinition, +) + + +class FakeLLMModel: + def __init__(self): + self.last_input_token_count = 10 + self.last_output_token_count = 20 + + def __call__(self, prompt, tools_to_call_from=None, **kwargs): + if tools_to_call_from is not None: + return ChatCompletionOutputMessage( + role="assistant", + content="", + tool_calls=[ + ChatCompletionOutputToolCall( + id="fake_id", + type="function", + function=ChatCompletionOutputFunctionDefinition( + name="final_answer", arguments={"answer": "image"} + ), + ) + ], + ) + else: + return ChatCompletionOutputMessage( + role="assistant", + content=""" +Code: +```py +final_answer('This is the final answer.') +```""", + ) class MonitoringTester(unittest.TestCase): def test_code_agent_metrics(self): - class FakeLLMModel: - def __init__(self): - self.last_input_token_count = 10 - self.last_output_token_count = 20 - - def __call__(self, prompt, **kwargs): - return """ -Code: -```py -final_answer('This is the final answer.') -```""" - agent = CodeAgent( tools=[], model=FakeLLMModel(), max_steps=1, ) - agent.run("Fake task") self.assertEqual(agent.monitor.total_input_token_count, 10) self.assertEqual(agent.monitor.total_output_token_count, 20) def test_json_agent_metrics(self): - class FakeLLMModel: - def __init__(self): - self.last_input_token_count = 10 - self.last_output_token_count = 20 - - def get_tool_call(self, prompt, **kwargs): - return "final_answer", {"answer": "image"}, "fake_id" - agent = ToolCallingAgent( tools=[], model=FakeLLMModel(), @@ -70,17 +85,19 @@ final_answer('This is the final answer.') self.assertEqual(agent.monitor.total_output_token_count, 20) def test_code_agent_metrics_max_steps(self): - class FakeLLMModel: + class FakeLLMModelMalformedAnswer: def __init__(self): self.last_input_token_count = 10 self.last_output_token_count = 20 def __call__(self, prompt, **kwargs): - return "Malformed answer" + return ChatCompletionOutputMessage( + role="assistant", content="Malformed answer" + ) agent = CodeAgent( tools=[], - model=FakeLLMModel(), + model=FakeLLMModelMalformedAnswer(), max_steps=1, ) @@ -90,7 +107,7 @@ final_answer('This is the final answer.') self.assertEqual(agent.monitor.total_output_token_count, 40) def test_code_agent_metrics_generation_error(self): - class FakeLLMModel: + class FakeLLMModelGenerationException: def __init__(self): self.last_input_token_count = 10 self.last_output_token_count = 20 @@ -102,7 +119,7 @@ final_answer('This is the final answer.') agent = CodeAgent( tools=[], - model=FakeLLMModel(), + model=FakeLLMModelGenerationException(), max_steps=1, ) agent.run("Fake task") @@ -113,16 +130,9 @@ final_answer('This is the final answer.') self.assertEqual(agent.monitor.total_output_token_count, 0) def test_streaming_agent_text_output(self): - def dummy_model(prompt, **kwargs): - return """ -Code: -```py -final_answer('This is the final answer.') -```""" - agent = CodeAgent( tools=[], - model=dummy_model, + model=FakeLLMModel(), max_steps=1, ) @@ -135,16 +145,9 @@ final_answer('This is the final answer.') self.assertIn("This is the final answer.", final_message.content) def test_streaming_agent_image_output(self): - class FakeLLM: - def __init__(self): - pass - - def get_tool_call(self, messages, **kwargs): - return "final_answer", {"answer": "image"}, "fake_id" - agent = ToolCallingAgent( tools=[], - model=FakeLLM(), + model=FakeLLMModel(), max_steps=1, )