TransformersModel auto-detects VLMs (#378)

* TransformersModel auto-detects VLMs
This commit is contained in:
Aymeric Roucher 2025-01-27 20:09:14 +01:00 committed by GitHub
parent a5290590c8
commit 4579a6f7cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 42 additions and 19 deletions

View File

@ -9,7 +9,7 @@ from smolagents.agents import CodeAgent, ToolCallingAgent
available_inferences = ["hf_api", "transformers", "ollama", "litellm"] available_inferences = ["hf_api", "transformers", "ollama", "litellm"]
chosen_inference = "transformers" chosen_inference = "transformers"
print(f"Chose model {chosen_inference}") print(f"Chose model: '{chosen_inference}'")
if chosen_inference == "hf_api": if chosen_inference == "hf_api":
model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct") model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct")

View File

@ -798,7 +798,7 @@ class ToolCallingAgent(MultiStepAgent):
tool_arguments = tool_call.function.arguments tool_arguments = tool_call.function.arguments
except Exception as e: except Exception as e:
raise AgentGenerationError(f"Error in generating tool call with model:\n{e}", self.logger) raise AgentGenerationError(f"Error in generating tool call with model:\n{e}", self.logger) from e
log_entry.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)] log_entry.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)]

View File

@ -418,9 +418,6 @@ class TransformersModel(Model):
The torch_dtype to initialize your model with. The torch_dtype to initialize your model with.
trust_remote_code (bool, default `False`): trust_remote_code (bool, default `False`):
Some models on the Hub require running remote code: for this model, you would have to set this flag to True. Some models on the Hub require running remote code: for this model, you would have to set this flag to True.
flatten_messages_as_text (`bool`, default `True`):
Whether to flatten messages as text: this must be sent to False to use VLMs (as opposed to LLMs for which this flag can be ignored).
Caution: this parameter is experimental and will be removed in an upcoming PR as we auto-detect VLMs.
kwargs (dict, *optional*): kwargs (dict, *optional*):
Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`. Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`.
**kwargs: **kwargs:
@ -449,7 +446,6 @@ class TransformersModel(Model):
device_map: Optional[str] = None, device_map: Optional[str] = None,
torch_dtype: Optional[str] = None, torch_dtype: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
flatten_messages_as_text: bool = True,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -469,6 +465,7 @@ class TransformersModel(Model):
if device_map is None: if device_map is None:
device_map = "cuda" if torch.cuda.is_available() else "cpu" device_map = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device_map}") logger.info(f"Using device: {device_map}")
self._is_vlm = False
try: try:
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,
@ -481,6 +478,7 @@ class TransformersModel(Model):
if "Unrecognized configuration class" in str(e): if "Unrecognized configuration class" in str(e):
self.model = AutoModelForImageTextToText.from_pretrained(model_id, device_map=device_map) self.model = AutoModelForImageTextToText.from_pretrained(model_id, device_map=device_map)
self.processor = AutoProcessor.from_pretrained(model_id) self.processor = AutoProcessor.from_pretrained(model_id)
self._is_vlm = True
else: else:
raise e raise e
except Exception as e: except Exception as e:
@ -490,7 +488,6 @@ class TransformersModel(Model):
self.model_id = default_model_id self.model_id = default_model_id
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map, torch_dtype=torch_dtype) self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map, torch_dtype=torch_dtype)
self.flatten_messages_as_text = flatten_messages_as_text
def make_stopping_criteria(self, stop_sequences: List[str], tokenizer) -> "StoppingCriteriaList": def make_stopping_criteria(self, stop_sequences: List[str], tokenizer) -> "StoppingCriteriaList":
from transformers import StoppingCriteria, StoppingCriteriaList from transformers import StoppingCriteria, StoppingCriteriaList
@ -526,8 +523,7 @@ class TransformersModel(Model):
messages=messages, messages=messages,
stop_sequences=stop_sequences, stop_sequences=stop_sequences,
grammar=grammar, grammar=grammar,
tools_to_call_from=tools_to_call_from, flatten_messages_as_text=(not self._is_vlm),
flatten_messages_as_text=self.flatten_messages_as_text,
**kwargs, **kwargs,
) )
@ -595,9 +591,19 @@ class TransformersModel(Model):
else: else:
if "Action:" in output: if "Action:" in output:
output = output.split("Action:", 1)[1].strip() output = output.split("Action:", 1)[1].strip()
parsed_output = json.loads(output) try:
tool_name = parsed_output.get("tool_name") start_index = output.index("{")
tool_arguments = parsed_output.get("tool_arguments") end_index = output.rindex("}")
output = output[start_index : end_index + 1]
except Exception as e:
raise Exception("No json blob found in output!") from e
try:
parsed_output = json.loads(output)
except json.JSONDecodeError as e:
raise ValueError(f"Tool call '{output}' has an invalid JSON structure: {e}")
tool_name = parsed_output.get("name")
tool_arguments = parsed_output.get("arguments")
return ChatMessage( return ChatMessage(
role="assistant", role="assistant",
content="", content="",

View File

@ -28,11 +28,7 @@ from smolagents.agents import (
ToolCallingAgent, ToolCallingAgent,
) )
from smolagents.default_tools import PythonInterpreterTool from smolagents.default_tools import PythonInterpreterTool
from smolagents.models import ( from smolagents.models import ChatMessage, ChatMessageToolCall, ChatMessageToolCallDefinition, TransformersModel
ChatMessage,
ChatMessageToolCall,
ChatMessageToolCallDefinition,
)
from smolagents.tools import tool from smolagents.tools import tool
from smolagents.types import AgentImage, AgentText from smolagents.types import AgentImage, AgentText
from smolagents.utils import BASE_BUILTIN_MODULES from smolagents.utils import BASE_BUILTIN_MODULES
@ -620,3 +616,26 @@ nested_answer()
output = agent.run("Count to 3") output = agent.run("Count to 3")
assert output == "Correct!" assert output == "Correct!"
def test_transformers_toolcalling_agent(self):
@tool
def get_weather(location: str, celsius: bool = False) -> str:
"""
Get weather in the next days at given location.
Secretly this tool does not care about the location, it hates the weather everywhere.
Args:
location: the location
celsius: the temperature type
"""
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
model = TransformersModel(
model_id="HuggingFaceTB/SmolLM2-360M-Instruct",
max_new_tokens=100,
device_map="auto",
do_sample=False,
)
agent = ToolCallingAgent(model=model, tools=[get_weather], max_steps=1)
agent.run("What's the weather in Paris?")
assert agent.logs[2].tool_calls[0].name == "get_weather"

View File

@ -57,7 +57,6 @@ class ModelTests(unittest.TestCase):
max_new_tokens=5, max_new_tokens=5,
device_map="auto", device_map="auto",
do_sample=False, do_sample=False,
flatten_messages_as_text=True,
) )
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}] messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
output = model(messages, stop_sequences=["great"]).content output = model(messages, stop_sequences=["great"]).content
@ -72,7 +71,6 @@ class ModelTests(unittest.TestCase):
max_new_tokens=5, max_new_tokens=5,
device_map="auto", device_map="auto",
do_sample=False, do_sample=False,
flatten_messages_as_text=False,
) )
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}, {"type": "image", "image": img}]}] messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}, {"type": "image", "image": img}]}]
output = model(messages, stop_sequences=["great"]).content output = model(messages, stop_sequences=["great"]).content