Fix for MLX Model when stop sequence is followed by any chars (#623)

This commit is contained in:
Joe 2025-02-14 03:17:56 -08:00 committed by GitHub
parent 2347631a55
commit cb2218a86f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 3 deletions

View File

@ -525,7 +525,7 @@ class MLXModel(Model):
def _to_message(self, text, tools_to_call_from): def _to_message(self, text, tools_to_call_from):
if tools_to_call_from: if tools_to_call_from:
# tmp solution for extracting tool JSON without assuming a specific model output format # solution for extracting tool JSON without assuming a specific model output format
maybe_json = "{" + text.split("{", 1)[-1][::-1].split("}", 1)[-1][::-1] + "}" maybe_json = "{" + text.split("{", 1)[-1][::-1].split("}", 1)[-1][::-1] + "}"
parsed_text = json.loads(maybe_json) parsed_text = json.loads(maybe_json)
tool_name = parsed_text.get(self.tool_name_key, None) tool_name = parsed_text.get(self.tool_name_key, None)
@ -579,8 +579,9 @@ class MLXModel(Model):
self.last_output_token_count += 1 self.last_output_token_count += 1
text += _.text text += _.text
for stop_sequence in prepared_stop_sequences: for stop_sequence in prepared_stop_sequences:
if text.strip().endswith(stop_sequence): stop_sequence_start = text.rfind(stop_sequence)
text = text[: -len(stop_sequence)] if stop_sequence_start != -1:
text = text[:stop_sequence_start]
return self._to_message(text, tools_to_call_from) return self._to_message(text, tools_to_call_from)
return self._to_message(text, tools_to_call_from) return self._to_message(text, tools_to_call_from)

View File

@ -69,6 +69,19 @@ class ModelTests(unittest.TestCase):
output = model(messages, stop_sequences=["great"]).content output = model(messages, stop_sequences=["great"]).content
assert output.startswith("Hello") assert output.startswith("Hello")
@unittest.skipUnless(sys.platform.startswith("darwin"), "requires macOS")
def test_get_mlx_message_tricky_stop_sequence(self):
# In this test HuggingFaceTB/SmolLM2-135M-Instruct generates the token ">'"
# which is required to test capturing stop_sequences that have extra chars at the end.
model = MLXModel(model_id="HuggingFaceTB/SmolLM2-135M-Instruct", max_tokens=100)
stop_sequence = " print '>"
messages = [{"role": "user", "content": [{"type": "text", "text": f"Please{stop_sequence}'"}]}]
# check our assumption that that ">" is followed by "'"
assert model.tokenizer.vocab[">'"]
assert model(messages, stop_sequences=[]).content == f"I'm ready to help you{stop_sequence}'"
# check stop_sequence capture when output has trailing chars
assert model(messages, stop_sequences=[stop_sequence]).content == "I'm ready to help you"
def test_transformers_message_no_tool(self): def test_transformers_message_no_tool(self):
model = TransformersModel( model = TransformersModel(
model_id="HuggingFaceTB/SmolLM2-135M-Instruct", model_id="HuggingFaceTB/SmolLM2-135M-Instruct",