Add trust_remote_code arg to TransformersModel (#240)

This commit is contained in:
Aymeric Roucher 2025-01-17 11:55:36 +01:00 committed by GitHub
parent c255c1ff84
commit 11a738e53a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 7 additions and 1 deletions

View File

@ -335,6 +335,8 @@ class TransformersModel(Model):
The device_map to initialize your model with. The device_map to initialize your model with.
torch_dtype (`str`, *optional*): torch_dtype (`str`, *optional*):
The torch_dtype to initialize your model with. The torch_dtype to initialize your model with.
trust_remote_code (bool):
Some models on the Hub require running remote code: for this model, you would have to set this flag to True.
kwargs (dict, *optional*): kwargs (dict, *optional*):
Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`. Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`.
Raises: Raises:
@ -360,6 +362,7 @@ class TransformersModel(Model):
model_id: Optional[str] = None, model_id: Optional[str] = None,
device_map: Optional[str] = None, device_map: Optional[str] = None,
torch_dtype: Optional[str] = None, torch_dtype: Optional[str] = None,
trust_remote_code: bool = False,
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
@ -381,7 +384,10 @@ class TransformersModel(Model):
try: try:
self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_id, device_map=device_map, torch_dtype=torch_dtype model_id,
device_map=device_map,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
) )
except Exception as e: except Exception as e:
logger.warning( logger.warning(