Add support for OpenTelemetry instrumentation 📊 (#200)

This commit is contained in:
Aymeric Roucher 2025-01-15 12:10:52 +01:00 committed by GitHub
parent ce1cd6d906
commit 450934ce79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 59 additions and 5 deletions

View File

@ -14,7 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from dataclasses import dataclass, asdict
import json
import logging
import os
@ -54,12 +54,29 @@ if _is_package_available("litellm"):
import litellm
def get_dict_from_nested_dataclasses(obj):
def convert(obj):
if hasattr(obj, "__dataclass_fields__"):
return {k: convert(v) for k, v in asdict(obj).items()}
return obj
return convert(obj)
@dataclass
class ChatMessageToolCallDefinition:
arguments: Any
name: str
description: Optional[str] = None
@classmethod
def from_hf_api(cls, tool_call_definition) -> "ChatMessageToolCallDefinition":
return cls(
arguments=tool_call_definition.arguments,
name=tool_call_definition.name,
description=tool_call_definition.description,
)
@dataclass
class ChatMessageToolCall:
@ -67,6 +84,14 @@ class ChatMessageToolCall:
id: str
type: str
@classmethod
def from_hf_api(cls, tool_call) -> "ChatMessageToolCall":
return cls(
function=ChatMessageToolCallDefinition.from_hf_api(tool_call.function),
id=tool_call.id,
type=tool_call.type,
)
@dataclass
class ChatMessage:
@ -74,6 +99,19 @@ class ChatMessage:
content: Optional[str] = None
tool_calls: Optional[List[ChatMessageToolCall]] = None
def model_dump_json(self):
return json.dumps(get_dict_from_nested_dataclasses(self))
@classmethod
def from_hf_api(cls, message) -> "ChatMessage":
tool_calls = None
if getattr(message, "tool_calls", None) is not None:
tool_calls = [
ChatMessageToolCall.from_hf_api(tool_call)
for tool_call in message.tool_calls
]
return cls(role=message.role, content=message.content, tool_calls=tool_calls)
class MessageRole(str, Enum):
USER = "user"
@ -283,7 +321,7 @@ class HfApiModel(Model):
)
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
return response.choices[0].message
return ChatMessage.from_hf_api(response.choices[0].message)
class TransformersModel(Model):
@ -315,14 +353,18 @@ class TransformersModel(Model):
logger.info(f"Using device: {self.device}")
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=self.device)
self.model = AutoModelForCausalLM.from_pretrained(
model_id, device_map=self.device
)
except Exception as e:
logger.warning(
f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {default_model_id=}."
)
self.model_id = default_model_id
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=self.device)
self.model = AutoModelForCausalLM.from_pretrained(
model_id, device_map=self.device
)
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
class StopOnStrings(StoppingCriteria):
@ -551,4 +593,5 @@ __all__ = [
"HfApiModel",
"LiteLLMModel",
"OpenAIServerModel",
"ChatMessage",
]

View File

@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import json
from typing import Optional
from smolagents import models, tool
from smolagents import models, tool, ChatMessage, HfApiModel
class ModelTests(unittest.TestCase):
@ -38,3 +39,13 @@ class ModelTests(unittest.TestCase):
"properties"
]["celsius"]
)
def test_chatmessage_has_model_dumps_json(self):
message = ChatMessage("user", "Hello!")
data = json.loads(message.model_dump_json())
assert data["content"] == "Hello!"
def test_get_hfapi_message_no_tool(self):
model = HfApiModel()
messages = [{"role": "user", "content": "Hello!"}]
model(messages, stop_sequences=["great"])