diff --git a/examples/tool_calling_agent_from_any_llm.py b/examples/tool_calling_agent_from_any_llm.py index 7d906a8..48fee5a 100644 --- a/examples/tool_calling_agent_from_any_llm.py +++ b/examples/tool_calling_agent_from_any_llm.py @@ -6,7 +6,7 @@ from smolagents import tool, HfApiEngine, OpenAIEngine, AnthropicEngine, Transfo # llm_engine = AnthropicEngine("claude-3-5-sonnet-20240620") # llm_engine = HfApiEngine("meta-llama/Llama-3.3-70B-Instruct") # llm_engine = TransformersEngine("meta-llama/Llama-3.2-2B-Instruct") -llm_engine = LiteLLMEngine() +llm_engine = LiteLLMEngine("gpt-4o") @tool def get_weather(location: str) -> str: diff --git a/src/smolagents/llm_engines.py b/src/smolagents/llm_engines.py index 45658a6..7fa88d9 100644 --- a/src/smolagents/llm_engines.py +++ b/src/smolagents/llm_engines.py @@ -19,6 +19,7 @@ from enum import Enum from typing import Dict, List, Optional, Tuple from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList +import litellm import logging import os import random @@ -566,11 +567,32 @@ class AnthropicEngine: class LiteLLMEngine(): def __init__(self, model_id="anthropic/claude-3-5-sonnet-20240620"): self.model_id = model_id - import os, litellm - # IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs litellm.add_function_to_prompt = True - + self.last_input_token_count = 0 + self.last_output_token_count = 0 + + def __call__( + self, + messages: List[Dict[str, str]], + stop_sequences: Optional[List[str]] = None, + grammar: Optional[str] = None, + max_tokens: int = 1500, + ) -> str: + messages = get_clean_message_list( + messages, role_conversions=tool_role_conversions + ) + + response = litellm.completion( + model=self.model_id, + messages=messages, + stop=stop_sequences, + max_tokens=max_tokens, + ) + self.last_input_token_count = response.usage.prompt_tokens + self.last_output_token_count = response.usage.completion_tokens + return response.choices[0].message.content + def get_tool_call( self, messages: List[Dict[str, str]], @@ -578,19 +600,20 @@ class LiteLLMEngine(): stop_sequences: Optional[List[str]] = None, max_tokens: int = 1500, ): - from litellm import completion messages = get_clean_message_list( messages, role_conversions=tool_role_conversions ) - response = completion( + response = litellm.completion( model=self.model_id, messages=messages, tools=[get_json_schema(tool) for tool in available_tools], tool_choice="required", - max_tokens=max_tokens, stop=stop_sequences, + max_tokens=max_tokens, ) tool_calls = response.choices[0].message.tool_calls[0] + self.last_input_token_count = response.usage.prompt_tokens + self.last_output_token_count = response.usage.completion_tokens return tool_calls.function.name, tool_calls.function.arguments, tool_calls.id