add device parameter to TransformersModel
This commit is contained in:
parent
81388b14f7
commit
12ee33a878
|
@ -29,6 +29,7 @@ import litellm
|
|||
import logging
|
||||
import os
|
||||
import random
|
||||
import torch
|
||||
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
|
@ -304,7 +305,7 @@ class TransformersModel(Model):
|
|||
f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {model_id=}."
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(default_model_id).to(self.device)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(default_model_id, device_map=device)
|
||||
|
||||
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
|
||||
class StopOnStrings(StoppingCriteria):
|
||||
|
|
Loading…
Reference in New Issue