Vastly simplify Model class (#146)

* Vastly simplify Model class by making only one __call__ method 
This commit is contained in:
Aymeric Roucher 2025-01-10 12:30:59 +01:00 committed by GitHub
parent 36ed279c85
commit 5c33130fa4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 364 additions and 294 deletions

View File

@ -113,8 +113,7 @@ The Python interpreter also doesn't allow imports by default outside of a safe l
You can authorize additional imports by passing the authorized modules as a list of strings in argument `additional_authorized_imports` upon initialization of your [`CodeAgent`]: You can authorize additional imports by passing the authorized modules as a list of strings in argument `additional_authorized_imports` upon initialization of your [`CodeAgent`]:
```py ```py
from smolagents import CodeAgent model = HfApiModel()
agent = CodeAgent(tools=[], model=model, additional_authorized_imports=['requests', 'bs4']) agent = CodeAgent(tools=[], model=model, additional_authorized_imports=['requests', 'bs4'])
agent.run("Could you get me the title of the page at url 'https://huggingface.co/blog'?") agent.run("Could you get me the title of the page at url 'https://huggingface.co/blog'?")
``` ```
@ -164,12 +163,12 @@ Transformers comes with a default toolbox for empowering agents, that you can ad
- **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`ToolCallingAgent`] if you initialize it with `add_base_tools=True`, since code-based agent can already natively execute Python code - **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`ToolCallingAgent`] if you initialize it with `add_base_tools=True`, since code-based agent can already natively execute Python code
- **Transcriber**: a speech-to-text pipeline built on Whisper-Turbo that transcribes an audio to text. - **Transcriber**: a speech-to-text pipeline built on Whisper-Turbo that transcribes an audio to text.
You can manually use a tool by calling the [`load_tool`] function and a task to perform. You can manually use a tool by calling it with its arguments.
```python ```python
from smolagents import load_tool from smolagents import DuckDuckGoSearchTool
search_tool = load_tool("web_search") search_tool = DuckDuckGoSearchTool()
print(search_tool("Who's the current president of Russia?")) print(search_tool("Who's the current president of Russia?"))
``` ```

View File

@ -776,11 +776,15 @@ class ToolCallingAgent(MultiStepAgent):
log_entry.agent_memory = agent_memory.copy() log_entry.agent_memory = agent_memory.copy()
try: try:
tool_name, tool_arguments, tool_call_id = self.model.get_tool_call( model_message = self.model(
self.input_messages, self.input_messages,
available_tools=list(self.tools.values()), tools_to_call_from=list(self.tools.values()),
stop_sequences=["Observation:"], stop_sequences=["Observation:"],
) )
tool_calls = model_message.tool_calls[0]
tool_arguments = tool_calls.function.arguments
tool_name, tool_call_id = tool_calls.function.name, tool_calls.id
except Exception as e: except Exception as e:
raise AgentGenerationError( raise AgentGenerationError(
f"Error in generating tool call with model:\n{e}" f"Error in generating tool call with model:\n{e}"
@ -913,7 +917,7 @@ class CodeAgent(MultiStepAgent):
self.input_messages, self.input_messages,
stop_sequences=["<end_code>", "Observation:"], stop_sequences=["<end_code>", "Observation:"],
**additional_args, **additional_args,
) ).content
log_entry.llm_output = llm_output log_entry.llm_output = llm_output
except Exception as e: except Exception as e:
raise AgentGenerationError(f"Error in generating model output:\n{e}") raise AgentGenerationError(f"Error in generating model output:\n{e}")

View File

@ -20,10 +20,16 @@ import os
import random import random
from copy import deepcopy from copy import deepcopy
from enum import Enum from enum import Enum
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional
import torch import torch
from huggingface_hub import InferenceClient from huggingface_hub import (
InferenceClient,
ChatCompletionOutputMessage,
ChatCompletionOutputToolCall,
ChatCompletionOutputFunctionDefinition,
)
from transformers import ( from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
@ -33,7 +39,6 @@ from transformers import (
import openai import openai
from .tools import Tool from .tools import Tool
from .utils import parse_json_tool_call
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -234,63 +239,46 @@ class HfApiModel(Model):
model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct", model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct",
token: Optional[str] = None, token: Optional[str] = None,
timeout: Optional[int] = 120, timeout: Optional[int] = 120,
temperature: float = 0.5,
): ):
super().__init__() super().__init__()
self.model_id = model_id self.model_id = model_id
if token is None: if token is None:
token = os.getenv("HF_TOKEN") token = os.getenv("HF_TOKEN")
self.client = InferenceClient(self.model_id, token=token, timeout=timeout) self.client = InferenceClient(self.model_id, token=token, timeout=timeout)
self.temperature = temperature
def generate( def __call__(
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None, grammar: Optional[str] = None,
max_tokens: int = 1500, max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
) -> str: ) -> str:
"""Generates a text completion for the given message list"""
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
# Send messages to the Hugging Face Inference API
if grammar is not None:
output = self.client.chat_completion(
messages,
stop=stop_sequences,
response_format=grammar,
max_tokens=max_tokens,
)
else:
output = self.client.chat.completions.create(
messages, stop=stop_sequences, max_tokens=max_tokens
)
response = output.choices[0].message.content
self.last_input_token_count = output.usage.prompt_tokens
self.last_output_token_count = output.usage.completion_tokens
return response
def get_tool_call(
self,
messages: List[Dict[str, str]],
available_tools: List[Tool],
stop_sequences,
):
"""Generates a tool call for the given message list. This method is used only by `ToolCallingAgent`."""
messages = get_clean_message_list( messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions messages, role_conversions=tool_role_conversions
) )
if tools_to_call_from:
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
messages=messages, messages=messages,
tools=[get_json_schema(tool) for tool in available_tools], tools=[get_json_schema(tool) for tool in tools_to_call_from],
tool_choice="auto", tool_choice="auto",
stop=stop_sequences, stop=stop_sequences,
max_tokens=max_tokens,
temperature=self.temperature,
)
else:
response = self.client.chat.completions.create(
model=self.model_id,
messages=messages,
stop=stop_sequences,
max_tokens=max_tokens,
temperature=self.temperature,
) )
tool_call = response.choices[0].message.tool_calls[0]
self.last_input_token_count = response.usage.prompt_tokens self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens self.last_output_token_count = response.usage.completion_tokens
return tool_call.function.name, tool_call.function.arguments, tool_call.id return response.choices[0].message
class TransformersModel(Model): class TransformersModel(Model):
@ -354,18 +342,27 @@ class TransformersModel(Model):
return StoppingCriteriaList([StopOnStrings(stop_sequences, self.tokenizer)]) return StoppingCriteriaList([StopOnStrings(stop_sequences, self.tokenizer)])
def generate( def __call__(
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None, grammar: Optional[str] = None,
max_tokens: int = 1500, max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
) -> str: ) -> str:
messages = get_clean_message_list( messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions messages, role_conversions=tool_role_conversions
) )
# Get LLM output if tools_to_call_from is not None:
prompt_tensor = self.tokenizer.apply_chat_template(
messages,
tools=[get_json_schema(tool) for tool in tools_to_call_from],
return_tensors="pt",
return_dict=True,
add_generation_prompt=True,
)
else:
prompt_tensor = self.tokenizer.apply_chat_template( prompt_tensor = self.tokenizer.apply_chat_template(
messages, messages,
return_tensors="pt", return_tensors="pt",
@ -382,56 +379,31 @@ class TransformersModel(Model):
), ),
) )
generated_tokens = out[0, count_prompt_tokens:] generated_tokens = out[0, count_prompt_tokens:]
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
self.last_input_token_count = count_prompt_tokens self.last_input_token_count = count_prompt_tokens
self.last_output_token_count = len(generated_tokens) self.last_output_token_count = len(generated_tokens)
if stop_sequences is not None: if stop_sequences is not None:
response = remove_stop_sequences(response, stop_sequences) output = remove_stop_sequences(output, stop_sequences)
return response
def get_tool_call( if tools_to_call_from is None:
self, return ChatCompletionOutputMessage(role="assistant", content=output)
messages: List[Dict[str, str]], else:
available_tools: List[Tool], tool_name, tool_arguments = json.load(output)
stop_sequences: Optional[List[str]] = None, return ChatCompletionOutputMessage(
max_tokens: int = 500, role="assistant",
) -> Tuple[str, Union[str, None], str]: content="",
messages = get_clean_message_list( tool_calls=[
messages, role_conversions=tool_role_conversions ChatCompletionOutputToolCall(
) id="".join(random.choices("0123456789", k=5)),
type="function",
prompt = self.tokenizer.apply_chat_template( function=ChatCompletionOutputFunctionDefinition(
messages, name=tool_name, arguments=tool_arguments
tools=[get_json_schema(tool) for tool in available_tools],
return_tensors="pt",
return_dict=True,
add_generation_prompt=True,
)
prompt = prompt.to(self.model.device)
count_prompt_tokens = prompt["input_ids"].shape[1]
out = self.model.generate(
**prompt,
max_new_tokens=max_tokens,
stopping_criteria=(
self.make_stopping_criteria(stop_sequences) if stop_sequences else None
), ),
) )
generated_tokens = out[0, count_prompt_tokens:] ],
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) )
self.last_input_token_count = count_prompt_tokens
self.last_output_token_count = len(generated_tokens)
if stop_sequences is not None:
response = remove_stop_sequences(response, stop_sequences)
tool_name, tool_input = parse_json_tool_call(response)
call_id = "".join(random.choices("0123456789", k=5))
return tool_name, tool_input, call_id
class LiteLLMModel(Model): class LiteLLMModel(Model):
@ -460,38 +432,16 @@ class LiteLLMModel(Model):
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None, grammar: Optional[str] = None,
max_tokens: int = 1500, max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
) -> str: ) -> str:
messages = get_clean_message_list( messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions messages, role_conversions=tool_role_conversions
) )
if tools_to_call_from:
response = litellm.completion( response = litellm.completion(
model=self.model_id, model=self.model_id,
messages=messages, messages=messages,
stop=stop_sequences, tools=[get_json_schema(tool) for tool in tools_to_call_from],
max_tokens=max_tokens,
api_base=self.api_base,
api_key=self.api_key,
**self.kwargs,
)
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
return response.choices[0].message.content
def get_tool_call(
self,
messages: List[Dict[str, str]],
available_tools: List[Tool],
stop_sequences: Optional[List[str]] = None,
max_tokens: int = 1500,
):
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
response = litellm.completion(
model=self.model_id,
messages=messages,
tools=[get_json_schema(tool) for tool in available_tools],
tool_choice="required", tool_choice="required",
stop=stop_sequences, stop=stop_sequences,
max_tokens=max_tokens, max_tokens=max_tokens,
@ -499,11 +449,19 @@ class LiteLLMModel(Model):
api_key=self.api_key, api_key=self.api_key,
**self.kwargs, **self.kwargs,
) )
tool_calls = response.choices[0].message.tool_calls[0] else:
response = litellm.completion(
model=self.model_id,
messages=messages,
stop=stop_sequences,
max_tokens=max_tokens,
api_base=self.api_base,
api_key=self.api_key,
**self.kwargs,
)
self.last_input_token_count = response.usage.prompt_tokens self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens self.last_output_token_count = response.usage.completion_tokens
arguments = json.loads(tool_calls.function.arguments) return response.choices[0].message
return tool_calls.function.name, arguments, tool_calls.id
class OpenAIServerModel(Model): class OpenAIServerModel(Model):
@ -539,64 +497,40 @@ class OpenAIServerModel(Model):
self.temperature = temperature self.temperature = temperature
self.kwargs = kwargs self.kwargs = kwargs
def generate( def __call__(
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None, grammar: Optional[str] = None,
max_tokens: int = 1500, max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
) -> str: ) -> str:
"""Generates a text completion for the given message list"""
messages = get_clean_message_list( messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions messages, role_conversions=tool_role_conversions
) )
if tools_to_call_from:
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=self.model_id, model=self.model_id,
messages=messages, messages=messages,
stop=stop_sequences, tools=[get_json_schema(tool) for tool in tools_to_call_from],
max_tokens=max_tokens,
temperature=self.temperature,
**self.kwargs,
)
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
return response.choices[0].message.content
def get_tool_call(
self,
messages: List[Dict[str, str]],
available_tools: List[Tool],
stop_sequences: Optional[List[str]] = None,
max_tokens: int = 500,
) -> Tuple[str, Union[str, Dict], str]:
"""Generates a tool call for the given message list"""
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
response = self.client.chat.completions.create(
model=self.model_id,
messages=messages,
tools=[get_json_schema(tool) for tool in available_tools],
tool_choice="auto", tool_choice="auto",
stop=stop_sequences, stop=stop_sequences,
max_tokens=max_tokens, max_tokens=max_tokens,
temperature=self.temperature, temperature=self.temperature,
**self.kwargs, **self.kwargs,
) )
else:
tool_calls = response.choices[0].message.tool_calls[0] response = self.client.chat.completions.create(
model=self.model_id,
messages=messages,
stop=stop_sequences,
max_tokens=max_tokens,
temperature=self.temperature,
**self.kwargs,
)
self.last_input_token_count = response.usage.prompt_tokens self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens self.last_output_token_count = response.usage.completion_tokens
return response.choices[0].message
try:
arguments = json.loads(tool_calls.function.arguments)
except json.JSONDecodeError:
arguments = tool_calls.function.arguments
return tool_calls.function.name, arguments, tool_calls.id
__all__ = [ __all__ = [

View File

@ -30,6 +30,11 @@ from smolagents.agents import (
from smolagents.default_tools import PythonInterpreterTool from smolagents.default_tools import PythonInterpreterTool
from smolagents.tools import tool from smolagents.tools import tool
from smolagents.types import AgentImage, AgentText from smolagents.types import AgentImage, AgentText
from huggingface_hub import (
ChatCompletionOutputMessage,
ChatCompletionOutputToolCall,
ChatCompletionOutputFunctionDefinition,
)
def get_new_path(suffix="") -> str: def get_new_path(suffix="") -> str:
@ -38,54 +43,106 @@ def get_new_path(suffix="") -> str:
class FakeToolCallModel: class FakeToolCallModel:
def get_tool_call( def __call__(
self, messages, available_tools, stop_sequences=None, grammar=None self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
): ):
if len(messages) < 3: if len(messages) < 3:
return "python_interpreter", {"code": "2*3.6452"}, "call_0" return ChatCompletionOutputMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionOutputToolCall(
id="call_0",
type="function",
function=ChatCompletionOutputFunctionDefinition(
name="python_interpreter", arguments={"code": "2*3.6452"}
),
)
],
)
else: else:
return "final_answer", {"answer": "7.2904"}, "call_1" return ChatCompletionOutputMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionOutputToolCall(
id="call_1",
type="function",
function=ChatCompletionOutputFunctionDefinition(
name="final_answer", arguments={"answer": "7.2904"}
),
)
],
)
class FakeToolCallModelImage: class FakeToolCallModelImage:
def get_tool_call( def __call__(
self, messages, available_tools, stop_sequences=None, grammar=None self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
): ):
if len(messages) < 3: if len(messages) < 3:
return ( return ChatCompletionOutputMessage(
"fake_image_generation_tool", role="assistant",
{"prompt": "An image of a cat"}, content="",
"call_0", tool_calls=[
ChatCompletionOutputToolCall(
id="call_0",
type="function",
function=ChatCompletionOutputFunctionDefinition(
name="fake_image_generation_tool",
arguments={"prompt": "An image of a cat"},
),
)
],
)
else:
return ChatCompletionOutputMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionOutputToolCall(
id="call_1",
type="function",
function=ChatCompletionOutputFunctionDefinition(
name="final_answer", arguments="image.png"
),
)
],
) )
else: # We're at step 2
return "final_answer", "image.png", "call_1"
def fake_code_model(messages, stop_sequences=None, grammar=None) -> str: def fake_code_model(messages, stop_sequences=None, grammar=None) -> str:
prompt = str(messages) prompt = str(messages)
if "special_marker" not in prompt: if "special_marker" not in prompt:
return """ return ChatCompletionOutputMessage(
role="assistant",
content="""
Thought: I should multiply 2 by 3.6452. special_marker Thought: I should multiply 2 by 3.6452. special_marker
Code: Code:
```py ```py
result = 2**3.6452 result = 2**3.6452
```<end_code> ```<end_code>
""" """,
)
else: # We're at step 2 else: # We're at step 2
return """ return ChatCompletionOutputMessage(
role="assistant",
content="""
Thought: I can now answer the initial question Thought: I can now answer the initial question
Code: Code:
```py ```py
final_answer(7.2904) final_answer(7.2904)
```<end_code> ```<end_code>
""" """,
)
def fake_code_model_error(messages, stop_sequences=None) -> str: def fake_code_model_error(messages, stop_sequences=None) -> str:
prompt = str(messages) prompt = str(messages)
if "special_marker" not in prompt: if "special_marker" not in prompt:
return """ return ChatCompletionOutputMessage(
role="assistant",
content="""
Thought: I should multiply 2 by 3.6452. special_marker Thought: I should multiply 2 by 3.6452. special_marker
Code: Code:
```py ```py
@ -94,21 +151,27 @@ b = a * 2
print = 2 print = 2
print("Ok, calculation done!") print("Ok, calculation done!")
```<end_code> ```<end_code>
""" """,
)
else: # We're at step 2 else: # We're at step 2
return """ return ChatCompletionOutputMessage(
role="assistant",
content="""
Thought: I can now answer the initial question Thought: I can now answer the initial question
Code: Code:
```py ```py
final_answer("got an error") final_answer("got an error")
```<end_code> ```<end_code>
""" """,
)
def fake_code_model_syntax_error(messages, stop_sequences=None) -> str: def fake_code_model_syntax_error(messages, stop_sequences=None) -> str:
prompt = str(messages) prompt = str(messages)
if "special_marker" not in prompt: if "special_marker" not in prompt:
return """ return ChatCompletionOutputMessage(
role="assistant",
content="""
Thought: I should multiply 2 by 3.6452. special_marker Thought: I should multiply 2 by 3.6452. special_marker
Code: Code:
```py ```py
@ -117,32 +180,41 @@ b = a * 2
print("Failing due to unexpected indent") print("Failing due to unexpected indent")
print("Ok, calculation done!") print("Ok, calculation done!")
```<end_code> ```<end_code>
""" """,
)
else: # We're at step 2 else: # We're at step 2
return """ return ChatCompletionOutputMessage(
role="assistant",
content="""
Thought: I can now answer the initial question Thought: I can now answer the initial question
Code: Code:
```py ```py
final_answer("got an error") final_answer("got an error")
```<end_code> ```<end_code>
""" """,
)
def fake_code_model_import(messages, stop_sequences=None) -> str: def fake_code_model_import(messages, stop_sequences=None) -> str:
return """ return ChatCompletionOutputMessage(
role="assistant",
content="""
Thought: I can answer the question Thought: I can answer the question
Code: Code:
```py ```py
import numpy as np import numpy as np
final_answer("got an error") final_answer("got an error")
```<end_code> ```<end_code>
""" """,
)
def fake_code_functiondef(messages, stop_sequences=None) -> str: def fake_code_functiondef(messages, stop_sequences=None) -> str:
prompt = str(messages) prompt = str(messages)
if "special_marker" not in prompt: if "special_marker" not in prompt:
return """ return ChatCompletionOutputMessage(
role="assistant",
content="""
Thought: Let's define the function. special_marker Thought: Let's define the function. special_marker
Code: Code:
```py ```py
@ -151,9 +223,12 @@ import numpy as np
def moving_average(x, w): def moving_average(x, w):
return np.convolve(x, np.ones(w), 'valid') / w return np.convolve(x, np.ones(w), 'valid') / w
```<end_code> ```<end_code>
""" """,
)
else: # We're at step 2 else: # We're at step 2
return """ return ChatCompletionOutputMessage(
role="assistant",
content="""
Thought: I can now answer the initial question Thought: I can now answer the initial question
Code: Code:
```py ```py
@ -161,29 +236,36 @@ x, w = [0, 1, 2, 3, 4, 5], 2
res = moving_average(x, w) res = moving_average(x, w)
final_answer(res) final_answer(res)
```<end_code> ```<end_code>
""" """,
)
def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) -> str: def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) -> str:
return """ return ChatCompletionOutputMessage(
role="assistant",
content="""
Thought: I should multiply 2 by 3.6452. special_marker Thought: I should multiply 2 by 3.6452. special_marker
Code: Code:
```py ```py
result = python_interpreter(code="2*3.6452") result = python_interpreter(code="2*3.6452")
final_answer(result) final_answer(result)
``` ```
""" """,
)
def fake_code_model_no_return(messages, stop_sequences=None, grammar=None) -> str: def fake_code_model_no_return(messages, stop_sequences=None, grammar=None) -> str:
return """ return ChatCompletionOutputMessage(
role="assistant",
content="""
Thought: I should multiply 2 by 3.6452. special_marker Thought: I should multiply 2 by 3.6452. special_marker
Code: Code:
```py ```py
result = python_interpreter(code="2*3.6452") result = python_interpreter(code="2*3.6452")
print(result) print(result)
``` ```
""" """,
)
class AgentTests(unittest.TestCase): class AgentTests(unittest.TestCase):
@ -360,52 +442,92 @@ class AgentTests(unittest.TestCase):
def test_multiagents(self): def test_multiagents(self):
class FakeModelMultiagentsManagerAgent: class FakeModelMultiagentsManagerAgent:
def __call__(self, messages, stop_sequences=None, grammar=None): def __call__(
self,
messages,
stop_sequences=None,
grammar=None,
tools_to_call_from=None,
):
if tools_to_call_from is not None:
if len(messages) < 3: if len(messages) < 3:
return """ return ChatCompletionOutputMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionOutputToolCall(
id="call_0",
type="function",
function=ChatCompletionOutputFunctionDefinition(
name="search_agent",
arguments="Who is the current US president?",
),
)
],
)
else:
assert "Report on the current US president" in str(messages)
return ChatCompletionOutputMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionOutputToolCall(
id="call_0",
type="function",
function=ChatCompletionOutputFunctionDefinition(
name="final_answer", arguments="Final report."
),
)
],
)
else:
if len(messages) < 3:
return ChatCompletionOutputMessage(
role="assistant",
content="""
Thought: Let's call our search agent. Thought: Let's call our search agent.
Code: Code:
```py ```py
result = search_agent("Who is the current US president?") result = search_agent("Who is the current US president?")
```<end_code> ```<end_code>
""" """,
)
else: else:
assert "Report on the current US president" in str(messages) assert "Report on the current US president" in str(messages)
return """ return ChatCompletionOutputMessage(
role="assistant",
content="""
Thought: Let's return the report. Thought: Let's return the report.
Code: Code:
```py ```py
final_answer("Final report.") final_answer("Final report.")
```<end_code> ```<end_code>
""" """,
def get_tool_call(
self, messages, available_tools, stop_sequences=None, grammar=None
):
if len(messages) < 3:
return (
"search_agent",
"Who is the current US president?",
"call_0",
)
else:
assert "Report on the current US president" in str(messages)
return (
"final_answer",
"Final report.",
"call_0",
) )
manager_model = FakeModelMultiagentsManagerAgent() manager_model = FakeModelMultiagentsManagerAgent()
class FakeModelMultiagentsManagedAgent: class FakeModelMultiagentsManagedAgent:
def get_tool_call( def __call__(
self, messages, available_tools, stop_sequences=None, grammar=None self,
messages,
tools_to_call_from=None,
stop_sequences=None,
grammar=None,
): ):
return ( return ChatCompletionOutputMessage(
"final_answer", role="assistant",
{"report": "Report on the current US president"}, content="",
"call_0", tool_calls=[
ChatCompletionOutputToolCall(
id="call_0",
type="function",
function=ChatCompletionOutputFunctionDefinition(
name="final_answer",
arguments="Report on the current US president",
),
)
],
) )
managed_model = FakeModelMultiagentsManagedAgent() managed_model = FakeModelMultiagentsManagedAgent()
@ -443,13 +565,16 @@ final_answer("Final report.")
def test_code_nontrivial_final_answer_works(self): def test_code_nontrivial_final_answer_works(self):
def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None): def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None):
return """Code: return ChatCompletionOutputMessage(
role="assistant",
content="""Code:
```py ```py
def nested_answer(): def nested_answer():
final_answer("Correct!") final_answer("Correct!")
nested_answer() nested_answer()
```<end_code>""" ```<end_code>""",
)
agent = CodeAgent(tools=[], model=fake_code_model_final_answer) agent = CodeAgent(tools=[], model=fake_code_model_final_answer)

View File

@ -92,7 +92,6 @@ class TestDocs:
raise ValueError(f"Docs directory not found at {cls.docs_dir}") raise ValueError(f"Docs directory not found at {cls.docs_dir}")
load_dotenv() load_dotenv()
cls.hf_token = os.getenv("HF_TOKEN")
cls.md_files = list(cls.docs_dir.rglob("*.md")) cls.md_files = list(cls.docs_dir.rglob("*.md"))
if not cls.md_files: if not cls.md_files:
@ -115,6 +114,7 @@ class TestDocs:
"from_langchain", # Langchain is not a dependency "from_langchain", # Langchain is not a dependency
"while llm_should_continue(memory):", # This is pseudo code "while llm_should_continue(memory):", # This is pseudo code
"ollama_chat/llama3.2", # Exclude ollama building in guided tour "ollama_chat/llama3.2", # Exclude ollama building in guided tour
"model = TransformersModel(model_id=model_id)", # Exclude testing with transformers model
] ]
code_blocks = [ code_blocks = [
block block
@ -131,10 +131,15 @@ class TestDocs:
ast.parse(block) ast.parse(block)
# Create and execute test script # Create and execute test script
print("\n\nCollected code block:==========\n".join(code_blocks))
try: try:
code_blocks = [ code_blocks = [
block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", self.hf_token).replace( (
"{your_username}", "m-ric" block.replace(
"<YOUR_HUGGINGFACEHUB_API_TOKEN>", os.getenv("HF_TOKEN")
)
.replace("YOUR_ANTHROPIC_API_KEY", os.getenv("ANTHROPIC_API_KEY"))
.replace("{your_username}", "m-ric")
) )
for block in code_blocks for block in code_blocks
] ]

View File

@ -22,42 +22,57 @@ from smolagents import (
ToolCallingAgent, ToolCallingAgent,
stream_to_gradio, stream_to_gradio,
) )
from huggingface_hub import (
ChatCompletionOutputMessage,
ChatCompletionOutputToolCall,
ChatCompletionOutputFunctionDefinition,
)
class MonitoringTester(unittest.TestCase): class FakeLLMModel:
def test_code_agent_metrics(self):
class FakeLLMModel:
def __init__(self): def __init__(self):
self.last_input_token_count = 10 self.last_input_token_count = 10
self.last_output_token_count = 20 self.last_output_token_count = 20
def __call__(self, prompt, **kwargs): def __call__(self, prompt, tools_to_call_from=None, **kwargs):
return """ if tools_to_call_from is not None:
return ChatCompletionOutputMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionOutputToolCall(
id="fake_id",
type="function",
function=ChatCompletionOutputFunctionDefinition(
name="final_answer", arguments={"answer": "image"}
),
)
],
)
else:
return ChatCompletionOutputMessage(
role="assistant",
content="""
Code: Code:
```py ```py
final_answer('This is the final answer.') final_answer('This is the final answer.')
```""" ```""",
)
class MonitoringTester(unittest.TestCase):
def test_code_agent_metrics(self):
agent = CodeAgent( agent = CodeAgent(
tools=[], tools=[],
model=FakeLLMModel(), model=FakeLLMModel(),
max_steps=1, max_steps=1,
) )
agent.run("Fake task") agent.run("Fake task")
self.assertEqual(agent.monitor.total_input_token_count, 10) self.assertEqual(agent.monitor.total_input_token_count, 10)
self.assertEqual(agent.monitor.total_output_token_count, 20) self.assertEqual(agent.monitor.total_output_token_count, 20)
def test_json_agent_metrics(self): def test_json_agent_metrics(self):
class FakeLLMModel:
def __init__(self):
self.last_input_token_count = 10
self.last_output_token_count = 20
def get_tool_call(self, prompt, **kwargs):
return "final_answer", {"answer": "image"}, "fake_id"
agent = ToolCallingAgent( agent = ToolCallingAgent(
tools=[], tools=[],
model=FakeLLMModel(), model=FakeLLMModel(),
@ -70,17 +85,19 @@ final_answer('This is the final answer.')
self.assertEqual(agent.monitor.total_output_token_count, 20) self.assertEqual(agent.monitor.total_output_token_count, 20)
def test_code_agent_metrics_max_steps(self): def test_code_agent_metrics_max_steps(self):
class FakeLLMModel: class FakeLLMModelMalformedAnswer:
def __init__(self): def __init__(self):
self.last_input_token_count = 10 self.last_input_token_count = 10
self.last_output_token_count = 20 self.last_output_token_count = 20
def __call__(self, prompt, **kwargs): def __call__(self, prompt, **kwargs):
return "Malformed answer" return ChatCompletionOutputMessage(
role="assistant", content="Malformed answer"
)
agent = CodeAgent( agent = CodeAgent(
tools=[], tools=[],
model=FakeLLMModel(), model=FakeLLMModelMalformedAnswer(),
max_steps=1, max_steps=1,
) )
@ -90,7 +107,7 @@ final_answer('This is the final answer.')
self.assertEqual(agent.monitor.total_output_token_count, 40) self.assertEqual(agent.monitor.total_output_token_count, 40)
def test_code_agent_metrics_generation_error(self): def test_code_agent_metrics_generation_error(self):
class FakeLLMModel: class FakeLLMModelGenerationException:
def __init__(self): def __init__(self):
self.last_input_token_count = 10 self.last_input_token_count = 10
self.last_output_token_count = 20 self.last_output_token_count = 20
@ -102,7 +119,7 @@ final_answer('This is the final answer.')
agent = CodeAgent( agent = CodeAgent(
tools=[], tools=[],
model=FakeLLMModel(), model=FakeLLMModelGenerationException(),
max_steps=1, max_steps=1,
) )
agent.run("Fake task") agent.run("Fake task")
@ -113,16 +130,9 @@ final_answer('This is the final answer.')
self.assertEqual(agent.monitor.total_output_token_count, 0) self.assertEqual(agent.monitor.total_output_token_count, 0)
def test_streaming_agent_text_output(self): def test_streaming_agent_text_output(self):
def dummy_model(prompt, **kwargs):
return """
Code:
```py
final_answer('This is the final answer.')
```"""
agent = CodeAgent( agent = CodeAgent(
tools=[], tools=[],
model=dummy_model, model=FakeLLMModel(),
max_steps=1, max_steps=1,
) )
@ -135,16 +145,9 @@ final_answer('This is the final answer.')
self.assertIn("This is the final answer.", final_message.content) self.assertIn("This is the final answer.", final_message.content)
def test_streaming_agent_image_output(self): def test_streaming_agent_image_output(self):
class FakeLLM:
def __init__(self):
pass
def get_tool_call(self, messages, **kwargs):
return "final_answer", {"answer": "image"}, "fake_id"
agent = ToolCallingAgent( agent = ToolCallingAgent(
tools=[], tools=[],
model=FakeLLM(), model=FakeLLMModel(),
max_steps=1, max_steps=1,
) )