MLX model support (#300)
This commit is contained in:
parent
bca3a9bc13
commit
9b96199d00
|
@ -28,10 +28,11 @@ To initialize a minimal agent, you need at least these two arguments:
|
||||||
- [`HfApiModel`] leverages a `huggingface_hub.InferenceClient` under the hood and supports all Inference Providers on the Hub.
|
- [`HfApiModel`] leverages a `huggingface_hub.InferenceClient` under the hood and supports all Inference Providers on the Hub.
|
||||||
- [`LiteLLMModel`] similarly lets you call 100+ different models and providers through [LiteLLM](https://docs.litellm.ai/)!
|
- [`LiteLLMModel`] similarly lets you call 100+ different models and providers through [LiteLLM](https://docs.litellm.ai/)!
|
||||||
- [`AzureOpenAIServerModel`] allows you to use OpenAI models deployed in [Azure](https://azure.microsoft.com/en-us/products/ai-services/openai-service).
|
- [`AzureOpenAIServerModel`] allows you to use OpenAI models deployed in [Azure](https://azure.microsoft.com/en-us/products/ai-services/openai-service).
|
||||||
|
- [`MLXModel`] creates a [mlx-lm](https://pypi.org/project/mlx-lm/) pipeline to run inference on your local machine.
|
||||||
|
|
||||||
- `tools`, a list of `Tools` that the agent can use to solve the task. It can be an empty list. You can also add the default toolbox on top of your `tools` list by defining the optional argument `add_base_tools=True`.
|
- `tools`, a list of `Tools` that the agent can use to solve the task. It can be an empty list. You can also add the default toolbox on top of your `tools` list by defining the optional argument `add_base_tools=True`.
|
||||||
|
|
||||||
Once you have these two arguments, `tools` and `model`, you can create an agent and run it. You can use any LLM you'd like, either through [Inference Providers](https://huggingface.co/blog/inference-providers), [transformers](https://github.com/huggingface/transformers/), [ollama](https://ollama.com/), [LiteLLM](https://www.litellm.ai/), or [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service).
|
Once you have these two arguments, `tools` and `model`, you can create an agent and run it. You can use any LLM you'd like, either through [Inference Providers](https://huggingface.co/blog/inference-providers), [transformers](https://github.com/huggingface/transformers/), [ollama](https://ollama.com/), [LiteLLM](https://www.litellm.ai/), [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service), or [mlx-lm](https://pypi.org/project/mlx-lm/).
|
||||||
|
|
||||||
<hfoptions id="Pick a LLM">
|
<hfoptions id="Pick a LLM">
|
||||||
<hfoption id="HF Inference API">
|
<hfoption id="HF Inference API">
|
||||||
|
@ -148,6 +149,19 @@ agent.run(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="mlx-lm">
|
||||||
|
|
||||||
|
```python
|
||||||
|
# !pip install smolagents[mlx-lm]
|
||||||
|
from smolagents import CodeAgent, MLXModel
|
||||||
|
|
||||||
|
mlx_model = MLXModel("mlx-community/Qwen2.5-Coder-32B-Instruct-4bit")
|
||||||
|
agent = CodeAgent(model=mlx_model, tools=[], add_base_tools=True)
|
||||||
|
|
||||||
|
agent.run("Could you give me the 118th number in the Fibonacci sequence?")
|
||||||
|
```
|
||||||
|
|
||||||
</hfoption>
|
</hfoption>
|
||||||
</hfoptions>
|
</hfoptions>
|
||||||
|
|
||||||
|
|
|
@ -148,3 +148,22 @@ model = AzureOpenAIServerModel(
|
||||||
```
|
```
|
||||||
|
|
||||||
[[autodoc]] AzureOpenAIServerModel
|
[[autodoc]] AzureOpenAIServerModel
|
||||||
|
|
||||||
|
### MLXModel
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
from smolagents import MLXModel
|
||||||
|
|
||||||
|
model = MLXModel(model_id="HuggingFaceTB/SmolLM-135M-Instruct")
|
||||||
|
|
||||||
|
print(model([{"role": "user", "content": "Ok!"}], stop_sequences=["great"]))
|
||||||
|
```
|
||||||
|
```text
|
||||||
|
>>> What a
|
||||||
|
```
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> You must have `mlx-lm` installed on your machine. Please run `pip install smolagents[mlx-lm]` if it's not the case.
|
||||||
|
|
||||||
|
[[autodoc]] MLXModel
|
||||||
|
|
|
@ -46,6 +46,9 @@ mcp = [
|
||||||
"mcpadapt>=0.0.6",
|
"mcpadapt>=0.0.6",
|
||||||
"mcp",
|
"mcp",
|
||||||
]
|
]
|
||||||
|
mlx-lm = [
|
||||||
|
"mlx-lm"
|
||||||
|
]
|
||||||
openai = [
|
openai = [
|
||||||
"openai>=1.58.1"
|
"openai>=1.58.1"
|
||||||
]
|
]
|
||||||
|
|
|
@ -18,6 +18,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import uuid
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
@ -415,6 +416,128 @@ class HfApiModel(Model):
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
class MLXModel(Model):
|
||||||
|
"""A class to interact with models loaded using MLX on Apple silicon.
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> You must have `mlx-lm` installed on your machine. Please run `pip install smolagents[mlx-lm]` if it's not the case.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
model_id (str):
|
||||||
|
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
|
||||||
|
tool_name_key (str):
|
||||||
|
The key, which can usually be found in the model's chat template, for retrieving a tool name.
|
||||||
|
tool_arguments_key (str):
|
||||||
|
The key, which can usually be found in the model's chat template, for retrieving tool arguments.
|
||||||
|
trust_remote_code (bool):
|
||||||
|
Some models on the Hub require running remote code: for this model, you would have to set this flag to True.
|
||||||
|
kwargs (dict, *optional*):
|
||||||
|
Any additional keyword arguments that you want to use in model.generate(), for instance `max_tokens`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
>>> engine = MLXModel(
|
||||||
|
... model_id="mlx-community/Qwen2.5-Coder-32B-Instruct-4bit",
|
||||||
|
... max_tokens=10000,
|
||||||
|
... )
|
||||||
|
>>> messages = [
|
||||||
|
... {
|
||||||
|
... "role": "user",
|
||||||
|
... "content": [
|
||||||
|
... {"type": "text", "text": "Explain quantum mechanics in simple terms."}
|
||||||
|
... ]
|
||||||
|
... }
|
||||||
|
... ]
|
||||||
|
>>> response = engine(messages, stop_sequences=["END"])
|
||||||
|
>>> print(response)
|
||||||
|
"Quantum mechanics is the branch of physics that studies..."
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
tool_name_key: str = "name",
|
||||||
|
tool_arguments_key: str = "arguments",
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
if not _is_package_available("mlx_lm"):
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"Please install 'mlx-lm' extra to use 'MLXModel': `pip install 'smolagents[mlx-lm]'`"
|
||||||
|
)
|
||||||
|
import mlx_lm
|
||||||
|
|
||||||
|
self.model_id = model_id
|
||||||
|
self.model, self.tokenizer = mlx_lm.load(model_id, tokenizer_config={"trust_remote_code": trust_remote_code})
|
||||||
|
self.stream_generate = mlx_lm.stream_generate
|
||||||
|
self.tool_name_key = tool_name_key
|
||||||
|
self.tool_arguments_key = tool_arguments_key
|
||||||
|
|
||||||
|
def _to_message(self, text, tools_to_call_from):
|
||||||
|
if tools_to_call_from:
|
||||||
|
# tmp solution for extracting tool JSON without assuming a specific model output format
|
||||||
|
maybe_json = "{" + text.split("{", 1)[-1][::-1].split("}", 1)[-1][::-1] + "}"
|
||||||
|
parsed_text = json.loads(maybe_json)
|
||||||
|
tool_name = parsed_text.get(self.tool_name_key, None)
|
||||||
|
tool_arguments = parsed_text.get(self.tool_arguments_key, None)
|
||||||
|
if tool_name:
|
||||||
|
return ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
ChatMessageToolCall(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
type="function",
|
||||||
|
function=ChatMessageToolCallDefinition(name=tool_name, arguments=tool_arguments),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return ChatMessage(role="assistant", content=text)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
grammar: Optional[str] = None,
|
||||||
|
tools_to_call_from: Optional[List[Tool]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> ChatMessage:
|
||||||
|
completion_kwargs = self._prepare_completion_kwargs(
|
||||||
|
flatten_messages_as_text=True, # mlx-lm doesn't support vision models
|
||||||
|
messages=messages,
|
||||||
|
stop_sequences=stop_sequences,
|
||||||
|
grammar=grammar,
|
||||||
|
tools_to_call_from=tools_to_call_from,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
messages = completion_kwargs.pop("messages")
|
||||||
|
prepared_stop_sequences = completion_kwargs.pop("stop", [])
|
||||||
|
tools = completion_kwargs.pop("tools", None)
|
||||||
|
completion_kwargs.pop("tool_choice", None)
|
||||||
|
|
||||||
|
prompt_ids = self.tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tools=tools,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.last_input_token_count = len(prompt_ids)
|
||||||
|
self.last_output_token_count = 0
|
||||||
|
text = ""
|
||||||
|
|
||||||
|
for _ in self.stream_generate(self.model, self.tokenizer, prompt=prompt_ids, **completion_kwargs):
|
||||||
|
self.last_output_token_count += 1
|
||||||
|
text += _.text
|
||||||
|
for stop_sequence in prepared_stop_sequences:
|
||||||
|
if text.strip().endswith(stop_sequence):
|
||||||
|
text = text[: -len(stop_sequence)]
|
||||||
|
return self._to_message(text, tools_to_call_from)
|
||||||
|
|
||||||
|
return self._to_message(text, tools_to_call_from)
|
||||||
|
|
||||||
|
|
||||||
class TransformersModel(Model):
|
class TransformersModel(Model):
|
||||||
"""A class that uses Hugging Face's Transformers library for language model interaction.
|
"""A class that uses Hugging Face's Transformers library for language model interaction.
|
||||||
|
|
||||||
|
@ -837,6 +960,7 @@ __all__ = [
|
||||||
"tool_role_conversions",
|
"tool_role_conversions",
|
||||||
"get_clean_message_list",
|
"get_clean_message_list",
|
||||||
"Model",
|
"Model",
|
||||||
|
"MLXModel",
|
||||||
"TransformersModel",
|
"TransformersModel",
|
||||||
"HfApiModel",
|
"HfApiModel",
|
||||||
"LiteLLMModel",
|
"LiteLLMModel",
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
@ -22,7 +23,7 @@ from unittest.mock import MagicMock, patch
|
||||||
import pytest
|
import pytest
|
||||||
from transformers.testing_utils import get_tests_dir
|
from transformers.testing_utils import get_tests_dir
|
||||||
|
|
||||||
from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool
|
from smolagents import ChatMessage, HfApiModel, MLXModel, TransformersModel, models, tool
|
||||||
from smolagents.models import MessageRole, get_clean_message_list, parse_json_if_needed
|
from smolagents.models import MessageRole, get_clean_message_list, parse_json_if_needed
|
||||||
|
|
||||||
|
|
||||||
|
@ -61,6 +62,13 @@ class ModelTests(unittest.TestCase):
|
||||||
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
|
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
|
||||||
model(messages, stop_sequences=["great"])
|
model(messages, stop_sequences=["great"])
|
||||||
|
|
||||||
|
@unittest.skipUnless(sys.platform.startswith("darwin"), "requires macOS")
|
||||||
|
def test_get_mlx_message_no_tool(self):
|
||||||
|
model = MLXModel(model_id="HuggingFaceTB/SmolLM2-135M-Instruct", max_tokens=10)
|
||||||
|
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
|
||||||
|
output = model(messages, stop_sequences=["great"]).content
|
||||||
|
assert output.startswith("Hello")
|
||||||
|
|
||||||
def test_transformers_message_no_tool(self):
|
def test_transformers_message_no_tool(self):
|
||||||
model = TransformersModel(
|
model = TransformersModel(
|
||||||
model_id="HuggingFaceTB/SmolLM2-135M-Instruct",
|
model_id="HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||||
|
|
Loading…
Reference in New Issue