From 450934ce796d183993fc23f5e613f4e11c214276 Mon Sep 17 00:00:00 2001 From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> Date: Wed, 15 Jan 2025 12:10:52 +0100 Subject: [PATCH] =?UTF-8?q?Add=20support=20for=20OpenTelemetry=20instrumen?= =?UTF-8?q?tation=20=F0=9F=93=8A=20(#200)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/smolagents/models.py | 51 ++++++++++++++++++++++++++++++++++++---- tests/test_models.py | 13 +++++++++- 2 files changed, 59 insertions(+), 5 deletions(-) diff --git a/src/smolagents/models.py b/src/smolagents/models.py index f25ced9..6a1ae77 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -14,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass +from dataclasses import dataclass, asdict import json import logging import os @@ -54,12 +54,29 @@ if _is_package_available("litellm"): import litellm +def get_dict_from_nested_dataclasses(obj): + def convert(obj): + if hasattr(obj, "__dataclass_fields__"): + return {k: convert(v) for k, v in asdict(obj).items()} + return obj + + return convert(obj) + + @dataclass class ChatMessageToolCallDefinition: arguments: Any name: str description: Optional[str] = None + @classmethod + def from_hf_api(cls, tool_call_definition) -> "ChatMessageToolCallDefinition": + return cls( + arguments=tool_call_definition.arguments, + name=tool_call_definition.name, + description=tool_call_definition.description, + ) + @dataclass class ChatMessageToolCall: @@ -67,6 +84,14 @@ class ChatMessageToolCall: id: str type: str + @classmethod + def from_hf_api(cls, tool_call) -> "ChatMessageToolCall": + return cls( + function=ChatMessageToolCallDefinition.from_hf_api(tool_call.function), + id=tool_call.id, + type=tool_call.type, + ) + @dataclass class ChatMessage: @@ -74,6 +99,19 @@ class ChatMessage: content: Optional[str] = None tool_calls: Optional[List[ChatMessageToolCall]] = None + def model_dump_json(self): + return json.dumps(get_dict_from_nested_dataclasses(self)) + + @classmethod + def from_hf_api(cls, message) -> "ChatMessage": + tool_calls = None + if getattr(message, "tool_calls", None) is not None: + tool_calls = [ + ChatMessageToolCall.from_hf_api(tool_call) + for tool_call in message.tool_calls + ] + return cls(role=message.role, content=message.content, tool_calls=tool_calls) + class MessageRole(str, Enum): USER = "user" @@ -283,7 +321,7 @@ class HfApiModel(Model): ) self.last_input_token_count = response.usage.prompt_tokens self.last_output_token_count = response.usage.completion_tokens - return response.choices[0].message + return ChatMessage.from_hf_api(response.choices[0].message) class TransformersModel(Model): @@ -315,14 +353,18 @@ class TransformersModel(Model): logger.info(f"Using device: {self.device}") try: self.tokenizer = AutoTokenizer.from_pretrained(model_id) - self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=self.device) + self.model = AutoModelForCausalLM.from_pretrained( + model_id, device_map=self.device + ) except Exception as e: logger.warning( f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {default_model_id=}." ) self.model_id = default_model_id self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) - self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=self.device) + self.model = AutoModelForCausalLM.from_pretrained( + model_id, device_map=self.device + ) def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: class StopOnStrings(StoppingCriteria): @@ -551,4 +593,5 @@ __all__ = [ "HfApiModel", "LiteLLMModel", "OpenAIServerModel", + "ChatMessage", ] diff --git a/tests/test_models.py b/tests/test_models.py index dbd93ce..5a3a821 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -13,9 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import unittest +import json from typing import Optional -from smolagents import models, tool +from smolagents import models, tool, ChatMessage, HfApiModel class ModelTests(unittest.TestCase): @@ -38,3 +39,13 @@ class ModelTests(unittest.TestCase): "properties" ]["celsius"] ) + + def test_chatmessage_has_model_dumps_json(self): + message = ChatMessage("user", "Hello!") + data = json.loads(message.model_dump_json()) + assert data["content"] == "Hello!" + + def test_get_hfapi_message_no_tool(self): + model = HfApiModel() + messages = [{"role": "user", "content": "Hello!"}] + model(messages, stop_sequences=["great"])