Vastly simplify Model class (#146)

* Vastly simplify Model class by making only one __call__ method 
This commit is contained in:
Aymeric Roucher 2025-01-10 12:30:59 +01:00 committed by GitHub
parent 36ed279c85
commit 5c33130fa4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 364 additions and 294 deletions

View File

@ -113,8 +113,7 @@ The Python interpreter also doesn't allow imports by default outside of a safe l
You can authorize additional imports by passing the authorized modules as a list of strings in argument `additional_authorized_imports` upon initialization of your [`CodeAgent`]:
```py
from smolagents import CodeAgent
model = HfApiModel()
agent = CodeAgent(tools=[], model=model, additional_authorized_imports=['requests', 'bs4'])
agent.run("Could you get me the title of the page at url 'https://huggingface.co/blog'?")
```
@ -164,12 +163,12 @@ Transformers comes with a default toolbox for empowering agents, that you can ad
- **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`ToolCallingAgent`] if you initialize it with `add_base_tools=True`, since code-based agent can already natively execute Python code
- **Transcriber**: a speech-to-text pipeline built on Whisper-Turbo that transcribes an audio to text.
You can manually use a tool by calling the [`load_tool`] function and a task to perform.
You can manually use a tool by calling it with its arguments.
```python
from smolagents import load_tool
from smolagents import DuckDuckGoSearchTool
search_tool = load_tool("web_search")
search_tool = DuckDuckGoSearchTool()
print(search_tool("Who's the current president of Russia?"))
```

View File

@ -776,11 +776,15 @@ class ToolCallingAgent(MultiStepAgent):
log_entry.agent_memory = agent_memory.copy()
try:
tool_name, tool_arguments, tool_call_id = self.model.get_tool_call(
model_message = self.model(
self.input_messages,
available_tools=list(self.tools.values()),
tools_to_call_from=list(self.tools.values()),
stop_sequences=["Observation:"],
)
tool_calls = model_message.tool_calls[0]
tool_arguments = tool_calls.function.arguments
tool_name, tool_call_id = tool_calls.function.name, tool_calls.id
except Exception as e:
raise AgentGenerationError(
f"Error in generating tool call with model:\n{e}"
@ -913,7 +917,7 @@ class CodeAgent(MultiStepAgent):
self.input_messages,
stop_sequences=["<end_code>", "Observation:"],
**additional_args,
)
).content
log_entry.llm_output = llm_output
except Exception as e:
raise AgentGenerationError(f"Error in generating model output:\n{e}")

View File

@ -20,10 +20,16 @@ import os
import random
from copy import deepcopy
from enum import Enum
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional
import torch
from huggingface_hub import InferenceClient
from huggingface_hub import (
InferenceClient,
ChatCompletionOutputMessage,
ChatCompletionOutputToolCall,
ChatCompletionOutputFunctionDefinition,
)
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
@ -33,7 +39,6 @@ from transformers import (
import openai
from .tools import Tool
from .utils import parse_json_tool_call
logger = logging.getLogger(__name__)
@ -234,63 +239,46 @@ class HfApiModel(Model):
model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct",
token: Optional[str] = None,
timeout: Optional[int] = 120,
temperature: float = 0.5,
):
super().__init__()
self.model_id = model_id
if token is None:
token = os.getenv("HF_TOKEN")
self.client = InferenceClient(self.model_id, token=token, timeout=timeout)
self.temperature = temperature
def generate(
def __call__(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
) -> str:
"""Generates a text completion for the given message list"""
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
# Send messages to the Hugging Face Inference API
if grammar is not None:
output = self.client.chat_completion(
messages,
stop=stop_sequences,
response_format=grammar,
max_tokens=max_tokens,
)
else:
output = self.client.chat.completions.create(
messages, stop=stop_sequences, max_tokens=max_tokens
)
response = output.choices[0].message.content
self.last_input_token_count = output.usage.prompt_tokens
self.last_output_token_count = output.usage.completion_tokens
return response
def get_tool_call(
self,
messages: List[Dict[str, str]],
available_tools: List[Tool],
stop_sequences,
):
"""Generates a tool call for the given message list. This method is used only by `ToolCallingAgent`."""
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
if tools_to_call_from:
response = self.client.chat.completions.create(
messages=messages,
tools=[get_json_schema(tool) for tool in available_tools],
tools=[get_json_schema(tool) for tool in tools_to_call_from],
tool_choice="auto",
stop=stop_sequences,
max_tokens=max_tokens,
temperature=self.temperature,
)
else:
response = self.client.chat.completions.create(
model=self.model_id,
messages=messages,
stop=stop_sequences,
max_tokens=max_tokens,
temperature=self.temperature,
)
tool_call = response.choices[0].message.tool_calls[0]
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
return tool_call.function.name, tool_call.function.arguments, tool_call.id
return response.choices[0].message
class TransformersModel(Model):
@ -354,18 +342,27 @@ class TransformersModel(Model):
return StoppingCriteriaList([StopOnStrings(stop_sequences, self.tokenizer)])
def generate(
def __call__(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
) -> str:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
# Get LLM output
if tools_to_call_from is not None:
prompt_tensor = self.tokenizer.apply_chat_template(
messages,
tools=[get_json_schema(tool) for tool in tools_to_call_from],
return_tensors="pt",
return_dict=True,
add_generation_prompt=True,
)
else:
prompt_tensor = self.tokenizer.apply_chat_template(
messages,
return_tensors="pt",
@ -382,56 +379,31 @@ class TransformersModel(Model):
),
)
generated_tokens = out[0, count_prompt_tokens:]
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
self.last_input_token_count = count_prompt_tokens
self.last_output_token_count = len(generated_tokens)
if stop_sequences is not None:
response = remove_stop_sequences(response, stop_sequences)
return response
output = remove_stop_sequences(output, stop_sequences)
def get_tool_call(
self,
messages: List[Dict[str, str]],
available_tools: List[Tool],
stop_sequences: Optional[List[str]] = None,
max_tokens: int = 500,
) -> Tuple[str, Union[str, None], str]:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
prompt = self.tokenizer.apply_chat_template(
messages,
tools=[get_json_schema(tool) for tool in available_tools],
return_tensors="pt",
return_dict=True,
add_generation_prompt=True,
)
prompt = prompt.to(self.model.device)
count_prompt_tokens = prompt["input_ids"].shape[1]
out = self.model.generate(
**prompt,
max_new_tokens=max_tokens,
stopping_criteria=(
self.make_stopping_criteria(stop_sequences) if stop_sequences else None
if tools_to_call_from is None:
return ChatCompletionOutputMessage(role="assistant", content=output)
else:
tool_name, tool_arguments = json.load(output)
return ChatCompletionOutputMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionOutputToolCall(
id="".join(random.choices("0123456789", k=5)),
type="function",
function=ChatCompletionOutputFunctionDefinition(
name=tool_name, arguments=tool_arguments
),
)
generated_tokens = out[0, count_prompt_tokens:]
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
self.last_input_token_count = count_prompt_tokens
self.last_output_token_count = len(generated_tokens)
if stop_sequences is not None:
response = remove_stop_sequences(response, stop_sequences)
tool_name, tool_input = parse_json_tool_call(response)
call_id = "".join(random.choices("0123456789", k=5))
return tool_name, tool_input, call_id
],
)
class LiteLLMModel(Model):
@ -460,38 +432,16 @@ class LiteLLMModel(Model):
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
) -> str:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
if tools_to_call_from:
response = litellm.completion(
model=self.model_id,
messages=messages,
stop=stop_sequences,
max_tokens=max_tokens,
api_base=self.api_base,
api_key=self.api_key,
**self.kwargs,
)
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
return response.choices[0].message.content
def get_tool_call(
self,
messages: List[Dict[str, str]],
available_tools: List[Tool],
stop_sequences: Optional[List[str]] = None,
max_tokens: int = 1500,
):
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
response = litellm.completion(
model=self.model_id,
messages=messages,
tools=[get_json_schema(tool) for tool in available_tools],
tools=[get_json_schema(tool) for tool in tools_to_call_from],
tool_choice="required",
stop=stop_sequences,
max_tokens=max_tokens,
@ -499,11 +449,19 @@ class LiteLLMModel(Model):
api_key=self.api_key,
**self.kwargs,
)
tool_calls = response.choices[0].message.tool_calls[0]
else:
response = litellm.completion(
model=self.model_id,
messages=messages,
stop=stop_sequences,
max_tokens=max_tokens,
api_base=self.api_base,
api_key=self.api_key,
**self.kwargs,
)
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
arguments = json.loads(tool_calls.function.arguments)
return tool_calls.function.name, arguments, tool_calls.id
return response.choices[0].message
class OpenAIServerModel(Model):
@ -539,64 +497,40 @@ class OpenAIServerModel(Model):
self.temperature = temperature
self.kwargs = kwargs
def generate(
def __call__(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
) -> str:
"""Generates a text completion for the given message list"""
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
if tools_to_call_from:
response = self.client.chat.completions.create(
model=self.model_id,
messages=messages,
stop=stop_sequences,
max_tokens=max_tokens,
temperature=self.temperature,
**self.kwargs,
)
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
return response.choices[0].message.content
def get_tool_call(
self,
messages: List[Dict[str, str]],
available_tools: List[Tool],
stop_sequences: Optional[List[str]] = None,
max_tokens: int = 500,
) -> Tuple[str, Union[str, Dict], str]:
"""Generates a tool call for the given message list"""
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
response = self.client.chat.completions.create(
model=self.model_id,
messages=messages,
tools=[get_json_schema(tool) for tool in available_tools],
tools=[get_json_schema(tool) for tool in tools_to_call_from],
tool_choice="auto",
stop=stop_sequences,
max_tokens=max_tokens,
temperature=self.temperature,
**self.kwargs,
)
tool_calls = response.choices[0].message.tool_calls[0]
else:
response = self.client.chat.completions.create(
model=self.model_id,
messages=messages,
stop=stop_sequences,
max_tokens=max_tokens,
temperature=self.temperature,
**self.kwargs,
)
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
try:
arguments = json.loads(tool_calls.function.arguments)
except json.JSONDecodeError:
arguments = tool_calls.function.arguments
return tool_calls.function.name, arguments, tool_calls.id
return response.choices[0].message
__all__ = [

View File

@ -30,6 +30,11 @@ from smolagents.agents import (
from smolagents.default_tools import PythonInterpreterTool
from smolagents.tools import tool
from smolagents.types import AgentImage, AgentText
from huggingface_hub import (
ChatCompletionOutputMessage,
ChatCompletionOutputToolCall,
ChatCompletionOutputFunctionDefinition,
)
def get_new_path(suffix="") -> str:
@ -38,54 +43,106 @@ def get_new_path(suffix="") -> str:
class FakeToolCallModel:
def get_tool_call(
self, messages, available_tools, stop_sequences=None, grammar=None
def __call__(
self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
):
if len(messages) < 3:
return "python_interpreter", {"code": "2*3.6452"}, "call_0"
return ChatCompletionOutputMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionOutputToolCall(
id="call_0",
type="function",
function=ChatCompletionOutputFunctionDefinition(
name="python_interpreter", arguments={"code": "2*3.6452"}
),
)
],
)
else:
return "final_answer", {"answer": "7.2904"}, "call_1"
return ChatCompletionOutputMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionOutputToolCall(
id="call_1",
type="function",
function=ChatCompletionOutputFunctionDefinition(
name="final_answer", arguments={"answer": "7.2904"}
),
)
],
)
class FakeToolCallModelImage:
def get_tool_call(
self, messages, available_tools, stop_sequences=None, grammar=None
def __call__(
self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
):
if len(messages) < 3:
return (
"fake_image_generation_tool",
{"prompt": "An image of a cat"},
"call_0",
return ChatCompletionOutputMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionOutputToolCall(
id="call_0",
type="function",
function=ChatCompletionOutputFunctionDefinition(
name="fake_image_generation_tool",
arguments={"prompt": "An image of a cat"},
),
)
],
)
else:
return ChatCompletionOutputMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionOutputToolCall(
id="call_1",
type="function",
function=ChatCompletionOutputFunctionDefinition(
name="final_answer", arguments="image.png"
),
)
],
)
else: # We're at step 2
return "final_answer", "image.png", "call_1"
def fake_code_model(messages, stop_sequences=None, grammar=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
return """
return ChatCompletionOutputMessage(
role="assistant",
content="""
Thought: I should multiply 2 by 3.6452. special_marker
Code:
```py
result = 2**3.6452
```<end_code>
"""
""",
)
else: # We're at step 2
return """
return ChatCompletionOutputMessage(
role="assistant",
content="""
Thought: I can now answer the initial question
Code:
```py
final_answer(7.2904)
```<end_code>
"""
""",
)
def fake_code_model_error(messages, stop_sequences=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
return """
return ChatCompletionOutputMessage(
role="assistant",
content="""
Thought: I should multiply 2 by 3.6452. special_marker
Code:
```py
@ -94,21 +151,27 @@ b = a * 2
print = 2
print("Ok, calculation done!")
```<end_code>
"""
""",
)
else: # We're at step 2
return """
return ChatCompletionOutputMessage(
role="assistant",
content="""
Thought: I can now answer the initial question
Code:
```py
final_answer("got an error")
```<end_code>
"""
""",
)
def fake_code_model_syntax_error(messages, stop_sequences=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
return """
return ChatCompletionOutputMessage(
role="assistant",
content="""
Thought: I should multiply 2 by 3.6452. special_marker
Code:
```py
@ -117,32 +180,41 @@ b = a * 2
print("Failing due to unexpected indent")
print("Ok, calculation done!")
```<end_code>
"""
""",
)
else: # We're at step 2
return """
return ChatCompletionOutputMessage(
role="assistant",
content="""
Thought: I can now answer the initial question
Code:
```py
final_answer("got an error")
```<end_code>
"""
""",
)
def fake_code_model_import(messages, stop_sequences=None) -> str:
return """
return ChatCompletionOutputMessage(
role="assistant",
content="""
Thought: I can answer the question
Code:
```py
import numpy as np
final_answer("got an error")
```<end_code>
"""
""",
)
def fake_code_functiondef(messages, stop_sequences=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
return """
return ChatCompletionOutputMessage(
role="assistant",
content="""
Thought: Let's define the function. special_marker
Code:
```py
@ -151,9 +223,12 @@ import numpy as np
def moving_average(x, w):
return np.convolve(x, np.ones(w), 'valid') / w
```<end_code>
"""
""",
)
else: # We're at step 2
return """
return ChatCompletionOutputMessage(
role="assistant",
content="""
Thought: I can now answer the initial question
Code:
```py
@ -161,29 +236,36 @@ x, w = [0, 1, 2, 3, 4, 5], 2
res = moving_average(x, w)
final_answer(res)
```<end_code>
"""
""",
)
def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) -> str:
return """
return ChatCompletionOutputMessage(
role="assistant",
content="""
Thought: I should multiply 2 by 3.6452. special_marker
Code:
```py
result = python_interpreter(code="2*3.6452")
final_answer(result)
```
"""
""",
)
def fake_code_model_no_return(messages, stop_sequences=None, grammar=None) -> str:
return """
return ChatCompletionOutputMessage(
role="assistant",
content="""
Thought: I should multiply 2 by 3.6452. special_marker
Code:
```py
result = python_interpreter(code="2*3.6452")
print(result)
```
"""
""",
)
class AgentTests(unittest.TestCase):
@ -360,52 +442,92 @@ class AgentTests(unittest.TestCase):
def test_multiagents(self):
class FakeModelMultiagentsManagerAgent:
def __call__(self, messages, stop_sequences=None, grammar=None):
def __call__(
self,
messages,
stop_sequences=None,
grammar=None,
tools_to_call_from=None,
):
if tools_to_call_from is not None:
if len(messages) < 3:
return """
return ChatCompletionOutputMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionOutputToolCall(
id="call_0",
type="function",
function=ChatCompletionOutputFunctionDefinition(
name="search_agent",
arguments="Who is the current US president?",
),
)
],
)
else:
assert "Report on the current US president" in str(messages)
return ChatCompletionOutputMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionOutputToolCall(
id="call_0",
type="function",
function=ChatCompletionOutputFunctionDefinition(
name="final_answer", arguments="Final report."
),
)
],
)
else:
if len(messages) < 3:
return ChatCompletionOutputMessage(
role="assistant",
content="""
Thought: Let's call our search agent.
Code:
```py
result = search_agent("Who is the current US president?")
```<end_code>
"""
""",
)
else:
assert "Report on the current US president" in str(messages)
return """
return ChatCompletionOutputMessage(
role="assistant",
content="""
Thought: Let's return the report.
Code:
```py
final_answer("Final report.")
```<end_code>
"""
def get_tool_call(
self, messages, available_tools, stop_sequences=None, grammar=None
):
if len(messages) < 3:
return (
"search_agent",
"Who is the current US president?",
"call_0",
)
else:
assert "Report on the current US president" in str(messages)
return (
"final_answer",
"Final report.",
"call_0",
""",
)
manager_model = FakeModelMultiagentsManagerAgent()
class FakeModelMultiagentsManagedAgent:
def get_tool_call(
self, messages, available_tools, stop_sequences=None, grammar=None
def __call__(
self,
messages,
tools_to_call_from=None,
stop_sequences=None,
grammar=None,
):
return (
"final_answer",
{"report": "Report on the current US president"},
"call_0",
return ChatCompletionOutputMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionOutputToolCall(
id="call_0",
type="function",
function=ChatCompletionOutputFunctionDefinition(
name="final_answer",
arguments="Report on the current US president",
),
)
],
)
managed_model = FakeModelMultiagentsManagedAgent()
@ -443,13 +565,16 @@ final_answer("Final report.")
def test_code_nontrivial_final_answer_works(self):
def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None):
return """Code:
return ChatCompletionOutputMessage(
role="assistant",
content="""Code:
```py
def nested_answer():
final_answer("Correct!")
nested_answer()
```<end_code>"""
```<end_code>""",
)
agent = CodeAgent(tools=[], model=fake_code_model_final_answer)

View File

@ -92,7 +92,6 @@ class TestDocs:
raise ValueError(f"Docs directory not found at {cls.docs_dir}")
load_dotenv()
cls.hf_token = os.getenv("HF_TOKEN")
cls.md_files = list(cls.docs_dir.rglob("*.md"))
if not cls.md_files:
@ -115,6 +114,7 @@ class TestDocs:
"from_langchain", # Langchain is not a dependency
"while llm_should_continue(memory):", # This is pseudo code
"ollama_chat/llama3.2", # Exclude ollama building in guided tour
"model = TransformersModel(model_id=model_id)", # Exclude testing with transformers model
]
code_blocks = [
block
@ -131,10 +131,15 @@ class TestDocs:
ast.parse(block)
# Create and execute test script
print("\n\nCollected code block:==========\n".join(code_blocks))
try:
code_blocks = [
block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", self.hf_token).replace(
"{your_username}", "m-ric"
(
block.replace(
"<YOUR_HUGGINGFACEHUB_API_TOKEN>", os.getenv("HF_TOKEN")
)
.replace("YOUR_ANTHROPIC_API_KEY", os.getenv("ANTHROPIC_API_KEY"))
.replace("{your_username}", "m-ric")
)
for block in code_blocks
]

View File

@ -22,42 +22,57 @@ from smolagents import (
ToolCallingAgent,
stream_to_gradio,
)
from huggingface_hub import (
ChatCompletionOutputMessage,
ChatCompletionOutputToolCall,
ChatCompletionOutputFunctionDefinition,
)
class MonitoringTester(unittest.TestCase):
def test_code_agent_metrics(self):
class FakeLLMModel:
class FakeLLMModel:
def __init__(self):
self.last_input_token_count = 10
self.last_output_token_count = 20
def __call__(self, prompt, **kwargs):
return """
def __call__(self, prompt, tools_to_call_from=None, **kwargs):
if tools_to_call_from is not None:
return ChatCompletionOutputMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionOutputToolCall(
id="fake_id",
type="function",
function=ChatCompletionOutputFunctionDefinition(
name="final_answer", arguments={"answer": "image"}
),
)
],
)
else:
return ChatCompletionOutputMessage(
role="assistant",
content="""
Code:
```py
final_answer('This is the final answer.')
```"""
```""",
)
class MonitoringTester(unittest.TestCase):
def test_code_agent_metrics(self):
agent = CodeAgent(
tools=[],
model=FakeLLMModel(),
max_steps=1,
)
agent.run("Fake task")
self.assertEqual(agent.monitor.total_input_token_count, 10)
self.assertEqual(agent.monitor.total_output_token_count, 20)
def test_json_agent_metrics(self):
class FakeLLMModel:
def __init__(self):
self.last_input_token_count = 10
self.last_output_token_count = 20
def get_tool_call(self, prompt, **kwargs):
return "final_answer", {"answer": "image"}, "fake_id"
agent = ToolCallingAgent(
tools=[],
model=FakeLLMModel(),
@ -70,17 +85,19 @@ final_answer('This is the final answer.')
self.assertEqual(agent.monitor.total_output_token_count, 20)
def test_code_agent_metrics_max_steps(self):
class FakeLLMModel:
class FakeLLMModelMalformedAnswer:
def __init__(self):
self.last_input_token_count = 10
self.last_output_token_count = 20
def __call__(self, prompt, **kwargs):
return "Malformed answer"
return ChatCompletionOutputMessage(
role="assistant", content="Malformed answer"
)
agent = CodeAgent(
tools=[],
model=FakeLLMModel(),
model=FakeLLMModelMalformedAnswer(),
max_steps=1,
)
@ -90,7 +107,7 @@ final_answer('This is the final answer.')
self.assertEqual(agent.monitor.total_output_token_count, 40)
def test_code_agent_metrics_generation_error(self):
class FakeLLMModel:
class FakeLLMModelGenerationException:
def __init__(self):
self.last_input_token_count = 10
self.last_output_token_count = 20
@ -102,7 +119,7 @@ final_answer('This is the final answer.')
agent = CodeAgent(
tools=[],
model=FakeLLMModel(),
model=FakeLLMModelGenerationException(),
max_steps=1,
)
agent.run("Fake task")
@ -113,16 +130,9 @@ final_answer('This is the final answer.')
self.assertEqual(agent.monitor.total_output_token_count, 0)
def test_streaming_agent_text_output(self):
def dummy_model(prompt, **kwargs):
return """
Code:
```py
final_answer('This is the final answer.')
```"""
agent = CodeAgent(
tools=[],
model=dummy_model,
model=FakeLLMModel(),
max_steps=1,
)
@ -135,16 +145,9 @@ final_answer('This is the final answer.')
self.assertIn("This is the final answer.", final_message.content)
def test_streaming_agent_image_output(self):
class FakeLLM:
def __init__(self):
pass
def get_tool_call(self, messages, **kwargs):
return "final_answer", {"answer": "image"}, "fake_id"
agent = ToolCallingAgent(
tools=[],
model=FakeLLM(),
model=FakeLLMModel(),
max_steps=1,
)