Add support for OpenTelemetry instrumentation 📊 (#200)
This commit is contained in:
parent
ce1cd6d906
commit
450934ce79
|
@ -14,7 +14,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, asdict
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
@ -54,12 +54,29 @@ if _is_package_available("litellm"):
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
|
def get_dict_from_nested_dataclasses(obj):
|
||||||
|
def convert(obj):
|
||||||
|
if hasattr(obj, "__dataclass_fields__"):
|
||||||
|
return {k: convert(v) for k, v in asdict(obj).items()}
|
||||||
|
return obj
|
||||||
|
|
||||||
|
return convert(obj)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChatMessageToolCallDefinition:
|
class ChatMessageToolCallDefinition:
|
||||||
arguments: Any
|
arguments: Any
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_hf_api(cls, tool_call_definition) -> "ChatMessageToolCallDefinition":
|
||||||
|
return cls(
|
||||||
|
arguments=tool_call_definition.arguments,
|
||||||
|
name=tool_call_definition.name,
|
||||||
|
description=tool_call_definition.description,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChatMessageToolCall:
|
class ChatMessageToolCall:
|
||||||
|
@ -67,6 +84,14 @@ class ChatMessageToolCall:
|
||||||
id: str
|
id: str
|
||||||
type: str
|
type: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_hf_api(cls, tool_call) -> "ChatMessageToolCall":
|
||||||
|
return cls(
|
||||||
|
function=ChatMessageToolCallDefinition.from_hf_api(tool_call.function),
|
||||||
|
id=tool_call.id,
|
||||||
|
type=tool_call.type,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChatMessage:
|
class ChatMessage:
|
||||||
|
@ -74,6 +99,19 @@ class ChatMessage:
|
||||||
content: Optional[str] = None
|
content: Optional[str] = None
|
||||||
tool_calls: Optional[List[ChatMessageToolCall]] = None
|
tool_calls: Optional[List[ChatMessageToolCall]] = None
|
||||||
|
|
||||||
|
def model_dump_json(self):
|
||||||
|
return json.dumps(get_dict_from_nested_dataclasses(self))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_hf_api(cls, message) -> "ChatMessage":
|
||||||
|
tool_calls = None
|
||||||
|
if getattr(message, "tool_calls", None) is not None:
|
||||||
|
tool_calls = [
|
||||||
|
ChatMessageToolCall.from_hf_api(tool_call)
|
||||||
|
for tool_call in message.tool_calls
|
||||||
|
]
|
||||||
|
return cls(role=message.role, content=message.content, tool_calls=tool_calls)
|
||||||
|
|
||||||
|
|
||||||
class MessageRole(str, Enum):
|
class MessageRole(str, Enum):
|
||||||
USER = "user"
|
USER = "user"
|
||||||
|
@ -283,7 +321,7 @@ class HfApiModel(Model):
|
||||||
)
|
)
|
||||||
self.last_input_token_count = response.usage.prompt_tokens
|
self.last_input_token_count = response.usage.prompt_tokens
|
||||||
self.last_output_token_count = response.usage.completion_tokens
|
self.last_output_token_count = response.usage.completion_tokens
|
||||||
return response.choices[0].message
|
return ChatMessage.from_hf_api(response.choices[0].message)
|
||||||
|
|
||||||
|
|
||||||
class TransformersModel(Model):
|
class TransformersModel(Model):
|
||||||
|
@ -315,14 +353,18 @@ class TransformersModel(Model):
|
||||||
logger.info(f"Using device: {self.device}")
|
logger.info(f"Using device: {self.device}")
|
||||||
try:
|
try:
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=self.device)
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id, device_map=self.device
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {default_model_id=}."
|
f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {default_model_id=}."
|
||||||
)
|
)
|
||||||
self.model_id = default_model_id
|
self.model_id = default_model_id
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
|
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=self.device)
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id, device_map=self.device
|
||||||
|
)
|
||||||
|
|
||||||
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
|
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
|
||||||
class StopOnStrings(StoppingCriteria):
|
class StopOnStrings(StoppingCriteria):
|
||||||
|
@ -551,4 +593,5 @@ __all__ = [
|
||||||
"HfApiModel",
|
"HfApiModel",
|
||||||
"LiteLLMModel",
|
"LiteLLMModel",
|
||||||
"OpenAIServerModel",
|
"OpenAIServerModel",
|
||||||
|
"ChatMessage",
|
||||||
]
|
]
|
||||||
|
|
|
@ -13,9 +13,10 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import unittest
|
import unittest
|
||||||
|
import json
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from smolagents import models, tool
|
from smolagents import models, tool, ChatMessage, HfApiModel
|
||||||
|
|
||||||
|
|
||||||
class ModelTests(unittest.TestCase):
|
class ModelTests(unittest.TestCase):
|
||||||
|
@ -38,3 +39,13 @@ class ModelTests(unittest.TestCase):
|
||||||
"properties"
|
"properties"
|
||||||
]["celsius"]
|
]["celsius"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_chatmessage_has_model_dumps_json(self):
|
||||||
|
message = ChatMessage("user", "Hello!")
|
||||||
|
data = json.loads(message.model_dump_json())
|
||||||
|
assert data["content"] == "Hello!"
|
||||||
|
|
||||||
|
def test_get_hfapi_message_no_tool(self):
|
||||||
|
model = HfApiModel()
|
||||||
|
messages = [{"role": "user", "content": "Hello!"}]
|
||||||
|
model(messages, stop_sequences=["great"])
|
||||||
|
|
Loading…
Reference in New Issue