add device parameter to TransformerModel

This commit is contained in:
Izaak Curry 2025-01-01 22:33:07 -08:00 committed by GitHub
parent 5991206ae5
commit 81388b14f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 6 additions and 3 deletions

View File

@ -284,7 +284,7 @@ class HfApiModel(Model):
class TransformersModel(Model):
"""This engine initializes a model and tokenizer from the given `model_id`."""
def __init__(self, model_id: Optional[str] = None):
def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None):
super().__init__()
default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
if model_id is None:
@ -293,15 +293,18 @@ class TransformersModel(Model):
f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'"
)
self.model_id = model_id
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device)
except Exception as e:
logger.warning(
f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {model_id=}."
)
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
self.model = AutoModelForCausalLM.from_pretrained(default_model_id)
self.model = AutoModelForCausalLM.from_pretrained(default_model_id).to(self.device)
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
class StopOnStrings(StoppingCriteria):