MLX model support (#300)

This commit is contained in:
Joe 2025-02-12 09:30:25 -08:00 committed by GitHub
parent bca3a9bc13
commit 9b96199d00
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 171 additions and 3 deletions

View File

@ -28,10 +28,11 @@ To initialize a minimal agent, you need at least these two arguments:
- [`HfApiModel`] leverages a `huggingface_hub.InferenceClient` under the hood and supports all Inference Providers on the Hub.
- [`LiteLLMModel`] similarly lets you call 100+ different models and providers through [LiteLLM](https://docs.litellm.ai/)!
- [`AzureOpenAIServerModel`] allows you to use OpenAI models deployed in [Azure](https://azure.microsoft.com/en-us/products/ai-services/openai-service).
- [`MLXModel`] creates a [mlx-lm](https://pypi.org/project/mlx-lm/) pipeline to run inference on your local machine.
- `tools`, a list of `Tools` that the agent can use to solve the task. It can be an empty list. You can also add the default toolbox on top of your `tools` list by defining the optional argument `add_base_tools=True`.
Once you have these two arguments, `tools` and `model`, you can create an agent and run it. You can use any LLM you'd like, either through [Inference Providers](https://huggingface.co/blog/inference-providers), [transformers](https://github.com/huggingface/transformers/), [ollama](https://ollama.com/), [LiteLLM](https://www.litellm.ai/), or [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service).
Once you have these two arguments, `tools` and `model`, you can create an agent and run it. You can use any LLM you'd like, either through [Inference Providers](https://huggingface.co/blog/inference-providers), [transformers](https://github.com/huggingface/transformers/), [ollama](https://ollama.com/), [LiteLLM](https://www.litellm.ai/), [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service), or [mlx-lm](https://pypi.org/project/mlx-lm/).
<hfoptions id="Pick a LLM">
<hfoption id="HF Inference API">
@ -148,6 +149,19 @@ agent.run(
)
```
</hfoption>
<hfoption id="mlx-lm">
```python
# !pip install smolagents[mlx-lm]
from smolagents import CodeAgent, MLXModel
mlx_model = MLXModel("mlx-community/Qwen2.5-Coder-32B-Instruct-4bit")
agent = CodeAgent(model=mlx_model, tools=[], add_base_tools=True)
agent.run("Could you give me the 118th number in the Fibonacci sequence?")
```
</hfoption>
</hfoptions>

View File

@ -147,4 +147,23 @@ model = AzureOpenAIServerModel(
)
```
[[autodoc]] AzureOpenAIServerModel
[[autodoc]] AzureOpenAIServerModel
### MLXModel
```python
from smolagents import MLXModel
model = MLXModel(model_id="HuggingFaceTB/SmolLM-135M-Instruct")
print(model([{"role": "user", "content": "Ok!"}], stop_sequences=["great"]))
```
```text
>>> What a
```
> [!TIP]
> You must have `mlx-lm` installed on your machine. Please run `pip install smolagents[mlx-lm]` if it's not the case.
[[autodoc]] MLXModel

View File

@ -46,6 +46,9 @@ mcp = [
"mcpadapt>=0.0.6",
"mcp",
]
mlx-lm = [
"mlx-lm"
]
openai = [
"openai>=1.58.1"
]

View File

@ -18,6 +18,7 @@ import json
import logging
import os
import random
import uuid
from copy import deepcopy
from dataclasses import asdict, dataclass
from enum import Enum
@ -415,6 +416,128 @@ class HfApiModel(Model):
return message
class MLXModel(Model):
"""A class to interact with models loaded using MLX on Apple silicon.
> [!TIP]
> You must have `mlx-lm` installed on your machine. Please run `pip install smolagents[mlx-lm]` if it's not the case.
Parameters:
model_id (str):
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
tool_name_key (str):
The key, which can usually be found in the model's chat template, for retrieving a tool name.
tool_arguments_key (str):
The key, which can usually be found in the model's chat template, for retrieving tool arguments.
trust_remote_code (bool):
Some models on the Hub require running remote code: for this model, you would have to set this flag to True.
kwargs (dict, *optional*):
Any additional keyword arguments that you want to use in model.generate(), for instance `max_tokens`.
Example:
```python
>>> engine = MLXModel(
... model_id="mlx-community/Qwen2.5-Coder-32B-Instruct-4bit",
... max_tokens=10000,
... )
>>> messages = [
... {
... "role": "user",
... "content": [
... {"type": "text", "text": "Explain quantum mechanics in simple terms."}
... ]
... }
... ]
>>> response = engine(messages, stop_sequences=["END"])
>>> print(response)
"Quantum mechanics is the branch of physics that studies..."
```
"""
def __init__(
self,
model_id: str,
tool_name_key: str = "name",
tool_arguments_key: str = "arguments",
trust_remote_code: bool = False,
**kwargs,
):
super().__init__(**kwargs)
if not _is_package_available("mlx_lm"):
raise ModuleNotFoundError(
"Please install 'mlx-lm' extra to use 'MLXModel': `pip install 'smolagents[mlx-lm]'`"
)
import mlx_lm
self.model_id = model_id
self.model, self.tokenizer = mlx_lm.load(model_id, tokenizer_config={"trust_remote_code": trust_remote_code})
self.stream_generate = mlx_lm.stream_generate
self.tool_name_key = tool_name_key
self.tool_arguments_key = tool_arguments_key
def _to_message(self, text, tools_to_call_from):
if tools_to_call_from:
# tmp solution for extracting tool JSON without assuming a specific model output format
maybe_json = "{" + text.split("{", 1)[-1][::-1].split("}", 1)[-1][::-1] + "}"
parsed_text = json.loads(maybe_json)
tool_name = parsed_text.get(self.tool_name_key, None)
tool_arguments = parsed_text.get(self.tool_arguments_key, None)
if tool_name:
return ChatMessage(
role="assistant",
content="",
tool_calls=[
ChatMessageToolCall(
id=uuid.uuid4(),
type="function",
function=ChatMessageToolCallDefinition(name=tool_name, arguments=tool_arguments),
)
],
)
return ChatMessage(role="assistant", content=text)
def __call__(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
tools_to_call_from: Optional[List[Tool]] = None,
**kwargs,
) -> ChatMessage:
completion_kwargs = self._prepare_completion_kwargs(
flatten_messages_as_text=True, # mlx-lm doesn't support vision models
messages=messages,
stop_sequences=stop_sequences,
grammar=grammar,
tools_to_call_from=tools_to_call_from,
**kwargs,
)
messages = completion_kwargs.pop("messages")
prepared_stop_sequences = completion_kwargs.pop("stop", [])
tools = completion_kwargs.pop("tools", None)
completion_kwargs.pop("tool_choice", None)
prompt_ids = self.tokenizer.apply_chat_template(
messages,
tools=tools,
add_generation_prompt=True,
)
self.last_input_token_count = len(prompt_ids)
self.last_output_token_count = 0
text = ""
for _ in self.stream_generate(self.model, self.tokenizer, prompt=prompt_ids, **completion_kwargs):
self.last_output_token_count += 1
text += _.text
for stop_sequence in prepared_stop_sequences:
if text.strip().endswith(stop_sequence):
text = text[: -len(stop_sequence)]
return self._to_message(text, tools_to_call_from)
return self._to_message(text, tools_to_call_from)
class TransformersModel(Model):
"""A class that uses Hugging Face's Transformers library for language model interaction.
@ -837,6 +960,7 @@ __all__ = [
"tool_role_conversions",
"get_clean_message_list",
"Model",
"MLXModel",
"TransformersModel",
"HfApiModel",
"LiteLLMModel",

View File

@ -14,6 +14,7 @@
# limitations under the License.
import json
import os
import sys
import unittest
from pathlib import Path
from typing import Optional
@ -22,7 +23,7 @@ from unittest.mock import MagicMock, patch
import pytest
from transformers.testing_utils import get_tests_dir
from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool
from smolagents import ChatMessage, HfApiModel, MLXModel, TransformersModel, models, tool
from smolagents.models import MessageRole, get_clean_message_list, parse_json_if_needed
@ -61,6 +62,13 @@ class ModelTests(unittest.TestCase):
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
model(messages, stop_sequences=["great"])
@unittest.skipUnless(sys.platform.startswith("darwin"), "requires macOS")
def test_get_mlx_message_no_tool(self):
model = MLXModel(model_id="HuggingFaceTB/SmolLM2-135M-Instruct", max_tokens=10)
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
output = model(messages, stop_sequences=["great"]).content
assert output.startswith("Hello")
def test_transformers_message_no_tool(self):
model = TransformersModel(
model_id="HuggingFaceTB/SmolLM2-135M-Instruct",