Vastly simplify Model class ✨ (#146)
* Vastly simplify Model class by making only one __call__ method ✨
This commit is contained in:
parent
36ed279c85
commit
5c33130fa4
|
@ -113,8 +113,7 @@ The Python interpreter also doesn't allow imports by default outside of a safe l
|
|||
You can authorize additional imports by passing the authorized modules as a list of strings in argument `additional_authorized_imports` upon initialization of your [`CodeAgent`]:
|
||||
|
||||
```py
|
||||
from smolagents import CodeAgent
|
||||
|
||||
model = HfApiModel()
|
||||
agent = CodeAgent(tools=[], model=model, additional_authorized_imports=['requests', 'bs4'])
|
||||
agent.run("Could you get me the title of the page at url 'https://huggingface.co/blog'?")
|
||||
```
|
||||
|
@ -164,12 +163,12 @@ Transformers comes with a default toolbox for empowering agents, that you can ad
|
|||
- **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`ToolCallingAgent`] if you initialize it with `add_base_tools=True`, since code-based agent can already natively execute Python code
|
||||
- **Transcriber**: a speech-to-text pipeline built on Whisper-Turbo that transcribes an audio to text.
|
||||
|
||||
You can manually use a tool by calling the [`load_tool`] function and a task to perform.
|
||||
You can manually use a tool by calling it with its arguments.
|
||||
|
||||
```python
|
||||
from smolagents import load_tool
|
||||
from smolagents import DuckDuckGoSearchTool
|
||||
|
||||
search_tool = load_tool("web_search")
|
||||
search_tool = DuckDuckGoSearchTool()
|
||||
print(search_tool("Who's the current president of Russia?"))
|
||||
```
|
||||
|
||||
|
|
|
@ -776,11 +776,15 @@ class ToolCallingAgent(MultiStepAgent):
|
|||
log_entry.agent_memory = agent_memory.copy()
|
||||
|
||||
try:
|
||||
tool_name, tool_arguments, tool_call_id = self.model.get_tool_call(
|
||||
model_message = self.model(
|
||||
self.input_messages,
|
||||
available_tools=list(self.tools.values()),
|
||||
tools_to_call_from=list(self.tools.values()),
|
||||
stop_sequences=["Observation:"],
|
||||
)
|
||||
tool_calls = model_message.tool_calls[0]
|
||||
tool_arguments = tool_calls.function.arguments
|
||||
tool_name, tool_call_id = tool_calls.function.name, tool_calls.id
|
||||
|
||||
except Exception as e:
|
||||
raise AgentGenerationError(
|
||||
f"Error in generating tool call with model:\n{e}"
|
||||
|
@ -913,7 +917,7 @@ class CodeAgent(MultiStepAgent):
|
|||
self.input_messages,
|
||||
stop_sequences=["<end_code>", "Observation:"],
|
||||
**additional_args,
|
||||
)
|
||||
).content
|
||||
log_entry.llm_output = llm_output
|
||||
except Exception as e:
|
||||
raise AgentGenerationError(f"Error in generating model output:\n{e}")
|
||||
|
|
|
@ -20,10 +20,16 @@ import os
|
|||
import random
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub import InferenceClient
|
||||
from huggingface_hub import (
|
||||
InferenceClient,
|
||||
ChatCompletionOutputMessage,
|
||||
ChatCompletionOutputToolCall,
|
||||
ChatCompletionOutputFunctionDefinition,
|
||||
)
|
||||
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
|
@ -33,7 +39,6 @@ from transformers import (
|
|||
import openai
|
||||
|
||||
from .tools import Tool
|
||||
from .utils import parse_json_tool_call
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -234,63 +239,46 @@ class HfApiModel(Model):
|
|||
model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct",
|
||||
token: Optional[str] = None,
|
||||
timeout: Optional[int] = 120,
|
||||
temperature: float = 0.5,
|
||||
):
|
||||
super().__init__()
|
||||
self.model_id = model_id
|
||||
if token is None:
|
||||
token = os.getenv("HF_TOKEN")
|
||||
self.client = InferenceClient(self.model_id, token=token, timeout=timeout)
|
||||
self.temperature = temperature
|
||||
|
||||
def generate(
|
||||
def __call__(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500,
|
||||
tools_to_call_from: Optional[List[Tool]] = None,
|
||||
) -> str:
|
||||
"""Generates a text completion for the given message list"""
|
||||
messages = get_clean_message_list(
|
||||
messages, role_conversions=tool_role_conversions
|
||||
)
|
||||
|
||||
# Send messages to the Hugging Face Inference API
|
||||
if grammar is not None:
|
||||
output = self.client.chat_completion(
|
||||
messages,
|
||||
stop=stop_sequences,
|
||||
response_format=grammar,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
else:
|
||||
output = self.client.chat.completions.create(
|
||||
messages, stop=stop_sequences, max_tokens=max_tokens
|
||||
)
|
||||
|
||||
response = output.choices[0].message.content
|
||||
self.last_input_token_count = output.usage.prompt_tokens
|
||||
self.last_output_token_count = output.usage.completion_tokens
|
||||
return response
|
||||
|
||||
def get_tool_call(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
available_tools: List[Tool],
|
||||
stop_sequences,
|
||||
):
|
||||
"""Generates a tool call for the given message list. This method is used only by `ToolCallingAgent`."""
|
||||
messages = get_clean_message_list(
|
||||
messages, role_conversions=tool_role_conversions
|
||||
)
|
||||
if tools_to_call_from:
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
tools=[get_json_schema(tool) for tool in available_tools],
|
||||
tools=[get_json_schema(tool) for tool in tools_to_call_from],
|
||||
tool_choice="auto",
|
||||
stop=stop_sequences,
|
||||
max_tokens=max_tokens,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
else:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_id,
|
||||
messages=messages,
|
||||
stop=stop_sequences,
|
||||
max_tokens=max_tokens,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
tool_call = response.choices[0].message.tool_calls[0]
|
||||
self.last_input_token_count = response.usage.prompt_tokens
|
||||
self.last_output_token_count = response.usage.completion_tokens
|
||||
return tool_call.function.name, tool_call.function.arguments, tool_call.id
|
||||
return response.choices[0].message
|
||||
|
||||
|
||||
class TransformersModel(Model):
|
||||
|
@ -354,18 +342,27 @@ class TransformersModel(Model):
|
|||
|
||||
return StoppingCriteriaList([StopOnStrings(stop_sequences, self.tokenizer)])
|
||||
|
||||
def generate(
|
||||
def __call__(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500,
|
||||
tools_to_call_from: Optional[List[Tool]] = None,
|
||||
) -> str:
|
||||
messages = get_clean_message_list(
|
||||
messages, role_conversions=tool_role_conversions
|
||||
)
|
||||
|
||||
# Get LLM output
|
||||
if tools_to_call_from is not None:
|
||||
prompt_tensor = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tools=[get_json_schema(tool) for tool in tools_to_call_from],
|
||||
return_tensors="pt",
|
||||
return_dict=True,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
else:
|
||||
prompt_tensor = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
return_tensors="pt",
|
||||
|
@ -382,56 +379,31 @@ class TransformersModel(Model):
|
|||
),
|
||||
)
|
||||
generated_tokens = out[0, count_prompt_tokens:]
|
||||
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||
output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||
|
||||
self.last_input_token_count = count_prompt_tokens
|
||||
self.last_output_token_count = len(generated_tokens)
|
||||
|
||||
if stop_sequences is not None:
|
||||
response = remove_stop_sequences(response, stop_sequences)
|
||||
return response
|
||||
output = remove_stop_sequences(output, stop_sequences)
|
||||
|
||||
def get_tool_call(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
available_tools: List[Tool],
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
max_tokens: int = 500,
|
||||
) -> Tuple[str, Union[str, None], str]:
|
||||
messages = get_clean_message_list(
|
||||
messages, role_conversions=tool_role_conversions
|
||||
)
|
||||
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tools=[get_json_schema(tool) for tool in available_tools],
|
||||
return_tensors="pt",
|
||||
return_dict=True,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
prompt = prompt.to(self.model.device)
|
||||
count_prompt_tokens = prompt["input_ids"].shape[1]
|
||||
|
||||
out = self.model.generate(
|
||||
**prompt,
|
||||
max_new_tokens=max_tokens,
|
||||
stopping_criteria=(
|
||||
self.make_stopping_criteria(stop_sequences) if stop_sequences else None
|
||||
if tools_to_call_from is None:
|
||||
return ChatCompletionOutputMessage(role="assistant", content=output)
|
||||
else:
|
||||
tool_name, tool_arguments = json.load(output)
|
||||
return ChatCompletionOutputMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatCompletionOutputToolCall(
|
||||
id="".join(random.choices("0123456789", k=5)),
|
||||
type="function",
|
||||
function=ChatCompletionOutputFunctionDefinition(
|
||||
name=tool_name, arguments=tool_arguments
|
||||
),
|
||||
)
|
||||
generated_tokens = out[0, count_prompt_tokens:]
|
||||
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||
|
||||
self.last_input_token_count = count_prompt_tokens
|
||||
self.last_output_token_count = len(generated_tokens)
|
||||
|
||||
if stop_sequences is not None:
|
||||
response = remove_stop_sequences(response, stop_sequences)
|
||||
|
||||
tool_name, tool_input = parse_json_tool_call(response)
|
||||
call_id = "".join(random.choices("0123456789", k=5))
|
||||
|
||||
return tool_name, tool_input, call_id
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class LiteLLMModel(Model):
|
||||
|
@ -460,38 +432,16 @@ class LiteLLMModel(Model):
|
|||
stop_sequences: Optional[List[str]] = None,
|
||||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500,
|
||||
tools_to_call_from: Optional[List[Tool]] = None,
|
||||
) -> str:
|
||||
messages = get_clean_message_list(
|
||||
messages, role_conversions=tool_role_conversions
|
||||
)
|
||||
|
||||
if tools_to_call_from:
|
||||
response = litellm.completion(
|
||||
model=self.model_id,
|
||||
messages=messages,
|
||||
stop=stop_sequences,
|
||||
max_tokens=max_tokens,
|
||||
api_base=self.api_base,
|
||||
api_key=self.api_key,
|
||||
**self.kwargs,
|
||||
)
|
||||
self.last_input_token_count = response.usage.prompt_tokens
|
||||
self.last_output_token_count = response.usage.completion_tokens
|
||||
return response.choices[0].message.content
|
||||
|
||||
def get_tool_call(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
available_tools: List[Tool],
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
max_tokens: int = 1500,
|
||||
):
|
||||
messages = get_clean_message_list(
|
||||
messages, role_conversions=tool_role_conversions
|
||||
)
|
||||
response = litellm.completion(
|
||||
model=self.model_id,
|
||||
messages=messages,
|
||||
tools=[get_json_schema(tool) for tool in available_tools],
|
||||
tools=[get_json_schema(tool) for tool in tools_to_call_from],
|
||||
tool_choice="required",
|
||||
stop=stop_sequences,
|
||||
max_tokens=max_tokens,
|
||||
|
@ -499,11 +449,19 @@ class LiteLLMModel(Model):
|
|||
api_key=self.api_key,
|
||||
**self.kwargs,
|
||||
)
|
||||
tool_calls = response.choices[0].message.tool_calls[0]
|
||||
else:
|
||||
response = litellm.completion(
|
||||
model=self.model_id,
|
||||
messages=messages,
|
||||
stop=stop_sequences,
|
||||
max_tokens=max_tokens,
|
||||
api_base=self.api_base,
|
||||
api_key=self.api_key,
|
||||
**self.kwargs,
|
||||
)
|
||||
self.last_input_token_count = response.usage.prompt_tokens
|
||||
self.last_output_token_count = response.usage.completion_tokens
|
||||
arguments = json.loads(tool_calls.function.arguments)
|
||||
return tool_calls.function.name, arguments, tool_calls.id
|
||||
return response.choices[0].message
|
||||
|
||||
|
||||
class OpenAIServerModel(Model):
|
||||
|
@ -539,64 +497,40 @@ class OpenAIServerModel(Model):
|
|||
self.temperature = temperature
|
||||
self.kwargs = kwargs
|
||||
|
||||
def generate(
|
||||
def __call__(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500,
|
||||
tools_to_call_from: Optional[List[Tool]] = None,
|
||||
) -> str:
|
||||
"""Generates a text completion for the given message list"""
|
||||
messages = get_clean_message_list(
|
||||
messages, role_conversions=tool_role_conversions
|
||||
)
|
||||
|
||||
if tools_to_call_from:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_id,
|
||||
messages=messages,
|
||||
stop=stop_sequences,
|
||||
max_tokens=max_tokens,
|
||||
temperature=self.temperature,
|
||||
**self.kwargs,
|
||||
)
|
||||
|
||||
self.last_input_token_count = response.usage.prompt_tokens
|
||||
self.last_output_token_count = response.usage.completion_tokens
|
||||
return response.choices[0].message.content
|
||||
|
||||
def get_tool_call(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
available_tools: List[Tool],
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
max_tokens: int = 500,
|
||||
) -> Tuple[str, Union[str, Dict], str]:
|
||||
"""Generates a tool call for the given message list"""
|
||||
messages = get_clean_message_list(
|
||||
messages, role_conversions=tool_role_conversions
|
||||
)
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_id,
|
||||
messages=messages,
|
||||
tools=[get_json_schema(tool) for tool in available_tools],
|
||||
tools=[get_json_schema(tool) for tool in tools_to_call_from],
|
||||
tool_choice="auto",
|
||||
stop=stop_sequences,
|
||||
max_tokens=max_tokens,
|
||||
temperature=self.temperature,
|
||||
**self.kwargs,
|
||||
)
|
||||
|
||||
tool_calls = response.choices[0].message.tool_calls[0]
|
||||
else:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_id,
|
||||
messages=messages,
|
||||
stop=stop_sequences,
|
||||
max_tokens=max_tokens,
|
||||
temperature=self.temperature,
|
||||
**self.kwargs,
|
||||
)
|
||||
self.last_input_token_count = response.usage.prompt_tokens
|
||||
self.last_output_token_count = response.usage.completion_tokens
|
||||
|
||||
try:
|
||||
arguments = json.loads(tool_calls.function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = tool_calls.function.arguments
|
||||
|
||||
return tool_calls.function.name, arguments, tool_calls.id
|
||||
return response.choices[0].message
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
|
|
@ -30,6 +30,11 @@ from smolagents.agents import (
|
|||
from smolagents.default_tools import PythonInterpreterTool
|
||||
from smolagents.tools import tool
|
||||
from smolagents.types import AgentImage, AgentText
|
||||
from huggingface_hub import (
|
||||
ChatCompletionOutputMessage,
|
||||
ChatCompletionOutputToolCall,
|
||||
ChatCompletionOutputFunctionDefinition,
|
||||
)
|
||||
|
||||
|
||||
def get_new_path(suffix="") -> str:
|
||||
|
@ -38,54 +43,106 @@ def get_new_path(suffix="") -> str:
|
|||
|
||||
|
||||
class FakeToolCallModel:
|
||||
def get_tool_call(
|
||||
self, messages, available_tools, stop_sequences=None, grammar=None
|
||||
def __call__(
|
||||
self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
|
||||
):
|
||||
if len(messages) < 3:
|
||||
return "python_interpreter", {"code": "2*3.6452"}, "call_0"
|
||||
return ChatCompletionOutputMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatCompletionOutputToolCall(
|
||||
id="call_0",
|
||||
type="function",
|
||||
function=ChatCompletionOutputFunctionDefinition(
|
||||
name="python_interpreter", arguments={"code": "2*3.6452"}
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
else:
|
||||
return "final_answer", {"answer": "7.2904"}, "call_1"
|
||||
return ChatCompletionOutputMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatCompletionOutputToolCall(
|
||||
id="call_1",
|
||||
type="function",
|
||||
function=ChatCompletionOutputFunctionDefinition(
|
||||
name="final_answer", arguments={"answer": "7.2904"}
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class FakeToolCallModelImage:
|
||||
def get_tool_call(
|
||||
self, messages, available_tools, stop_sequences=None, grammar=None
|
||||
def __call__(
|
||||
self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
|
||||
):
|
||||
if len(messages) < 3:
|
||||
return (
|
||||
"fake_image_generation_tool",
|
||||
{"prompt": "An image of a cat"},
|
||||
"call_0",
|
||||
return ChatCompletionOutputMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatCompletionOutputToolCall(
|
||||
id="call_0",
|
||||
type="function",
|
||||
function=ChatCompletionOutputFunctionDefinition(
|
||||
name="fake_image_generation_tool",
|
||||
arguments={"prompt": "An image of a cat"},
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
else:
|
||||
return ChatCompletionOutputMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatCompletionOutputToolCall(
|
||||
id="call_1",
|
||||
type="function",
|
||||
function=ChatCompletionOutputFunctionDefinition(
|
||||
name="final_answer", arguments="image.png"
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
else: # We're at step 2
|
||||
return "final_answer", "image.png", "call_1"
|
||||
|
||||
|
||||
def fake_code_model(messages, stop_sequences=None, grammar=None) -> str:
|
||||
prompt = str(messages)
|
||||
if "special_marker" not in prompt:
|
||||
return """
|
||||
return ChatCompletionOutputMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
Code:
|
||||
```py
|
||||
result = 2**3.6452
|
||||
```<end_code>
|
||||
"""
|
||||
""",
|
||||
)
|
||||
else: # We're at step 2
|
||||
return """
|
||||
return ChatCompletionOutputMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: I can now answer the initial question
|
||||
Code:
|
||||
```py
|
||||
final_answer(7.2904)
|
||||
```<end_code>
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
def fake_code_model_error(messages, stop_sequences=None) -> str:
|
||||
prompt = str(messages)
|
||||
if "special_marker" not in prompt:
|
||||
return """
|
||||
return ChatCompletionOutputMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
Code:
|
||||
```py
|
||||
|
@ -94,21 +151,27 @@ b = a * 2
|
|||
print = 2
|
||||
print("Ok, calculation done!")
|
||||
```<end_code>
|
||||
"""
|
||||
""",
|
||||
)
|
||||
else: # We're at step 2
|
||||
return """
|
||||
return ChatCompletionOutputMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: I can now answer the initial question
|
||||
Code:
|
||||
```py
|
||||
final_answer("got an error")
|
||||
```<end_code>
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
def fake_code_model_syntax_error(messages, stop_sequences=None) -> str:
|
||||
prompt = str(messages)
|
||||
if "special_marker" not in prompt:
|
||||
return """
|
||||
return ChatCompletionOutputMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
Code:
|
||||
```py
|
||||
|
@ -117,32 +180,41 @@ b = a * 2
|
|||
print("Failing due to unexpected indent")
|
||||
print("Ok, calculation done!")
|
||||
```<end_code>
|
||||
"""
|
||||
""",
|
||||
)
|
||||
else: # We're at step 2
|
||||
return """
|
||||
return ChatCompletionOutputMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: I can now answer the initial question
|
||||
Code:
|
||||
```py
|
||||
final_answer("got an error")
|
||||
```<end_code>
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
def fake_code_model_import(messages, stop_sequences=None) -> str:
|
||||
return """
|
||||
return ChatCompletionOutputMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: I can answer the question
|
||||
Code:
|
||||
```py
|
||||
import numpy as np
|
||||
final_answer("got an error")
|
||||
```<end_code>
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
def fake_code_functiondef(messages, stop_sequences=None) -> str:
|
||||
prompt = str(messages)
|
||||
if "special_marker" not in prompt:
|
||||
return """
|
||||
return ChatCompletionOutputMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: Let's define the function. special_marker
|
||||
Code:
|
||||
```py
|
||||
|
@ -151,9 +223,12 @@ import numpy as np
|
|||
def moving_average(x, w):
|
||||
return np.convolve(x, np.ones(w), 'valid') / w
|
||||
```<end_code>
|
||||
"""
|
||||
""",
|
||||
)
|
||||
else: # We're at step 2
|
||||
return """
|
||||
return ChatCompletionOutputMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: I can now answer the initial question
|
||||
Code:
|
||||
```py
|
||||
|
@ -161,29 +236,36 @@ x, w = [0, 1, 2, 3, 4, 5], 2
|
|||
res = moving_average(x, w)
|
||||
final_answer(res)
|
||||
```<end_code>
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) -> str:
|
||||
return """
|
||||
return ChatCompletionOutputMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
Code:
|
||||
```py
|
||||
result = python_interpreter(code="2*3.6452")
|
||||
final_answer(result)
|
||||
```
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
def fake_code_model_no_return(messages, stop_sequences=None, grammar=None) -> str:
|
||||
return """
|
||||
return ChatCompletionOutputMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
Code:
|
||||
```py
|
||||
result = python_interpreter(code="2*3.6452")
|
||||
print(result)
|
||||
```
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class AgentTests(unittest.TestCase):
|
||||
|
@ -360,52 +442,92 @@ class AgentTests(unittest.TestCase):
|
|||
|
||||
def test_multiagents(self):
|
||||
class FakeModelMultiagentsManagerAgent:
|
||||
def __call__(self, messages, stop_sequences=None, grammar=None):
|
||||
def __call__(
|
||||
self,
|
||||
messages,
|
||||
stop_sequences=None,
|
||||
grammar=None,
|
||||
tools_to_call_from=None,
|
||||
):
|
||||
if tools_to_call_from is not None:
|
||||
if len(messages) < 3:
|
||||
return """
|
||||
return ChatCompletionOutputMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatCompletionOutputToolCall(
|
||||
id="call_0",
|
||||
type="function",
|
||||
function=ChatCompletionOutputFunctionDefinition(
|
||||
name="search_agent",
|
||||
arguments="Who is the current US president?",
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
else:
|
||||
assert "Report on the current US president" in str(messages)
|
||||
return ChatCompletionOutputMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatCompletionOutputToolCall(
|
||||
id="call_0",
|
||||
type="function",
|
||||
function=ChatCompletionOutputFunctionDefinition(
|
||||
name="final_answer", arguments="Final report."
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
else:
|
||||
if len(messages) < 3:
|
||||
return ChatCompletionOutputMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: Let's call our search agent.
|
||||
Code:
|
||||
```py
|
||||
result = search_agent("Who is the current US president?")
|
||||
```<end_code>
|
||||
"""
|
||||
""",
|
||||
)
|
||||
else:
|
||||
assert "Report on the current US president" in str(messages)
|
||||
return """
|
||||
return ChatCompletionOutputMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: Let's return the report.
|
||||
Code:
|
||||
```py
|
||||
final_answer("Final report.")
|
||||
```<end_code>
|
||||
"""
|
||||
|
||||
def get_tool_call(
|
||||
self, messages, available_tools, stop_sequences=None, grammar=None
|
||||
):
|
||||
if len(messages) < 3:
|
||||
return (
|
||||
"search_agent",
|
||||
"Who is the current US president?",
|
||||
"call_0",
|
||||
)
|
||||
else:
|
||||
assert "Report on the current US president" in str(messages)
|
||||
return (
|
||||
"final_answer",
|
||||
"Final report.",
|
||||
"call_0",
|
||||
""",
|
||||
)
|
||||
|
||||
manager_model = FakeModelMultiagentsManagerAgent()
|
||||
|
||||
class FakeModelMultiagentsManagedAgent:
|
||||
def get_tool_call(
|
||||
self, messages, available_tools, stop_sequences=None, grammar=None
|
||||
def __call__(
|
||||
self,
|
||||
messages,
|
||||
tools_to_call_from=None,
|
||||
stop_sequences=None,
|
||||
grammar=None,
|
||||
):
|
||||
return (
|
||||
"final_answer",
|
||||
{"report": "Report on the current US president"},
|
||||
"call_0",
|
||||
return ChatCompletionOutputMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatCompletionOutputToolCall(
|
||||
id="call_0",
|
||||
type="function",
|
||||
function=ChatCompletionOutputFunctionDefinition(
|
||||
name="final_answer",
|
||||
arguments="Report on the current US president",
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
managed_model = FakeModelMultiagentsManagedAgent()
|
||||
|
@ -443,13 +565,16 @@ final_answer("Final report.")
|
|||
|
||||
def test_code_nontrivial_final_answer_works(self):
|
||||
def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None):
|
||||
return """Code:
|
||||
return ChatCompletionOutputMessage(
|
||||
role="assistant",
|
||||
content="""Code:
|
||||
```py
|
||||
def nested_answer():
|
||||
final_answer("Correct!")
|
||||
|
||||
nested_answer()
|
||||
```<end_code>"""
|
||||
```<end_code>""",
|
||||
)
|
||||
|
||||
agent = CodeAgent(tools=[], model=fake_code_model_final_answer)
|
||||
|
||||
|
|
|
@ -92,7 +92,6 @@ class TestDocs:
|
|||
raise ValueError(f"Docs directory not found at {cls.docs_dir}")
|
||||
|
||||
load_dotenv()
|
||||
cls.hf_token = os.getenv("HF_TOKEN")
|
||||
|
||||
cls.md_files = list(cls.docs_dir.rglob("*.md"))
|
||||
if not cls.md_files:
|
||||
|
@ -115,6 +114,7 @@ class TestDocs:
|
|||
"from_langchain", # Langchain is not a dependency
|
||||
"while llm_should_continue(memory):", # This is pseudo code
|
||||
"ollama_chat/llama3.2", # Exclude ollama building in guided tour
|
||||
"model = TransformersModel(model_id=model_id)", # Exclude testing with transformers model
|
||||
]
|
||||
code_blocks = [
|
||||
block
|
||||
|
@ -131,10 +131,15 @@ class TestDocs:
|
|||
ast.parse(block)
|
||||
|
||||
# Create and execute test script
|
||||
print("\n\nCollected code block:==========\n".join(code_blocks))
|
||||
try:
|
||||
code_blocks = [
|
||||
block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", self.hf_token).replace(
|
||||
"{your_username}", "m-ric"
|
||||
(
|
||||
block.replace(
|
||||
"<YOUR_HUGGINGFACEHUB_API_TOKEN>", os.getenv("HF_TOKEN")
|
||||
)
|
||||
.replace("YOUR_ANTHROPIC_API_KEY", os.getenv("ANTHROPIC_API_KEY"))
|
||||
.replace("{your_username}", "m-ric")
|
||||
)
|
||||
for block in code_blocks
|
||||
]
|
||||
|
|
|
@ -22,42 +22,57 @@ from smolagents import (
|
|||
ToolCallingAgent,
|
||||
stream_to_gradio,
|
||||
)
|
||||
from huggingface_hub import (
|
||||
ChatCompletionOutputMessage,
|
||||
ChatCompletionOutputToolCall,
|
||||
ChatCompletionOutputFunctionDefinition,
|
||||
)
|
||||
|
||||
|
||||
class MonitoringTester(unittest.TestCase):
|
||||
def test_code_agent_metrics(self):
|
||||
class FakeLLMModel:
|
||||
class FakeLLMModel:
|
||||
def __init__(self):
|
||||
self.last_input_token_count = 10
|
||||
self.last_output_token_count = 20
|
||||
|
||||
def __call__(self, prompt, **kwargs):
|
||||
return """
|
||||
def __call__(self, prompt, tools_to_call_from=None, **kwargs):
|
||||
if tools_to_call_from is not None:
|
||||
return ChatCompletionOutputMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatCompletionOutputToolCall(
|
||||
id="fake_id",
|
||||
type="function",
|
||||
function=ChatCompletionOutputFunctionDefinition(
|
||||
name="final_answer", arguments={"answer": "image"}
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
else:
|
||||
return ChatCompletionOutputMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Code:
|
||||
```py
|
||||
final_answer('This is the final answer.')
|
||||
```"""
|
||||
```""",
|
||||
)
|
||||
|
||||
|
||||
class MonitoringTester(unittest.TestCase):
|
||||
def test_code_agent_metrics(self):
|
||||
agent = CodeAgent(
|
||||
tools=[],
|
||||
model=FakeLLMModel(),
|
||||
max_steps=1,
|
||||
)
|
||||
|
||||
agent.run("Fake task")
|
||||
|
||||
self.assertEqual(agent.monitor.total_input_token_count, 10)
|
||||
self.assertEqual(agent.monitor.total_output_token_count, 20)
|
||||
|
||||
def test_json_agent_metrics(self):
|
||||
class FakeLLMModel:
|
||||
def __init__(self):
|
||||
self.last_input_token_count = 10
|
||||
self.last_output_token_count = 20
|
||||
|
||||
def get_tool_call(self, prompt, **kwargs):
|
||||
return "final_answer", {"answer": "image"}, "fake_id"
|
||||
|
||||
agent = ToolCallingAgent(
|
||||
tools=[],
|
||||
model=FakeLLMModel(),
|
||||
|
@ -70,17 +85,19 @@ final_answer('This is the final answer.')
|
|||
self.assertEqual(agent.monitor.total_output_token_count, 20)
|
||||
|
||||
def test_code_agent_metrics_max_steps(self):
|
||||
class FakeLLMModel:
|
||||
class FakeLLMModelMalformedAnswer:
|
||||
def __init__(self):
|
||||
self.last_input_token_count = 10
|
||||
self.last_output_token_count = 20
|
||||
|
||||
def __call__(self, prompt, **kwargs):
|
||||
return "Malformed answer"
|
||||
return ChatCompletionOutputMessage(
|
||||
role="assistant", content="Malformed answer"
|
||||
)
|
||||
|
||||
agent = CodeAgent(
|
||||
tools=[],
|
||||
model=FakeLLMModel(),
|
||||
model=FakeLLMModelMalformedAnswer(),
|
||||
max_steps=1,
|
||||
)
|
||||
|
||||
|
@ -90,7 +107,7 @@ final_answer('This is the final answer.')
|
|||
self.assertEqual(agent.monitor.total_output_token_count, 40)
|
||||
|
||||
def test_code_agent_metrics_generation_error(self):
|
||||
class FakeLLMModel:
|
||||
class FakeLLMModelGenerationException:
|
||||
def __init__(self):
|
||||
self.last_input_token_count = 10
|
||||
self.last_output_token_count = 20
|
||||
|
@ -102,7 +119,7 @@ final_answer('This is the final answer.')
|
|||
|
||||
agent = CodeAgent(
|
||||
tools=[],
|
||||
model=FakeLLMModel(),
|
||||
model=FakeLLMModelGenerationException(),
|
||||
max_steps=1,
|
||||
)
|
||||
agent.run("Fake task")
|
||||
|
@ -113,16 +130,9 @@ final_answer('This is the final answer.')
|
|||
self.assertEqual(agent.monitor.total_output_token_count, 0)
|
||||
|
||||
def test_streaming_agent_text_output(self):
|
||||
def dummy_model(prompt, **kwargs):
|
||||
return """
|
||||
Code:
|
||||
```py
|
||||
final_answer('This is the final answer.')
|
||||
```"""
|
||||
|
||||
agent = CodeAgent(
|
||||
tools=[],
|
||||
model=dummy_model,
|
||||
model=FakeLLMModel(),
|
||||
max_steps=1,
|
||||
)
|
||||
|
||||
|
@ -135,16 +145,9 @@ final_answer('This is the final answer.')
|
|||
self.assertIn("This is the final answer.", final_message.content)
|
||||
|
||||
def test_streaming_agent_image_output(self):
|
||||
class FakeLLM:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_tool_call(self, messages, **kwargs):
|
||||
return "final_answer", {"answer": "image"}, "fake_id"
|
||||
|
||||
agent = ToolCallingAgent(
|
||||
tools=[],
|
||||
model=FakeLLM(),
|
||||
model=FakeLLMModel(),
|
||||
max_steps=1,
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue