Fix for MLX Model when stop sequence is followed by any chars (#623)
This commit is contained in:
parent
2347631a55
commit
cb2218a86f
|
@ -525,7 +525,7 @@ class MLXModel(Model):
|
||||||
|
|
||||||
def _to_message(self, text, tools_to_call_from):
|
def _to_message(self, text, tools_to_call_from):
|
||||||
if tools_to_call_from:
|
if tools_to_call_from:
|
||||||
# tmp solution for extracting tool JSON without assuming a specific model output format
|
# solution for extracting tool JSON without assuming a specific model output format
|
||||||
maybe_json = "{" + text.split("{", 1)[-1][::-1].split("}", 1)[-1][::-1] + "}"
|
maybe_json = "{" + text.split("{", 1)[-1][::-1].split("}", 1)[-1][::-1] + "}"
|
||||||
parsed_text = json.loads(maybe_json)
|
parsed_text = json.loads(maybe_json)
|
||||||
tool_name = parsed_text.get(self.tool_name_key, None)
|
tool_name = parsed_text.get(self.tool_name_key, None)
|
||||||
|
@ -579,8 +579,9 @@ class MLXModel(Model):
|
||||||
self.last_output_token_count += 1
|
self.last_output_token_count += 1
|
||||||
text += _.text
|
text += _.text
|
||||||
for stop_sequence in prepared_stop_sequences:
|
for stop_sequence in prepared_stop_sequences:
|
||||||
if text.strip().endswith(stop_sequence):
|
stop_sequence_start = text.rfind(stop_sequence)
|
||||||
text = text[: -len(stop_sequence)]
|
if stop_sequence_start != -1:
|
||||||
|
text = text[:stop_sequence_start]
|
||||||
return self._to_message(text, tools_to_call_from)
|
return self._to_message(text, tools_to_call_from)
|
||||||
|
|
||||||
return self._to_message(text, tools_to_call_from)
|
return self._to_message(text, tools_to_call_from)
|
||||||
|
|
|
@ -69,6 +69,19 @@ class ModelTests(unittest.TestCase):
|
||||||
output = model(messages, stop_sequences=["great"]).content
|
output = model(messages, stop_sequences=["great"]).content
|
||||||
assert output.startswith("Hello")
|
assert output.startswith("Hello")
|
||||||
|
|
||||||
|
@unittest.skipUnless(sys.platform.startswith("darwin"), "requires macOS")
|
||||||
|
def test_get_mlx_message_tricky_stop_sequence(self):
|
||||||
|
# In this test HuggingFaceTB/SmolLM2-135M-Instruct generates the token ">'"
|
||||||
|
# which is required to test capturing stop_sequences that have extra chars at the end.
|
||||||
|
model = MLXModel(model_id="HuggingFaceTB/SmolLM2-135M-Instruct", max_tokens=100)
|
||||||
|
stop_sequence = " print '>"
|
||||||
|
messages = [{"role": "user", "content": [{"type": "text", "text": f"Please{stop_sequence}'"}]}]
|
||||||
|
# check our assumption that that ">" is followed by "'"
|
||||||
|
assert model.tokenizer.vocab[">'"]
|
||||||
|
assert model(messages, stop_sequences=[]).content == f"I'm ready to help you{stop_sequence}'"
|
||||||
|
# check stop_sequence capture when output has trailing chars
|
||||||
|
assert model(messages, stop_sequences=[stop_sequence]).content == "I'm ready to help you"
|
||||||
|
|
||||||
def test_transformers_message_no_tool(self):
|
def test_transformers_message_no_tool(self):
|
||||||
model = TransformersModel(
|
model = TransformersModel(
|
||||||
model_id="HuggingFaceTB/SmolLM2-135M-Instruct",
|
model_id="HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||||
|
|
Loading…
Reference in New Issue