MLX model support (#300)

This commit is contained in:
Joe 2025-02-12 09:30:25 -08:00 committed by GitHub
parent bca3a9bc13
commit 9b96199d00
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 171 additions and 3 deletions

View File

@ -28,10 +28,11 @@ To initialize a minimal agent, you need at least these two arguments:
- [`HfApiModel`] leverages a `huggingface_hub.InferenceClient` under the hood and supports all Inference Providers on the Hub. - [`HfApiModel`] leverages a `huggingface_hub.InferenceClient` under the hood and supports all Inference Providers on the Hub.
- [`LiteLLMModel`] similarly lets you call 100+ different models and providers through [LiteLLM](https://docs.litellm.ai/)! - [`LiteLLMModel`] similarly lets you call 100+ different models and providers through [LiteLLM](https://docs.litellm.ai/)!
- [`AzureOpenAIServerModel`] allows you to use OpenAI models deployed in [Azure](https://azure.microsoft.com/en-us/products/ai-services/openai-service). - [`AzureOpenAIServerModel`] allows you to use OpenAI models deployed in [Azure](https://azure.microsoft.com/en-us/products/ai-services/openai-service).
- [`MLXModel`] creates a [mlx-lm](https://pypi.org/project/mlx-lm/) pipeline to run inference on your local machine.
- `tools`, a list of `Tools` that the agent can use to solve the task. It can be an empty list. You can also add the default toolbox on top of your `tools` list by defining the optional argument `add_base_tools=True`. - `tools`, a list of `Tools` that the agent can use to solve the task. It can be an empty list. You can also add the default toolbox on top of your `tools` list by defining the optional argument `add_base_tools=True`.
Once you have these two arguments, `tools` and `model`, you can create an agent and run it. You can use any LLM you'd like, either through [Inference Providers](https://huggingface.co/blog/inference-providers), [transformers](https://github.com/huggingface/transformers/), [ollama](https://ollama.com/), [LiteLLM](https://www.litellm.ai/), or [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service). Once you have these two arguments, `tools` and `model`, you can create an agent and run it. You can use any LLM you'd like, either through [Inference Providers](https://huggingface.co/blog/inference-providers), [transformers](https://github.com/huggingface/transformers/), [ollama](https://ollama.com/), [LiteLLM](https://www.litellm.ai/), [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service), or [mlx-lm](https://pypi.org/project/mlx-lm/).
<hfoptions id="Pick a LLM"> <hfoptions id="Pick a LLM">
<hfoption id="HF Inference API"> <hfoption id="HF Inference API">
@ -148,6 +149,19 @@ agent.run(
) )
``` ```
</hfoption>
<hfoption id="mlx-lm">
```python
# !pip install smolagents[mlx-lm]
from smolagents import CodeAgent, MLXModel
mlx_model = MLXModel("mlx-community/Qwen2.5-Coder-32B-Instruct-4bit")
agent = CodeAgent(model=mlx_model, tools=[], add_base_tools=True)
agent.run("Could you give me the 118th number in the Fibonacci sequence?")
```
</hfoption> </hfoption>
</hfoptions> </hfoptions>

View File

@ -147,4 +147,23 @@ model = AzureOpenAIServerModel(
) )
``` ```
[[autodoc]] AzureOpenAIServerModel [[autodoc]] AzureOpenAIServerModel
### MLXModel
```python
from smolagents import MLXModel
model = MLXModel(model_id="HuggingFaceTB/SmolLM-135M-Instruct")
print(model([{"role": "user", "content": "Ok!"}], stop_sequences=["great"]))
```
```text
>>> What a
```
> [!TIP]
> You must have `mlx-lm` installed on your machine. Please run `pip install smolagents[mlx-lm]` if it's not the case.
[[autodoc]] MLXModel

View File

@ -46,6 +46,9 @@ mcp = [
"mcpadapt>=0.0.6", "mcpadapt>=0.0.6",
"mcp", "mcp",
] ]
mlx-lm = [
"mlx-lm"
]
openai = [ openai = [
"openai>=1.58.1" "openai>=1.58.1"
] ]

View File

@ -18,6 +18,7 @@ import json
import logging import logging
import os import os
import random import random
import uuid
from copy import deepcopy from copy import deepcopy
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from enum import Enum from enum import Enum
@ -415,6 +416,128 @@ class HfApiModel(Model):
return message return message
class MLXModel(Model):
"""A class to interact with models loaded using MLX on Apple silicon.
> [!TIP]
> You must have `mlx-lm` installed on your machine. Please run `pip install smolagents[mlx-lm]` if it's not the case.
Parameters:
model_id (str):
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
tool_name_key (str):
The key, which can usually be found in the model's chat template, for retrieving a tool name.
tool_arguments_key (str):
The key, which can usually be found in the model's chat template, for retrieving tool arguments.
trust_remote_code (bool):
Some models on the Hub require running remote code: for this model, you would have to set this flag to True.
kwargs (dict, *optional*):
Any additional keyword arguments that you want to use in model.generate(), for instance `max_tokens`.
Example:
```python
>>> engine = MLXModel(
... model_id="mlx-community/Qwen2.5-Coder-32B-Instruct-4bit",
... max_tokens=10000,
... )
>>> messages = [
... {
... "role": "user",
... "content": [
... {"type": "text", "text": "Explain quantum mechanics in simple terms."}
... ]
... }
... ]
>>> response = engine(messages, stop_sequences=["END"])
>>> print(response)
"Quantum mechanics is the branch of physics that studies..."
```
"""
def __init__(
self,
model_id: str,
tool_name_key: str = "name",
tool_arguments_key: str = "arguments",
trust_remote_code: bool = False,
**kwargs,
):
super().__init__(**kwargs)
if not _is_package_available("mlx_lm"):
raise ModuleNotFoundError(
"Please install 'mlx-lm' extra to use 'MLXModel': `pip install 'smolagents[mlx-lm]'`"
)
import mlx_lm
self.model_id = model_id
self.model, self.tokenizer = mlx_lm.load(model_id, tokenizer_config={"trust_remote_code": trust_remote_code})
self.stream_generate = mlx_lm.stream_generate
self.tool_name_key = tool_name_key
self.tool_arguments_key = tool_arguments_key
def _to_message(self, text, tools_to_call_from):
if tools_to_call_from:
# tmp solution for extracting tool JSON without assuming a specific model output format
maybe_json = "{" + text.split("{", 1)[-1][::-1].split("}", 1)[-1][::-1] + "}"
parsed_text = json.loads(maybe_json)
tool_name = parsed_text.get(self.tool_name_key, None)
tool_arguments = parsed_text.get(self.tool_arguments_key, None)
if tool_name:
return ChatMessage(
role="assistant",
content="",
tool_calls=[
ChatMessageToolCall(
id=uuid.uuid4(),
type="function",
function=ChatMessageToolCallDefinition(name=tool_name, arguments=tool_arguments),
)
],
)
return ChatMessage(role="assistant", content=text)
def __call__(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
tools_to_call_from: Optional[List[Tool]] = None,
**kwargs,
) -> ChatMessage:
completion_kwargs = self._prepare_completion_kwargs(
flatten_messages_as_text=True, # mlx-lm doesn't support vision models
messages=messages,
stop_sequences=stop_sequences,
grammar=grammar,
tools_to_call_from=tools_to_call_from,
**kwargs,
)
messages = completion_kwargs.pop("messages")
prepared_stop_sequences = completion_kwargs.pop("stop", [])
tools = completion_kwargs.pop("tools", None)
completion_kwargs.pop("tool_choice", None)
prompt_ids = self.tokenizer.apply_chat_template(
messages,
tools=tools,
add_generation_prompt=True,
)
self.last_input_token_count = len(prompt_ids)
self.last_output_token_count = 0
text = ""
for _ in self.stream_generate(self.model, self.tokenizer, prompt=prompt_ids, **completion_kwargs):
self.last_output_token_count += 1
text += _.text
for stop_sequence in prepared_stop_sequences:
if text.strip().endswith(stop_sequence):
text = text[: -len(stop_sequence)]
return self._to_message(text, tools_to_call_from)
return self._to_message(text, tools_to_call_from)
class TransformersModel(Model): class TransformersModel(Model):
"""A class that uses Hugging Face's Transformers library for language model interaction. """A class that uses Hugging Face's Transformers library for language model interaction.
@ -837,6 +960,7 @@ __all__ = [
"tool_role_conversions", "tool_role_conversions",
"get_clean_message_list", "get_clean_message_list",
"Model", "Model",
"MLXModel",
"TransformersModel", "TransformersModel",
"HfApiModel", "HfApiModel",
"LiteLLMModel", "LiteLLMModel",

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import json import json
import os import os
import sys
import unittest import unittest
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@ -22,7 +23,7 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from transformers.testing_utils import get_tests_dir from transformers.testing_utils import get_tests_dir
from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool from smolagents import ChatMessage, HfApiModel, MLXModel, TransformersModel, models, tool
from smolagents.models import MessageRole, get_clean_message_list, parse_json_if_needed from smolagents.models import MessageRole, get_clean_message_list, parse_json_if_needed
@ -61,6 +62,13 @@ class ModelTests(unittest.TestCase):
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}] messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
model(messages, stop_sequences=["great"]) model(messages, stop_sequences=["great"])
@unittest.skipUnless(sys.platform.startswith("darwin"), "requires macOS")
def test_get_mlx_message_no_tool(self):
model = MLXModel(model_id="HuggingFaceTB/SmolLM2-135M-Instruct", max_tokens=10)
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
output = model(messages, stop_sequences=["great"]).content
assert output.startswith("Hello")
def test_transformers_message_no_tool(self): def test_transformers_message_no_tool(self):
model = TransformersModel( model = TransformersModel(
model_id="HuggingFaceTB/SmolLM2-135M-Instruct", model_id="HuggingFaceTB/SmolLM2-135M-Instruct",