TransformersModel auto-detects VLMs (#378)
* TransformersModel auto-detects VLMs
This commit is contained in:
parent
a5290590c8
commit
4579a6f7cc
|
@ -9,7 +9,7 @@ from smolagents.agents import CodeAgent, ToolCallingAgent
|
||||||
available_inferences = ["hf_api", "transformers", "ollama", "litellm"]
|
available_inferences = ["hf_api", "transformers", "ollama", "litellm"]
|
||||||
chosen_inference = "transformers"
|
chosen_inference = "transformers"
|
||||||
|
|
||||||
print(f"Chose model {chosen_inference}")
|
print(f"Chose model: '{chosen_inference}'")
|
||||||
|
|
||||||
if chosen_inference == "hf_api":
|
if chosen_inference == "hf_api":
|
||||||
model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct")
|
model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct")
|
||||||
|
|
|
@ -798,7 +798,7 @@ class ToolCallingAgent(MultiStepAgent):
|
||||||
tool_arguments = tool_call.function.arguments
|
tool_arguments = tool_call.function.arguments
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AgentGenerationError(f"Error in generating tool call with model:\n{e}", self.logger)
|
raise AgentGenerationError(f"Error in generating tool call with model:\n{e}", self.logger) from e
|
||||||
|
|
||||||
log_entry.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)]
|
log_entry.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)]
|
||||||
|
|
||||||
|
|
|
@ -418,9 +418,6 @@ class TransformersModel(Model):
|
||||||
The torch_dtype to initialize your model with.
|
The torch_dtype to initialize your model with.
|
||||||
trust_remote_code (bool, default `False`):
|
trust_remote_code (bool, default `False`):
|
||||||
Some models on the Hub require running remote code: for this model, you would have to set this flag to True.
|
Some models on the Hub require running remote code: for this model, you would have to set this flag to True.
|
||||||
flatten_messages_as_text (`bool`, default `True`):
|
|
||||||
Whether to flatten messages as text: this must be sent to False to use VLMs (as opposed to LLMs for which this flag can be ignored).
|
|
||||||
Caution: this parameter is experimental and will be removed in an upcoming PR as we auto-detect VLMs.
|
|
||||||
kwargs (dict, *optional*):
|
kwargs (dict, *optional*):
|
||||||
Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`.
|
Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`.
|
||||||
**kwargs:
|
**kwargs:
|
||||||
|
@ -449,7 +446,6 @@ class TransformersModel(Model):
|
||||||
device_map: Optional[str] = None,
|
device_map: Optional[str] = None,
|
||||||
torch_dtype: Optional[str] = None,
|
torch_dtype: Optional[str] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
flatten_messages_as_text: bool = True,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
@ -469,6 +465,7 @@ class TransformersModel(Model):
|
||||||
if device_map is None:
|
if device_map is None:
|
||||||
device_map = "cuda" if torch.cuda.is_available() else "cpu"
|
device_map = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
logger.info(f"Using device: {device_map}")
|
logger.info(f"Using device: {device_map}")
|
||||||
|
self._is_vlm = False
|
||||||
try:
|
try:
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -481,6 +478,7 @@ class TransformersModel(Model):
|
||||||
if "Unrecognized configuration class" in str(e):
|
if "Unrecognized configuration class" in str(e):
|
||||||
self.model = AutoModelForImageTextToText.from_pretrained(model_id, device_map=device_map)
|
self.model = AutoModelForImageTextToText.from_pretrained(model_id, device_map=device_map)
|
||||||
self.processor = AutoProcessor.from_pretrained(model_id)
|
self.processor = AutoProcessor.from_pretrained(model_id)
|
||||||
|
self._is_vlm = True
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -490,7 +488,6 @@ class TransformersModel(Model):
|
||||||
self.model_id = default_model_id
|
self.model_id = default_model_id
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
|
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map, torch_dtype=torch_dtype)
|
self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map, torch_dtype=torch_dtype)
|
||||||
self.flatten_messages_as_text = flatten_messages_as_text
|
|
||||||
|
|
||||||
def make_stopping_criteria(self, stop_sequences: List[str], tokenizer) -> "StoppingCriteriaList":
|
def make_stopping_criteria(self, stop_sequences: List[str], tokenizer) -> "StoppingCriteriaList":
|
||||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||||
|
@ -526,8 +523,7 @@ class TransformersModel(Model):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stop_sequences=stop_sequences,
|
stop_sequences=stop_sequences,
|
||||||
grammar=grammar,
|
grammar=grammar,
|
||||||
tools_to_call_from=tools_to_call_from,
|
flatten_messages_as_text=(not self._is_vlm),
|
||||||
flatten_messages_as_text=self.flatten_messages_as_text,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -595,9 +591,19 @@ class TransformersModel(Model):
|
||||||
else:
|
else:
|
||||||
if "Action:" in output:
|
if "Action:" in output:
|
||||||
output = output.split("Action:", 1)[1].strip()
|
output = output.split("Action:", 1)[1].strip()
|
||||||
parsed_output = json.loads(output)
|
try:
|
||||||
tool_name = parsed_output.get("tool_name")
|
start_index = output.index("{")
|
||||||
tool_arguments = parsed_output.get("tool_arguments")
|
end_index = output.rindex("}")
|
||||||
|
output = output[start_index : end_index + 1]
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception("No json blob found in output!") from e
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed_output = json.loads(output)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError(f"Tool call '{output}' has an invalid JSON structure: {e}")
|
||||||
|
tool_name = parsed_output.get("name")
|
||||||
|
tool_arguments = parsed_output.get("arguments")
|
||||||
return ChatMessage(
|
return ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="",
|
content="",
|
||||||
|
|
|
@ -28,11 +28,7 @@ from smolagents.agents import (
|
||||||
ToolCallingAgent,
|
ToolCallingAgent,
|
||||||
)
|
)
|
||||||
from smolagents.default_tools import PythonInterpreterTool
|
from smolagents.default_tools import PythonInterpreterTool
|
||||||
from smolagents.models import (
|
from smolagents.models import ChatMessage, ChatMessageToolCall, ChatMessageToolCallDefinition, TransformersModel
|
||||||
ChatMessage,
|
|
||||||
ChatMessageToolCall,
|
|
||||||
ChatMessageToolCallDefinition,
|
|
||||||
)
|
|
||||||
from smolagents.tools import tool
|
from smolagents.tools import tool
|
||||||
from smolagents.types import AgentImage, AgentText
|
from smolagents.types import AgentImage, AgentText
|
||||||
from smolagents.utils import BASE_BUILTIN_MODULES
|
from smolagents.utils import BASE_BUILTIN_MODULES
|
||||||
|
@ -620,3 +616,26 @@ nested_answer()
|
||||||
|
|
||||||
output = agent.run("Count to 3")
|
output = agent.run("Count to 3")
|
||||||
assert output == "Correct!"
|
assert output == "Correct!"
|
||||||
|
|
||||||
|
def test_transformers_toolcalling_agent(self):
|
||||||
|
@tool
|
||||||
|
def get_weather(location: str, celsius: bool = False) -> str:
|
||||||
|
"""
|
||||||
|
Get weather in the next days at given location.
|
||||||
|
Secretly this tool does not care about the location, it hates the weather everywhere.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
location: the location
|
||||||
|
celsius: the temperature type
|
||||||
|
"""
|
||||||
|
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||||
|
|
||||||
|
model = TransformersModel(
|
||||||
|
model_id="HuggingFaceTB/SmolLM2-360M-Instruct",
|
||||||
|
max_new_tokens=100,
|
||||||
|
device_map="auto",
|
||||||
|
do_sample=False,
|
||||||
|
)
|
||||||
|
agent = ToolCallingAgent(model=model, tools=[get_weather], max_steps=1)
|
||||||
|
agent.run("What's the weather in Paris?")
|
||||||
|
assert agent.logs[2].tool_calls[0].name == "get_weather"
|
||||||
|
|
|
@ -57,7 +57,6 @@ class ModelTests(unittest.TestCase):
|
||||||
max_new_tokens=5,
|
max_new_tokens=5,
|
||||||
device_map="auto",
|
device_map="auto",
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
flatten_messages_as_text=True,
|
|
||||||
)
|
)
|
||||||
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
|
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
|
||||||
output = model(messages, stop_sequences=["great"]).content
|
output = model(messages, stop_sequences=["great"]).content
|
||||||
|
@ -72,7 +71,6 @@ class ModelTests(unittest.TestCase):
|
||||||
max_new_tokens=5,
|
max_new_tokens=5,
|
||||||
device_map="auto",
|
device_map="auto",
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
flatten_messages_as_text=False,
|
|
||||||
)
|
)
|
||||||
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}, {"type": "image", "image": img}]}]
|
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}, {"type": "image", "image": img}]}]
|
||||||
output = model(messages, stop_sequences=["great"]).content
|
output = model(messages, stop_sequences=["great"]).content
|
||||||
|
|
Loading…
Reference in New Issue