add device parameter to TransformersModel

This commit is contained in:
Izaak Curry 2025-01-02 20:54:32 -08:00
parent 81388b14f7
commit 12ee33a878
1 changed files with 2 additions and 1 deletions

View File

@ -29,6 +29,7 @@ import litellm
import logging
import os
import random
import torch
from huggingface_hub import InferenceClient
@ -304,7 +305,7 @@ class TransformersModel(Model):
f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {model_id=}."
)
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
self.model = AutoModelForCausalLM.from_pretrained(default_model_id).to(self.device)
self.model = AutoModelForCausalLM.from_pretrained(default_model_id, device_map=device)
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
class StopOnStrings(StoppingCriteria):