Add support for OpenTelemetry instrumentation 📊 (#200)

This commit is contained in:
Aymeric Roucher 2025-01-15 12:10:52 +01:00 committed by GitHub
parent ce1cd6d906
commit 450934ce79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 59 additions and 5 deletions

View File

@ -14,7 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass, asdict
import json import json
import logging import logging
import os import os
@ -54,12 +54,29 @@ if _is_package_available("litellm"):
import litellm import litellm
def get_dict_from_nested_dataclasses(obj):
def convert(obj):
if hasattr(obj, "__dataclass_fields__"):
return {k: convert(v) for k, v in asdict(obj).items()}
return obj
return convert(obj)
@dataclass @dataclass
class ChatMessageToolCallDefinition: class ChatMessageToolCallDefinition:
arguments: Any arguments: Any
name: str name: str
description: Optional[str] = None description: Optional[str] = None
@classmethod
def from_hf_api(cls, tool_call_definition) -> "ChatMessageToolCallDefinition":
return cls(
arguments=tool_call_definition.arguments,
name=tool_call_definition.name,
description=tool_call_definition.description,
)
@dataclass @dataclass
class ChatMessageToolCall: class ChatMessageToolCall:
@ -67,6 +84,14 @@ class ChatMessageToolCall:
id: str id: str
type: str type: str
@classmethod
def from_hf_api(cls, tool_call) -> "ChatMessageToolCall":
return cls(
function=ChatMessageToolCallDefinition.from_hf_api(tool_call.function),
id=tool_call.id,
type=tool_call.type,
)
@dataclass @dataclass
class ChatMessage: class ChatMessage:
@ -74,6 +99,19 @@ class ChatMessage:
content: Optional[str] = None content: Optional[str] = None
tool_calls: Optional[List[ChatMessageToolCall]] = None tool_calls: Optional[List[ChatMessageToolCall]] = None
def model_dump_json(self):
return json.dumps(get_dict_from_nested_dataclasses(self))
@classmethod
def from_hf_api(cls, message) -> "ChatMessage":
tool_calls = None
if getattr(message, "tool_calls", None) is not None:
tool_calls = [
ChatMessageToolCall.from_hf_api(tool_call)
for tool_call in message.tool_calls
]
return cls(role=message.role, content=message.content, tool_calls=tool_calls)
class MessageRole(str, Enum): class MessageRole(str, Enum):
USER = "user" USER = "user"
@ -283,7 +321,7 @@ class HfApiModel(Model):
) )
self.last_input_token_count = response.usage.prompt_tokens self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens self.last_output_token_count = response.usage.completion_tokens
return response.choices[0].message return ChatMessage.from_hf_api(response.choices[0].message)
class TransformersModel(Model): class TransformersModel(Model):
@ -315,14 +353,18 @@ class TransformersModel(Model):
logger.info(f"Using device: {self.device}") logger.info(f"Using device: {self.device}")
try: try:
self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=self.device) self.model = AutoModelForCausalLM.from_pretrained(
model_id, device_map=self.device
)
except Exception as e: except Exception as e:
logger.warning( logger.warning(
f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {default_model_id=}." f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {default_model_id=}."
) )
self.model_id = default_model_id self.model_id = default_model_id
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=self.device) self.model = AutoModelForCausalLM.from_pretrained(
model_id, device_map=self.device
)
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
class StopOnStrings(StoppingCriteria): class StopOnStrings(StoppingCriteria):
@ -551,4 +593,5 @@ __all__ = [
"HfApiModel", "HfApiModel",
"LiteLLMModel", "LiteLLMModel",
"OpenAIServerModel", "OpenAIServerModel",
"ChatMessage",
] ]

View File

@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import unittest import unittest
import json
from typing import Optional from typing import Optional
from smolagents import models, tool from smolagents import models, tool, ChatMessage, HfApiModel
class ModelTests(unittest.TestCase): class ModelTests(unittest.TestCase):
@ -38,3 +39,13 @@ class ModelTests(unittest.TestCase):
"properties" "properties"
]["celsius"] ]["celsius"]
) )
def test_chatmessage_has_model_dumps_json(self):
message = ChatMessage("user", "Hello!")
data = json.loads(message.model_dump_json())
assert data["content"] == "Hello!"
def test_get_hfapi_message_no_tool(self):
model = HfApiModel()
messages = [{"role": "user", "content": "Hello!"}]
model(messages, stop_sequences=["great"])