Add standard call to LiteLLM engine
This commit is contained in:
parent
1e357cee7f
commit
162d4dc362
|
@ -6,7 +6,7 @@ from smolagents import tool, HfApiEngine, OpenAIEngine, AnthropicEngine, Transfo
|
|||
# llm_engine = AnthropicEngine("claude-3-5-sonnet-20240620")
|
||||
# llm_engine = HfApiEngine("meta-llama/Llama-3.3-70B-Instruct")
|
||||
# llm_engine = TransformersEngine("meta-llama/Llama-3.2-2B-Instruct")
|
||||
llm_engine = LiteLLMEngine()
|
||||
llm_engine = LiteLLMEngine("gpt-4o")
|
||||
|
||||
@tool
|
||||
def get_weather(location: str) -> str:
|
||||
|
|
|
@ -19,6 +19,7 @@ from enum import Enum
|
|||
from typing import Dict, List, Optional, Tuple
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
|
||||
|
||||
import litellm
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
|
@ -566,11 +567,32 @@ class AnthropicEngine:
|
|||
class LiteLLMEngine():
|
||||
def __init__(self, model_id="anthropic/claude-3-5-sonnet-20240620"):
|
||||
self.model_id = model_id
|
||||
import os, litellm
|
||||
|
||||
# IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs
|
||||
litellm.add_function_to_prompt = True
|
||||
|
||||
self.last_input_token_count = 0
|
||||
self.last_output_token_count = 0
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500,
|
||||
) -> str:
|
||||
messages = get_clean_message_list(
|
||||
messages, role_conversions=tool_role_conversions
|
||||
)
|
||||
|
||||
response = litellm.completion(
|
||||
model=self.model_id,
|
||||
messages=messages,
|
||||
stop=stop_sequences,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
self.last_input_token_count = response.usage.prompt_tokens
|
||||
self.last_output_token_count = response.usage.completion_tokens
|
||||
return response.choices[0].message.content
|
||||
|
||||
def get_tool_call(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
|
@ -578,19 +600,20 @@ class LiteLLMEngine():
|
|||
stop_sequences: Optional[List[str]] = None,
|
||||
max_tokens: int = 1500,
|
||||
):
|
||||
from litellm import completion
|
||||
messages = get_clean_message_list(
|
||||
messages, role_conversions=tool_role_conversions
|
||||
)
|
||||
response = completion(
|
||||
response = litellm.completion(
|
||||
model=self.model_id,
|
||||
messages=messages,
|
||||
tools=[get_json_schema(tool) for tool in available_tools],
|
||||
tool_choice="required",
|
||||
max_tokens=max_tokens,
|
||||
stop=stop_sequences,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
tool_calls = response.choices[0].message.tool_calls[0]
|
||||
self.last_input_token_count = response.usage.prompt_tokens
|
||||
self.last_output_token_count = response.usage.completion_tokens
|
||||
return tool_calls.function.name, tool_calls.function.arguments, tool_calls.id
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue