diff --git a/src/smolagents/models.py b/src/smolagents/models.py index b576846..11be825 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -525,7 +525,7 @@ class MLXModel(Model): def _to_message(self, text, tools_to_call_from): if tools_to_call_from: - # tmp solution for extracting tool JSON without assuming a specific model output format + # solution for extracting tool JSON without assuming a specific model output format maybe_json = "{" + text.split("{", 1)[-1][::-1].split("}", 1)[-1][::-1] + "}" parsed_text = json.loads(maybe_json) tool_name = parsed_text.get(self.tool_name_key, None) @@ -579,8 +579,9 @@ class MLXModel(Model): self.last_output_token_count += 1 text += _.text for stop_sequence in prepared_stop_sequences: - if text.strip().endswith(stop_sequence): - text = text[: -len(stop_sequence)] + stop_sequence_start = text.rfind(stop_sequence) + if stop_sequence_start != -1: + text = text[:stop_sequence_start] return self._to_message(text, tools_to_call_from) return self._to_message(text, tools_to_call_from) diff --git a/tests/test_models.py b/tests/test_models.py index 0e7a42c..ea70983 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -69,6 +69,19 @@ class ModelTests(unittest.TestCase): output = model(messages, stop_sequences=["great"]).content assert output.startswith("Hello") + @unittest.skipUnless(sys.platform.startswith("darwin"), "requires macOS") + def test_get_mlx_message_tricky_stop_sequence(self): + # In this test HuggingFaceTB/SmolLM2-135M-Instruct generates the token ">'" + # which is required to test capturing stop_sequences that have extra chars at the end. + model = MLXModel(model_id="HuggingFaceTB/SmolLM2-135M-Instruct", max_tokens=100) + stop_sequence = " print '>" + messages = [{"role": "user", "content": [{"type": "text", "text": f"Please{stop_sequence}'"}]}] + # check our assumption that that ">" is followed by "'" + assert model.tokenizer.vocab[">'"] + assert model(messages, stop_sequences=[]).content == f"I'm ready to help you{stop_sequence}'" + # check stop_sequence capture when output has trailing chars + assert model(messages, stop_sequences=[stop_sequence]).content == "I'm ready to help you" + def test_transformers_message_no_tool(self): model = TransformersModel( model_id="HuggingFaceTB/SmolLM2-135M-Instruct",