diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 6fc8dbb..6ad0ce9 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -284,7 +284,7 @@ class HfApiModel(Model): class TransformersModel(Model): """This engine initializes a model and tokenizer from the given `model_id`.""" - def __init__(self, model_id: Optional[str] = None): + def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None): super().__init__() default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct" if model_id is None: @@ -293,15 +293,18 @@ class TransformersModel(Model): f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'" ) self.model_id = model_id + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = device try: self.tokenizer = AutoTokenizer.from_pretrained(model_id) - self.model = AutoModelForCausalLM.from_pretrained(model_id) + self.model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device) except Exception as e: logger.warning( f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {model_id=}." ) self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) - self.model = AutoModelForCausalLM.from_pretrained(default_model_id) + self.model = AutoModelForCausalLM.from_pretrained(default_model_id).to(self.device) def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: class StopOnStrings(StoppingCriteria):