Add standard call to LiteLLM engine

This commit is contained in:
Aymeric 2024-12-24 19:55:34 +01:00
parent 1e357cee7f
commit 162d4dc362
2 changed files with 30 additions and 7 deletions

View File

@ -6,7 +6,7 @@ from smolagents import tool, HfApiEngine, OpenAIEngine, AnthropicEngine, Transfo
# llm_engine = AnthropicEngine("claude-3-5-sonnet-20240620")
# llm_engine = HfApiEngine("meta-llama/Llama-3.3-70B-Instruct")
# llm_engine = TransformersEngine("meta-llama/Llama-3.2-2B-Instruct")
llm_engine = LiteLLMEngine()
llm_engine = LiteLLMEngine("gpt-4o")
@tool
def get_weather(location: str) -> str:

View File

@ -19,6 +19,7 @@ from enum import Enum
from typing import Dict, List, Optional, Tuple
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
import litellm
import logging
import os
import random
@ -566,11 +567,32 @@ class AnthropicEngine:
class LiteLLMEngine():
def __init__(self, model_id="anthropic/claude-3-5-sonnet-20240620"):
self.model_id = model_id
import os, litellm
# IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs
litellm.add_function_to_prompt = True
self.last_input_token_count = 0
self.last_output_token_count = 0
def __call__(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
) -> str:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
response = litellm.completion(
model=self.model_id,
messages=messages,
stop=stop_sequences,
max_tokens=max_tokens,
)
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
return response.choices[0].message.content
def get_tool_call(
self,
messages: List[Dict[str, str]],
@ -578,19 +600,20 @@ class LiteLLMEngine():
stop_sequences: Optional[List[str]] = None,
max_tokens: int = 1500,
):
from litellm import completion
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
response = completion(
response = litellm.completion(
model=self.model_id,
messages=messages,
tools=[get_json_schema(tool) for tool in available_tools],
tool_choice="required",
max_tokens=max_tokens,
stop=stop_sequences,
max_tokens=max_tokens,
)
tool_calls = response.choices[0].message.tool_calls[0]
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
return tool_calls.function.name, tool_calls.function.arguments, tool_calls.id