TransformersModel auto-detects VLMs (#378)
* TransformersModel auto-detects VLMs
This commit is contained in:
parent
a5290590c8
commit
4579a6f7cc
|
@ -9,7 +9,7 @@ from smolagents.agents import CodeAgent, ToolCallingAgent
|
|||
available_inferences = ["hf_api", "transformers", "ollama", "litellm"]
|
||||
chosen_inference = "transformers"
|
||||
|
||||
print(f"Chose model {chosen_inference}")
|
||||
print(f"Chose model: '{chosen_inference}'")
|
||||
|
||||
if chosen_inference == "hf_api":
|
||||
model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct")
|
||||
|
|
|
@ -798,7 +798,7 @@ class ToolCallingAgent(MultiStepAgent):
|
|||
tool_arguments = tool_call.function.arguments
|
||||
|
||||
except Exception as e:
|
||||
raise AgentGenerationError(f"Error in generating tool call with model:\n{e}", self.logger)
|
||||
raise AgentGenerationError(f"Error in generating tool call with model:\n{e}", self.logger) from e
|
||||
|
||||
log_entry.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)]
|
||||
|
||||
|
|
|
@ -418,9 +418,6 @@ class TransformersModel(Model):
|
|||
The torch_dtype to initialize your model with.
|
||||
trust_remote_code (bool, default `False`):
|
||||
Some models on the Hub require running remote code: for this model, you would have to set this flag to True.
|
||||
flatten_messages_as_text (`bool`, default `True`):
|
||||
Whether to flatten messages as text: this must be sent to False to use VLMs (as opposed to LLMs for which this flag can be ignored).
|
||||
Caution: this parameter is experimental and will be removed in an upcoming PR as we auto-detect VLMs.
|
||||
kwargs (dict, *optional*):
|
||||
Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`.
|
||||
**kwargs:
|
||||
|
@ -449,7 +446,6 @@ class TransformersModel(Model):
|
|||
device_map: Optional[str] = None,
|
||||
torch_dtype: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
flatten_messages_as_text: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
@ -469,6 +465,7 @@ class TransformersModel(Model):
|
|||
if device_map is None:
|
||||
device_map = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"Using device: {device_map}")
|
||||
self._is_vlm = False
|
||||
try:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
|
@ -481,6 +478,7 @@ class TransformersModel(Model):
|
|||
if "Unrecognized configuration class" in str(e):
|
||||
self.model = AutoModelForImageTextToText.from_pretrained(model_id, device_map=device_map)
|
||||
self.processor = AutoProcessor.from_pretrained(model_id)
|
||||
self._is_vlm = True
|
||||
else:
|
||||
raise e
|
||||
except Exception as e:
|
||||
|
@ -490,7 +488,6 @@ class TransformersModel(Model):
|
|||
self.model_id = default_model_id
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map, torch_dtype=torch_dtype)
|
||||
self.flatten_messages_as_text = flatten_messages_as_text
|
||||
|
||||
def make_stopping_criteria(self, stop_sequences: List[str], tokenizer) -> "StoppingCriteriaList":
|
||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||
|
@ -526,8 +523,7 @@ class TransformersModel(Model):
|
|||
messages=messages,
|
||||
stop_sequences=stop_sequences,
|
||||
grammar=grammar,
|
||||
tools_to_call_from=tools_to_call_from,
|
||||
flatten_messages_as_text=self.flatten_messages_as_text,
|
||||
flatten_messages_as_text=(not self._is_vlm),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -595,9 +591,19 @@ class TransformersModel(Model):
|
|||
else:
|
||||
if "Action:" in output:
|
||||
output = output.split("Action:", 1)[1].strip()
|
||||
parsed_output = json.loads(output)
|
||||
tool_name = parsed_output.get("tool_name")
|
||||
tool_arguments = parsed_output.get("tool_arguments")
|
||||
try:
|
||||
start_index = output.index("{")
|
||||
end_index = output.rindex("}")
|
||||
output = output[start_index : end_index + 1]
|
||||
except Exception as e:
|
||||
raise Exception("No json blob found in output!") from e
|
||||
|
||||
try:
|
||||
parsed_output = json.loads(output)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Tool call '{output}' has an invalid JSON structure: {e}")
|
||||
tool_name = parsed_output.get("name")
|
||||
tool_arguments = parsed_output.get("arguments")
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
|
|
|
@ -28,11 +28,7 @@ from smolagents.agents import (
|
|||
ToolCallingAgent,
|
||||
)
|
||||
from smolagents.default_tools import PythonInterpreterTool
|
||||
from smolagents.models import (
|
||||
ChatMessage,
|
||||
ChatMessageToolCall,
|
||||
ChatMessageToolCallDefinition,
|
||||
)
|
||||
from smolagents.models import ChatMessage, ChatMessageToolCall, ChatMessageToolCallDefinition, TransformersModel
|
||||
from smolagents.tools import tool
|
||||
from smolagents.types import AgentImage, AgentText
|
||||
from smolagents.utils import BASE_BUILTIN_MODULES
|
||||
|
@ -620,3 +616,26 @@ nested_answer()
|
|||
|
||||
output = agent.run("Count to 3")
|
||||
assert output == "Correct!"
|
||||
|
||||
def test_transformers_toolcalling_agent(self):
|
||||
@tool
|
||||
def get_weather(location: str, celsius: bool = False) -> str:
|
||||
"""
|
||||
Get weather in the next days at given location.
|
||||
Secretly this tool does not care about the location, it hates the weather everywhere.
|
||||
|
||||
Args:
|
||||
location: the location
|
||||
celsius: the temperature type
|
||||
"""
|
||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||
|
||||
model = TransformersModel(
|
||||
model_id="HuggingFaceTB/SmolLM2-360M-Instruct",
|
||||
max_new_tokens=100,
|
||||
device_map="auto",
|
||||
do_sample=False,
|
||||
)
|
||||
agent = ToolCallingAgent(model=model, tools=[get_weather], max_steps=1)
|
||||
agent.run("What's the weather in Paris?")
|
||||
assert agent.logs[2].tool_calls[0].name == "get_weather"
|
||||
|
|
|
@ -57,7 +57,6 @@ class ModelTests(unittest.TestCase):
|
|||
max_new_tokens=5,
|
||||
device_map="auto",
|
||||
do_sample=False,
|
||||
flatten_messages_as_text=True,
|
||||
)
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
|
||||
output = model(messages, stop_sequences=["great"]).content
|
||||
|
@ -72,7 +71,6 @@ class ModelTests(unittest.TestCase):
|
|||
max_new_tokens=5,
|
||||
device_map="auto",
|
||||
do_sample=False,
|
||||
flatten_messages_as_text=False,
|
||||
)
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}, {"type": "image", "image": img}]}]
|
||||
output = model(messages, stop_sequences=["great"]).content
|
||||
|
|
Loading…
Reference in New Issue