Vastly simplify Model class ✨ (#146)
* Vastly simplify Model class by making only one __call__ method ✨
This commit is contained in:
parent
36ed279c85
commit
5c33130fa4
|
@ -113,8 +113,7 @@ The Python interpreter also doesn't allow imports by default outside of a safe l
|
||||||
You can authorize additional imports by passing the authorized modules as a list of strings in argument `additional_authorized_imports` upon initialization of your [`CodeAgent`]:
|
You can authorize additional imports by passing the authorized modules as a list of strings in argument `additional_authorized_imports` upon initialization of your [`CodeAgent`]:
|
||||||
|
|
||||||
```py
|
```py
|
||||||
from smolagents import CodeAgent
|
model = HfApiModel()
|
||||||
|
|
||||||
agent = CodeAgent(tools=[], model=model, additional_authorized_imports=['requests', 'bs4'])
|
agent = CodeAgent(tools=[], model=model, additional_authorized_imports=['requests', 'bs4'])
|
||||||
agent.run("Could you get me the title of the page at url 'https://huggingface.co/blog'?")
|
agent.run("Could you get me the title of the page at url 'https://huggingface.co/blog'?")
|
||||||
```
|
```
|
||||||
|
@ -164,12 +163,12 @@ Transformers comes with a default toolbox for empowering agents, that you can ad
|
||||||
- **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`ToolCallingAgent`] if you initialize it with `add_base_tools=True`, since code-based agent can already natively execute Python code
|
- **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`ToolCallingAgent`] if you initialize it with `add_base_tools=True`, since code-based agent can already natively execute Python code
|
||||||
- **Transcriber**: a speech-to-text pipeline built on Whisper-Turbo that transcribes an audio to text.
|
- **Transcriber**: a speech-to-text pipeline built on Whisper-Turbo that transcribes an audio to text.
|
||||||
|
|
||||||
You can manually use a tool by calling the [`load_tool`] function and a task to perform.
|
You can manually use a tool by calling it with its arguments.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from smolagents import load_tool
|
from smolagents import DuckDuckGoSearchTool
|
||||||
|
|
||||||
search_tool = load_tool("web_search")
|
search_tool = DuckDuckGoSearchTool()
|
||||||
print(search_tool("Who's the current president of Russia?"))
|
print(search_tool("Who's the current president of Russia?"))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -776,11 +776,15 @@ class ToolCallingAgent(MultiStepAgent):
|
||||||
log_entry.agent_memory = agent_memory.copy()
|
log_entry.agent_memory = agent_memory.copy()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tool_name, tool_arguments, tool_call_id = self.model.get_tool_call(
|
model_message = self.model(
|
||||||
self.input_messages,
|
self.input_messages,
|
||||||
available_tools=list(self.tools.values()),
|
tools_to_call_from=list(self.tools.values()),
|
||||||
stop_sequences=["Observation:"],
|
stop_sequences=["Observation:"],
|
||||||
)
|
)
|
||||||
|
tool_calls = model_message.tool_calls[0]
|
||||||
|
tool_arguments = tool_calls.function.arguments
|
||||||
|
tool_name, tool_call_id = tool_calls.function.name, tool_calls.id
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AgentGenerationError(
|
raise AgentGenerationError(
|
||||||
f"Error in generating tool call with model:\n{e}"
|
f"Error in generating tool call with model:\n{e}"
|
||||||
|
@ -913,7 +917,7 @@ class CodeAgent(MultiStepAgent):
|
||||||
self.input_messages,
|
self.input_messages,
|
||||||
stop_sequences=["<end_code>", "Observation:"],
|
stop_sequences=["<end_code>", "Observation:"],
|
||||||
**additional_args,
|
**additional_args,
|
||||||
)
|
).content
|
||||||
log_entry.llm_output = llm_output
|
log_entry.llm_output = llm_output
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AgentGenerationError(f"Error in generating model output:\n{e}")
|
raise AgentGenerationError(f"Error in generating model output:\n{e}")
|
||||||
|
|
|
@ -20,10 +20,16 @@ import os
|
||||||
import random
|
import random
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import InferenceClient
|
from huggingface_hub import (
|
||||||
|
InferenceClient,
|
||||||
|
ChatCompletionOutputMessage,
|
||||||
|
ChatCompletionOutputToolCall,
|
||||||
|
ChatCompletionOutputFunctionDefinition,
|
||||||
|
)
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
@ -33,7 +39,6 @@ from transformers import (
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
from .tools import Tool
|
from .tools import Tool
|
||||||
from .utils import parse_json_tool_call
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -234,63 +239,46 @@ class HfApiModel(Model):
|
||||||
model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct",
|
model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct",
|
||||||
token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
timeout: Optional[int] = 120,
|
timeout: Optional[int] = 120,
|
||||||
|
temperature: float = 0.5,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
if token is None:
|
if token is None:
|
||||||
token = os.getenv("HF_TOKEN")
|
token = os.getenv("HF_TOKEN")
|
||||||
self.client = InferenceClient(self.model_id, token=token, timeout=timeout)
|
self.client = InferenceClient(self.model_id, token=token, timeout=timeout)
|
||||||
|
self.temperature = temperature
|
||||||
|
|
||||||
def generate(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
stop_sequences: Optional[List[str]] = None,
|
stop_sequences: Optional[List[str]] = None,
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
max_tokens: int = 1500,
|
max_tokens: int = 1500,
|
||||||
|
tools_to_call_from: Optional[List[Tool]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generates a text completion for the given message list"""
|
|
||||||
messages = get_clean_message_list(
|
|
||||||
messages, role_conversions=tool_role_conversions
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send messages to the Hugging Face Inference API
|
|
||||||
if grammar is not None:
|
|
||||||
output = self.client.chat_completion(
|
|
||||||
messages,
|
|
||||||
stop=stop_sequences,
|
|
||||||
response_format=grammar,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
output = self.client.chat.completions.create(
|
|
||||||
messages, stop=stop_sequences, max_tokens=max_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
response = output.choices[0].message.content
|
|
||||||
self.last_input_token_count = output.usage.prompt_tokens
|
|
||||||
self.last_output_token_count = output.usage.completion_tokens
|
|
||||||
return response
|
|
||||||
|
|
||||||
def get_tool_call(
|
|
||||||
self,
|
|
||||||
messages: List[Dict[str, str]],
|
|
||||||
available_tools: List[Tool],
|
|
||||||
stop_sequences,
|
|
||||||
):
|
|
||||||
"""Generates a tool call for the given message list. This method is used only by `ToolCallingAgent`."""
|
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
messages, role_conversions=tool_role_conversions
|
messages, role_conversions=tool_role_conversions
|
||||||
)
|
)
|
||||||
|
if tools_to_call_from:
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=[get_json_schema(tool) for tool in available_tools],
|
tools=[get_json_schema(tool) for tool in tools_to_call_from],
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
stop=stop_sequences,
|
stop=stop_sequences,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=self.temperature,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model_id,
|
||||||
|
messages=messages,
|
||||||
|
stop=stop_sequences,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=self.temperature,
|
||||||
)
|
)
|
||||||
tool_call = response.choices[0].message.tool_calls[0]
|
|
||||||
self.last_input_token_count = response.usage.prompt_tokens
|
self.last_input_token_count = response.usage.prompt_tokens
|
||||||
self.last_output_token_count = response.usage.completion_tokens
|
self.last_output_token_count = response.usage.completion_tokens
|
||||||
return tool_call.function.name, tool_call.function.arguments, tool_call.id
|
return response.choices[0].message
|
||||||
|
|
||||||
|
|
||||||
class TransformersModel(Model):
|
class TransformersModel(Model):
|
||||||
|
@ -354,18 +342,27 @@ class TransformersModel(Model):
|
||||||
|
|
||||||
return StoppingCriteriaList([StopOnStrings(stop_sequences, self.tokenizer)])
|
return StoppingCriteriaList([StopOnStrings(stop_sequences, self.tokenizer)])
|
||||||
|
|
||||||
def generate(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
stop_sequences: Optional[List[str]] = None,
|
stop_sequences: Optional[List[str]] = None,
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
max_tokens: int = 1500,
|
max_tokens: int = 1500,
|
||||||
|
tools_to_call_from: Optional[List[Tool]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
messages, role_conversions=tool_role_conversions
|
messages, role_conversions=tool_role_conversions
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get LLM output
|
if tools_to_call_from is not None:
|
||||||
|
prompt_tensor = self.tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tools=[get_json_schema(tool) for tool in tools_to_call_from],
|
||||||
|
return_tensors="pt",
|
||||||
|
return_dict=True,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
prompt_tensor = self.tokenizer.apply_chat_template(
|
prompt_tensor = self.tokenizer.apply_chat_template(
|
||||||
messages,
|
messages,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
|
@ -382,56 +379,31 @@ class TransformersModel(Model):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
generated_tokens = out[0, count_prompt_tokens:]
|
generated_tokens = out[0, count_prompt_tokens:]
|
||||||
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||||
|
|
||||||
self.last_input_token_count = count_prompt_tokens
|
self.last_input_token_count = count_prompt_tokens
|
||||||
self.last_output_token_count = len(generated_tokens)
|
self.last_output_token_count = len(generated_tokens)
|
||||||
|
|
||||||
if stop_sequences is not None:
|
if stop_sequences is not None:
|
||||||
response = remove_stop_sequences(response, stop_sequences)
|
output = remove_stop_sequences(output, stop_sequences)
|
||||||
return response
|
|
||||||
|
|
||||||
def get_tool_call(
|
if tools_to_call_from is None:
|
||||||
self,
|
return ChatCompletionOutputMessage(role="assistant", content=output)
|
||||||
messages: List[Dict[str, str]],
|
else:
|
||||||
available_tools: List[Tool],
|
tool_name, tool_arguments = json.load(output)
|
||||||
stop_sequences: Optional[List[str]] = None,
|
return ChatCompletionOutputMessage(
|
||||||
max_tokens: int = 500,
|
role="assistant",
|
||||||
) -> Tuple[str, Union[str, None], str]:
|
content="",
|
||||||
messages = get_clean_message_list(
|
tool_calls=[
|
||||||
messages, role_conversions=tool_role_conversions
|
ChatCompletionOutputToolCall(
|
||||||
)
|
id="".join(random.choices("0123456789", k=5)),
|
||||||
|
type="function",
|
||||||
prompt = self.tokenizer.apply_chat_template(
|
function=ChatCompletionOutputFunctionDefinition(
|
||||||
messages,
|
name=tool_name, arguments=tool_arguments
|
||||||
tools=[get_json_schema(tool) for tool in available_tools],
|
|
||||||
return_tensors="pt",
|
|
||||||
return_dict=True,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
)
|
|
||||||
prompt = prompt.to(self.model.device)
|
|
||||||
count_prompt_tokens = prompt["input_ids"].shape[1]
|
|
||||||
|
|
||||||
out = self.model.generate(
|
|
||||||
**prompt,
|
|
||||||
max_new_tokens=max_tokens,
|
|
||||||
stopping_criteria=(
|
|
||||||
self.make_stopping_criteria(stop_sequences) if stop_sequences else None
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
generated_tokens = out[0, count_prompt_tokens:]
|
],
|
||||||
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
)
|
||||||
|
|
||||||
self.last_input_token_count = count_prompt_tokens
|
|
||||||
self.last_output_token_count = len(generated_tokens)
|
|
||||||
|
|
||||||
if stop_sequences is not None:
|
|
||||||
response = remove_stop_sequences(response, stop_sequences)
|
|
||||||
|
|
||||||
tool_name, tool_input = parse_json_tool_call(response)
|
|
||||||
call_id = "".join(random.choices("0123456789", k=5))
|
|
||||||
|
|
||||||
return tool_name, tool_input, call_id
|
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMModel(Model):
|
class LiteLLMModel(Model):
|
||||||
|
@ -460,38 +432,16 @@ class LiteLLMModel(Model):
|
||||||
stop_sequences: Optional[List[str]] = None,
|
stop_sequences: Optional[List[str]] = None,
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
max_tokens: int = 1500,
|
max_tokens: int = 1500,
|
||||||
|
tools_to_call_from: Optional[List[Tool]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
messages, role_conversions=tool_role_conversions
|
messages, role_conversions=tool_role_conversions
|
||||||
)
|
)
|
||||||
|
if tools_to_call_from:
|
||||||
response = litellm.completion(
|
response = litellm.completion(
|
||||||
model=self.model_id,
|
model=self.model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stop=stop_sequences,
|
tools=[get_json_schema(tool) for tool in tools_to_call_from],
|
||||||
max_tokens=max_tokens,
|
|
||||||
api_base=self.api_base,
|
|
||||||
api_key=self.api_key,
|
|
||||||
**self.kwargs,
|
|
||||||
)
|
|
||||||
self.last_input_token_count = response.usage.prompt_tokens
|
|
||||||
self.last_output_token_count = response.usage.completion_tokens
|
|
||||||
return response.choices[0].message.content
|
|
||||||
|
|
||||||
def get_tool_call(
|
|
||||||
self,
|
|
||||||
messages: List[Dict[str, str]],
|
|
||||||
available_tools: List[Tool],
|
|
||||||
stop_sequences: Optional[List[str]] = None,
|
|
||||||
max_tokens: int = 1500,
|
|
||||||
):
|
|
||||||
messages = get_clean_message_list(
|
|
||||||
messages, role_conversions=tool_role_conversions
|
|
||||||
)
|
|
||||||
response = litellm.completion(
|
|
||||||
model=self.model_id,
|
|
||||||
messages=messages,
|
|
||||||
tools=[get_json_schema(tool) for tool in available_tools],
|
|
||||||
tool_choice="required",
|
tool_choice="required",
|
||||||
stop=stop_sequences,
|
stop=stop_sequences,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
@ -499,11 +449,19 @@ class LiteLLMModel(Model):
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
**self.kwargs,
|
**self.kwargs,
|
||||||
)
|
)
|
||||||
tool_calls = response.choices[0].message.tool_calls[0]
|
else:
|
||||||
|
response = litellm.completion(
|
||||||
|
model=self.model_id,
|
||||||
|
messages=messages,
|
||||||
|
stop=stop_sequences,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
api_base=self.api_base,
|
||||||
|
api_key=self.api_key,
|
||||||
|
**self.kwargs,
|
||||||
|
)
|
||||||
self.last_input_token_count = response.usage.prompt_tokens
|
self.last_input_token_count = response.usage.prompt_tokens
|
||||||
self.last_output_token_count = response.usage.completion_tokens
|
self.last_output_token_count = response.usage.completion_tokens
|
||||||
arguments = json.loads(tool_calls.function.arguments)
|
return response.choices[0].message
|
||||||
return tool_calls.function.name, arguments, tool_calls.id
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIServerModel(Model):
|
class OpenAIServerModel(Model):
|
||||||
|
@ -539,64 +497,40 @@ class OpenAIServerModel(Model):
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
def generate(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
stop_sequences: Optional[List[str]] = None,
|
stop_sequences: Optional[List[str]] = None,
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
max_tokens: int = 1500,
|
max_tokens: int = 1500,
|
||||||
|
tools_to_call_from: Optional[List[Tool]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generates a text completion for the given message list"""
|
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
messages, role_conversions=tool_role_conversions
|
messages, role_conversions=tool_role_conversions
|
||||||
)
|
)
|
||||||
|
if tools_to_call_from:
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=self.model_id,
|
model=self.model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stop=stop_sequences,
|
tools=[get_json_schema(tool) for tool in tools_to_call_from],
|
||||||
max_tokens=max_tokens,
|
|
||||||
temperature=self.temperature,
|
|
||||||
**self.kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.last_input_token_count = response.usage.prompt_tokens
|
|
||||||
self.last_output_token_count = response.usage.completion_tokens
|
|
||||||
return response.choices[0].message.content
|
|
||||||
|
|
||||||
def get_tool_call(
|
|
||||||
self,
|
|
||||||
messages: List[Dict[str, str]],
|
|
||||||
available_tools: List[Tool],
|
|
||||||
stop_sequences: Optional[List[str]] = None,
|
|
||||||
max_tokens: int = 500,
|
|
||||||
) -> Tuple[str, Union[str, Dict], str]:
|
|
||||||
"""Generates a tool call for the given message list"""
|
|
||||||
messages = get_clean_message_list(
|
|
||||||
messages, role_conversions=tool_role_conversions
|
|
||||||
)
|
|
||||||
|
|
||||||
response = self.client.chat.completions.create(
|
|
||||||
model=self.model_id,
|
|
||||||
messages=messages,
|
|
||||||
tools=[get_json_schema(tool) for tool in available_tools],
|
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
stop=stop_sequences,
|
stop=stop_sequences,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
**self.kwargs,
|
**self.kwargs,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
tool_calls = response.choices[0].message.tool_calls[0]
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model_id,
|
||||||
|
messages=messages,
|
||||||
|
stop=stop_sequences,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=self.temperature,
|
||||||
|
**self.kwargs,
|
||||||
|
)
|
||||||
self.last_input_token_count = response.usage.prompt_tokens
|
self.last_input_token_count = response.usage.prompt_tokens
|
||||||
self.last_output_token_count = response.usage.completion_tokens
|
self.last_output_token_count = response.usage.completion_tokens
|
||||||
|
return response.choices[0].message
|
||||||
try:
|
|
||||||
arguments = json.loads(tool_calls.function.arguments)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
arguments = tool_calls.function.arguments
|
|
||||||
|
|
||||||
return tool_calls.function.name, arguments, tool_calls.id
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
|
@ -30,6 +30,11 @@ from smolagents.agents import (
|
||||||
from smolagents.default_tools import PythonInterpreterTool
|
from smolagents.default_tools import PythonInterpreterTool
|
||||||
from smolagents.tools import tool
|
from smolagents.tools import tool
|
||||||
from smolagents.types import AgentImage, AgentText
|
from smolagents.types import AgentImage, AgentText
|
||||||
|
from huggingface_hub import (
|
||||||
|
ChatCompletionOutputMessage,
|
||||||
|
ChatCompletionOutputToolCall,
|
||||||
|
ChatCompletionOutputFunctionDefinition,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_new_path(suffix="") -> str:
|
def get_new_path(suffix="") -> str:
|
||||||
|
@ -38,54 +43,106 @@ def get_new_path(suffix="") -> str:
|
||||||
|
|
||||||
|
|
||||||
class FakeToolCallModel:
|
class FakeToolCallModel:
|
||||||
def get_tool_call(
|
def __call__(
|
||||||
self, messages, available_tools, stop_sequences=None, grammar=None
|
self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
|
||||||
):
|
):
|
||||||
if len(messages) < 3:
|
if len(messages) < 3:
|
||||||
return "python_interpreter", {"code": "2*3.6452"}, "call_0"
|
return ChatCompletionOutputMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionOutputToolCall(
|
||||||
|
id="call_0",
|
||||||
|
type="function",
|
||||||
|
function=ChatCompletionOutputFunctionDefinition(
|
||||||
|
name="python_interpreter", arguments={"code": "2*3.6452"}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return "final_answer", {"answer": "7.2904"}, "call_1"
|
return ChatCompletionOutputMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionOutputToolCall(
|
||||||
|
id="call_1",
|
||||||
|
type="function",
|
||||||
|
function=ChatCompletionOutputFunctionDefinition(
|
||||||
|
name="final_answer", arguments={"answer": "7.2904"}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FakeToolCallModelImage:
|
class FakeToolCallModelImage:
|
||||||
def get_tool_call(
|
def __call__(
|
||||||
self, messages, available_tools, stop_sequences=None, grammar=None
|
self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
|
||||||
):
|
):
|
||||||
if len(messages) < 3:
|
if len(messages) < 3:
|
||||||
return (
|
return ChatCompletionOutputMessage(
|
||||||
"fake_image_generation_tool",
|
role="assistant",
|
||||||
{"prompt": "An image of a cat"},
|
content="",
|
||||||
"call_0",
|
tool_calls=[
|
||||||
|
ChatCompletionOutputToolCall(
|
||||||
|
id="call_0",
|
||||||
|
type="function",
|
||||||
|
function=ChatCompletionOutputFunctionDefinition(
|
||||||
|
name="fake_image_generation_tool",
|
||||||
|
arguments={"prompt": "An image of a cat"},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return ChatCompletionOutputMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionOutputToolCall(
|
||||||
|
id="call_1",
|
||||||
|
type="function",
|
||||||
|
function=ChatCompletionOutputFunctionDefinition(
|
||||||
|
name="final_answer", arguments="image.png"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
else: # We're at step 2
|
|
||||||
return "final_answer", "image.png", "call_1"
|
|
||||||
|
|
||||||
|
|
||||||
def fake_code_model(messages, stop_sequences=None, grammar=None) -> str:
|
def fake_code_model(messages, stop_sequences=None, grammar=None) -> str:
|
||||||
prompt = str(messages)
|
prompt = str(messages)
|
||||||
if "special_marker" not in prompt:
|
if "special_marker" not in prompt:
|
||||||
return """
|
return ChatCompletionOutputMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="""
|
||||||
Thought: I should multiply 2 by 3.6452. special_marker
|
Thought: I should multiply 2 by 3.6452. special_marker
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
result = 2**3.6452
|
result = 2**3.6452
|
||||||
```<end_code>
|
```<end_code>
|
||||||
"""
|
""",
|
||||||
|
)
|
||||||
else: # We're at step 2
|
else: # We're at step 2
|
||||||
return """
|
return ChatCompletionOutputMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="""
|
||||||
Thought: I can now answer the initial question
|
Thought: I can now answer the initial question
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
final_answer(7.2904)
|
final_answer(7.2904)
|
||||||
```<end_code>
|
```<end_code>
|
||||||
"""
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def fake_code_model_error(messages, stop_sequences=None) -> str:
|
def fake_code_model_error(messages, stop_sequences=None) -> str:
|
||||||
prompt = str(messages)
|
prompt = str(messages)
|
||||||
if "special_marker" not in prompt:
|
if "special_marker" not in prompt:
|
||||||
return """
|
return ChatCompletionOutputMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="""
|
||||||
Thought: I should multiply 2 by 3.6452. special_marker
|
Thought: I should multiply 2 by 3.6452. special_marker
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
|
@ -94,21 +151,27 @@ b = a * 2
|
||||||
print = 2
|
print = 2
|
||||||
print("Ok, calculation done!")
|
print("Ok, calculation done!")
|
||||||
```<end_code>
|
```<end_code>
|
||||||
"""
|
""",
|
||||||
|
)
|
||||||
else: # We're at step 2
|
else: # We're at step 2
|
||||||
return """
|
return ChatCompletionOutputMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="""
|
||||||
Thought: I can now answer the initial question
|
Thought: I can now answer the initial question
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
final_answer("got an error")
|
final_answer("got an error")
|
||||||
```<end_code>
|
```<end_code>
|
||||||
"""
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def fake_code_model_syntax_error(messages, stop_sequences=None) -> str:
|
def fake_code_model_syntax_error(messages, stop_sequences=None) -> str:
|
||||||
prompt = str(messages)
|
prompt = str(messages)
|
||||||
if "special_marker" not in prompt:
|
if "special_marker" not in prompt:
|
||||||
return """
|
return ChatCompletionOutputMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="""
|
||||||
Thought: I should multiply 2 by 3.6452. special_marker
|
Thought: I should multiply 2 by 3.6452. special_marker
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
|
@ -117,32 +180,41 @@ b = a * 2
|
||||||
print("Failing due to unexpected indent")
|
print("Failing due to unexpected indent")
|
||||||
print("Ok, calculation done!")
|
print("Ok, calculation done!")
|
||||||
```<end_code>
|
```<end_code>
|
||||||
"""
|
""",
|
||||||
|
)
|
||||||
else: # We're at step 2
|
else: # We're at step 2
|
||||||
return """
|
return ChatCompletionOutputMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="""
|
||||||
Thought: I can now answer the initial question
|
Thought: I can now answer the initial question
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
final_answer("got an error")
|
final_answer("got an error")
|
||||||
```<end_code>
|
```<end_code>
|
||||||
"""
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def fake_code_model_import(messages, stop_sequences=None) -> str:
|
def fake_code_model_import(messages, stop_sequences=None) -> str:
|
||||||
return """
|
return ChatCompletionOutputMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="""
|
||||||
Thought: I can answer the question
|
Thought: I can answer the question
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
final_answer("got an error")
|
final_answer("got an error")
|
||||||
```<end_code>
|
```<end_code>
|
||||||
"""
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def fake_code_functiondef(messages, stop_sequences=None) -> str:
|
def fake_code_functiondef(messages, stop_sequences=None) -> str:
|
||||||
prompt = str(messages)
|
prompt = str(messages)
|
||||||
if "special_marker" not in prompt:
|
if "special_marker" not in prompt:
|
||||||
return """
|
return ChatCompletionOutputMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="""
|
||||||
Thought: Let's define the function. special_marker
|
Thought: Let's define the function. special_marker
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
|
@ -151,9 +223,12 @@ import numpy as np
|
||||||
def moving_average(x, w):
|
def moving_average(x, w):
|
||||||
return np.convolve(x, np.ones(w), 'valid') / w
|
return np.convolve(x, np.ones(w), 'valid') / w
|
||||||
```<end_code>
|
```<end_code>
|
||||||
"""
|
""",
|
||||||
|
)
|
||||||
else: # We're at step 2
|
else: # We're at step 2
|
||||||
return """
|
return ChatCompletionOutputMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="""
|
||||||
Thought: I can now answer the initial question
|
Thought: I can now answer the initial question
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
|
@ -161,29 +236,36 @@ x, w = [0, 1, 2, 3, 4, 5], 2
|
||||||
res = moving_average(x, w)
|
res = moving_average(x, w)
|
||||||
final_answer(res)
|
final_answer(res)
|
||||||
```<end_code>
|
```<end_code>
|
||||||
"""
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) -> str:
|
def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) -> str:
|
||||||
return """
|
return ChatCompletionOutputMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="""
|
||||||
Thought: I should multiply 2 by 3.6452. special_marker
|
Thought: I should multiply 2 by 3.6452. special_marker
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
result = python_interpreter(code="2*3.6452")
|
result = python_interpreter(code="2*3.6452")
|
||||||
final_answer(result)
|
final_answer(result)
|
||||||
```
|
```
|
||||||
"""
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def fake_code_model_no_return(messages, stop_sequences=None, grammar=None) -> str:
|
def fake_code_model_no_return(messages, stop_sequences=None, grammar=None) -> str:
|
||||||
return """
|
return ChatCompletionOutputMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="""
|
||||||
Thought: I should multiply 2 by 3.6452. special_marker
|
Thought: I should multiply 2 by 3.6452. special_marker
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
result = python_interpreter(code="2*3.6452")
|
result = python_interpreter(code="2*3.6452")
|
||||||
print(result)
|
print(result)
|
||||||
```
|
```
|
||||||
"""
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AgentTests(unittest.TestCase):
|
class AgentTests(unittest.TestCase):
|
||||||
|
@ -360,52 +442,92 @@ class AgentTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_multiagents(self):
|
def test_multiagents(self):
|
||||||
class FakeModelMultiagentsManagerAgent:
|
class FakeModelMultiagentsManagerAgent:
|
||||||
def __call__(self, messages, stop_sequences=None, grammar=None):
|
def __call__(
|
||||||
|
self,
|
||||||
|
messages,
|
||||||
|
stop_sequences=None,
|
||||||
|
grammar=None,
|
||||||
|
tools_to_call_from=None,
|
||||||
|
):
|
||||||
|
if tools_to_call_from is not None:
|
||||||
if len(messages) < 3:
|
if len(messages) < 3:
|
||||||
return """
|
return ChatCompletionOutputMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionOutputToolCall(
|
||||||
|
id="call_0",
|
||||||
|
type="function",
|
||||||
|
function=ChatCompletionOutputFunctionDefinition(
|
||||||
|
name="search_agent",
|
||||||
|
arguments="Who is the current US president?",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert "Report on the current US president" in str(messages)
|
||||||
|
return ChatCompletionOutputMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionOutputToolCall(
|
||||||
|
id="call_0",
|
||||||
|
type="function",
|
||||||
|
function=ChatCompletionOutputFunctionDefinition(
|
||||||
|
name="final_answer", arguments="Final report."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if len(messages) < 3:
|
||||||
|
return ChatCompletionOutputMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="""
|
||||||
Thought: Let's call our search agent.
|
Thought: Let's call our search agent.
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
result = search_agent("Who is the current US president?")
|
result = search_agent("Who is the current US president?")
|
||||||
```<end_code>
|
```<end_code>
|
||||||
"""
|
""",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert "Report on the current US president" in str(messages)
|
assert "Report on the current US president" in str(messages)
|
||||||
return """
|
return ChatCompletionOutputMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="""
|
||||||
Thought: Let's return the report.
|
Thought: Let's return the report.
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
final_answer("Final report.")
|
final_answer("Final report.")
|
||||||
```<end_code>
|
```<end_code>
|
||||||
"""
|
""",
|
||||||
|
|
||||||
def get_tool_call(
|
|
||||||
self, messages, available_tools, stop_sequences=None, grammar=None
|
|
||||||
):
|
|
||||||
if len(messages) < 3:
|
|
||||||
return (
|
|
||||||
"search_agent",
|
|
||||||
"Who is the current US president?",
|
|
||||||
"call_0",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert "Report on the current US president" in str(messages)
|
|
||||||
return (
|
|
||||||
"final_answer",
|
|
||||||
"Final report.",
|
|
||||||
"call_0",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
manager_model = FakeModelMultiagentsManagerAgent()
|
manager_model = FakeModelMultiagentsManagerAgent()
|
||||||
|
|
||||||
class FakeModelMultiagentsManagedAgent:
|
class FakeModelMultiagentsManagedAgent:
|
||||||
def get_tool_call(
|
def __call__(
|
||||||
self, messages, available_tools, stop_sequences=None, grammar=None
|
self,
|
||||||
|
messages,
|
||||||
|
tools_to_call_from=None,
|
||||||
|
stop_sequences=None,
|
||||||
|
grammar=None,
|
||||||
):
|
):
|
||||||
return (
|
return ChatCompletionOutputMessage(
|
||||||
"final_answer",
|
role="assistant",
|
||||||
{"report": "Report on the current US president"},
|
content="",
|
||||||
"call_0",
|
tool_calls=[
|
||||||
|
ChatCompletionOutputToolCall(
|
||||||
|
id="call_0",
|
||||||
|
type="function",
|
||||||
|
function=ChatCompletionOutputFunctionDefinition(
|
||||||
|
name="final_answer",
|
||||||
|
arguments="Report on the current US president",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
managed_model = FakeModelMultiagentsManagedAgent()
|
managed_model = FakeModelMultiagentsManagedAgent()
|
||||||
|
@ -443,13 +565,16 @@ final_answer("Final report.")
|
||||||
|
|
||||||
def test_code_nontrivial_final_answer_works(self):
|
def test_code_nontrivial_final_answer_works(self):
|
||||||
def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None):
|
def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None):
|
||||||
return """Code:
|
return ChatCompletionOutputMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="""Code:
|
||||||
```py
|
```py
|
||||||
def nested_answer():
|
def nested_answer():
|
||||||
final_answer("Correct!")
|
final_answer("Correct!")
|
||||||
|
|
||||||
nested_answer()
|
nested_answer()
|
||||||
```<end_code>"""
|
```<end_code>""",
|
||||||
|
)
|
||||||
|
|
||||||
agent = CodeAgent(tools=[], model=fake_code_model_final_answer)
|
agent = CodeAgent(tools=[], model=fake_code_model_final_answer)
|
||||||
|
|
||||||
|
|
|
@ -92,7 +92,6 @@ class TestDocs:
|
||||||
raise ValueError(f"Docs directory not found at {cls.docs_dir}")
|
raise ValueError(f"Docs directory not found at {cls.docs_dir}")
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
cls.hf_token = os.getenv("HF_TOKEN")
|
|
||||||
|
|
||||||
cls.md_files = list(cls.docs_dir.rglob("*.md"))
|
cls.md_files = list(cls.docs_dir.rglob("*.md"))
|
||||||
if not cls.md_files:
|
if not cls.md_files:
|
||||||
|
@ -115,6 +114,7 @@ class TestDocs:
|
||||||
"from_langchain", # Langchain is not a dependency
|
"from_langchain", # Langchain is not a dependency
|
||||||
"while llm_should_continue(memory):", # This is pseudo code
|
"while llm_should_continue(memory):", # This is pseudo code
|
||||||
"ollama_chat/llama3.2", # Exclude ollama building in guided tour
|
"ollama_chat/llama3.2", # Exclude ollama building in guided tour
|
||||||
|
"model = TransformersModel(model_id=model_id)", # Exclude testing with transformers model
|
||||||
]
|
]
|
||||||
code_blocks = [
|
code_blocks = [
|
||||||
block
|
block
|
||||||
|
@ -131,10 +131,15 @@ class TestDocs:
|
||||||
ast.parse(block)
|
ast.parse(block)
|
||||||
|
|
||||||
# Create and execute test script
|
# Create and execute test script
|
||||||
|
print("\n\nCollected code block:==========\n".join(code_blocks))
|
||||||
try:
|
try:
|
||||||
code_blocks = [
|
code_blocks = [
|
||||||
block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", self.hf_token).replace(
|
(
|
||||||
"{your_username}", "m-ric"
|
block.replace(
|
||||||
|
"<YOUR_HUGGINGFACEHUB_API_TOKEN>", os.getenv("HF_TOKEN")
|
||||||
|
)
|
||||||
|
.replace("YOUR_ANTHROPIC_API_KEY", os.getenv("ANTHROPIC_API_KEY"))
|
||||||
|
.replace("{your_username}", "m-ric")
|
||||||
)
|
)
|
||||||
for block in code_blocks
|
for block in code_blocks
|
||||||
]
|
]
|
||||||
|
|
|
@ -22,42 +22,57 @@ from smolagents import (
|
||||||
ToolCallingAgent,
|
ToolCallingAgent,
|
||||||
stream_to_gradio,
|
stream_to_gradio,
|
||||||
)
|
)
|
||||||
|
from huggingface_hub import (
|
||||||
|
ChatCompletionOutputMessage,
|
||||||
|
ChatCompletionOutputToolCall,
|
||||||
|
ChatCompletionOutputFunctionDefinition,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MonitoringTester(unittest.TestCase):
|
|
||||||
def test_code_agent_metrics(self):
|
|
||||||
class FakeLLMModel:
|
class FakeLLMModel:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.last_input_token_count = 10
|
self.last_input_token_count = 10
|
||||||
self.last_output_token_count = 20
|
self.last_output_token_count = 20
|
||||||
|
|
||||||
def __call__(self, prompt, **kwargs):
|
def __call__(self, prompt, tools_to_call_from=None, **kwargs):
|
||||||
return """
|
if tools_to_call_from is not None:
|
||||||
|
return ChatCompletionOutputMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionOutputToolCall(
|
||||||
|
id="fake_id",
|
||||||
|
type="function",
|
||||||
|
function=ChatCompletionOutputFunctionDefinition(
|
||||||
|
name="final_answer", arguments={"answer": "image"}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return ChatCompletionOutputMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="""
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
final_answer('This is the final answer.')
|
final_answer('This is the final answer.')
|
||||||
```"""
|
```""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MonitoringTester(unittest.TestCase):
|
||||||
|
def test_code_agent_metrics(self):
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
model=FakeLLMModel(),
|
model=FakeLLMModel(),
|
||||||
max_steps=1,
|
max_steps=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
agent.run("Fake task")
|
agent.run("Fake task")
|
||||||
|
|
||||||
self.assertEqual(agent.monitor.total_input_token_count, 10)
|
self.assertEqual(agent.monitor.total_input_token_count, 10)
|
||||||
self.assertEqual(agent.monitor.total_output_token_count, 20)
|
self.assertEqual(agent.monitor.total_output_token_count, 20)
|
||||||
|
|
||||||
def test_json_agent_metrics(self):
|
def test_json_agent_metrics(self):
|
||||||
class FakeLLMModel:
|
|
||||||
def __init__(self):
|
|
||||||
self.last_input_token_count = 10
|
|
||||||
self.last_output_token_count = 20
|
|
||||||
|
|
||||||
def get_tool_call(self, prompt, **kwargs):
|
|
||||||
return "final_answer", {"answer": "image"}, "fake_id"
|
|
||||||
|
|
||||||
agent = ToolCallingAgent(
|
agent = ToolCallingAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
model=FakeLLMModel(),
|
model=FakeLLMModel(),
|
||||||
|
@ -70,17 +85,19 @@ final_answer('This is the final answer.')
|
||||||
self.assertEqual(agent.monitor.total_output_token_count, 20)
|
self.assertEqual(agent.monitor.total_output_token_count, 20)
|
||||||
|
|
||||||
def test_code_agent_metrics_max_steps(self):
|
def test_code_agent_metrics_max_steps(self):
|
||||||
class FakeLLMModel:
|
class FakeLLMModelMalformedAnswer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.last_input_token_count = 10
|
self.last_input_token_count = 10
|
||||||
self.last_output_token_count = 20
|
self.last_output_token_count = 20
|
||||||
|
|
||||||
def __call__(self, prompt, **kwargs):
|
def __call__(self, prompt, **kwargs):
|
||||||
return "Malformed answer"
|
return ChatCompletionOutputMessage(
|
||||||
|
role="assistant", content="Malformed answer"
|
||||||
|
)
|
||||||
|
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
model=FakeLLMModel(),
|
model=FakeLLMModelMalformedAnswer(),
|
||||||
max_steps=1,
|
max_steps=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -90,7 +107,7 @@ final_answer('This is the final answer.')
|
||||||
self.assertEqual(agent.monitor.total_output_token_count, 40)
|
self.assertEqual(agent.monitor.total_output_token_count, 40)
|
||||||
|
|
||||||
def test_code_agent_metrics_generation_error(self):
|
def test_code_agent_metrics_generation_error(self):
|
||||||
class FakeLLMModel:
|
class FakeLLMModelGenerationException:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.last_input_token_count = 10
|
self.last_input_token_count = 10
|
||||||
self.last_output_token_count = 20
|
self.last_output_token_count = 20
|
||||||
|
@ -102,7 +119,7 @@ final_answer('This is the final answer.')
|
||||||
|
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
model=FakeLLMModel(),
|
model=FakeLLMModelGenerationException(),
|
||||||
max_steps=1,
|
max_steps=1,
|
||||||
)
|
)
|
||||||
agent.run("Fake task")
|
agent.run("Fake task")
|
||||||
|
@ -113,16 +130,9 @@ final_answer('This is the final answer.')
|
||||||
self.assertEqual(agent.monitor.total_output_token_count, 0)
|
self.assertEqual(agent.monitor.total_output_token_count, 0)
|
||||||
|
|
||||||
def test_streaming_agent_text_output(self):
|
def test_streaming_agent_text_output(self):
|
||||||
def dummy_model(prompt, **kwargs):
|
|
||||||
return """
|
|
||||||
Code:
|
|
||||||
```py
|
|
||||||
final_answer('This is the final answer.')
|
|
||||||
```"""
|
|
||||||
|
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
model=dummy_model,
|
model=FakeLLMModel(),
|
||||||
max_steps=1,
|
max_steps=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -135,16 +145,9 @@ final_answer('This is the final answer.')
|
||||||
self.assertIn("This is the final answer.", final_message.content)
|
self.assertIn("This is the final answer.", final_message.content)
|
||||||
|
|
||||||
def test_streaming_agent_image_output(self):
|
def test_streaming_agent_image_output(self):
|
||||||
class FakeLLM:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_tool_call(self, messages, **kwargs):
|
|
||||||
return "final_answer", {"answer": "image"}, "fake_id"
|
|
||||||
|
|
||||||
agent = ToolCallingAgent(
|
agent = ToolCallingAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
model=FakeLLM(),
|
model=FakeLLMModel(),
|
||||||
max_steps=1,
|
max_steps=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue