From 11a738e53a4ee1e9bda7bdf7c9b94318bdcd4d14 Mon Sep 17 00:00:00 2001 From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:55:36 +0100 Subject: [PATCH] Add trust_remote_code arg to TransformersModel (#240) --- src/smolagents/models.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 111fa0c..f07d799 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -335,6 +335,8 @@ class TransformersModel(Model): The device_map to initialize your model with. torch_dtype (`str`, *optional*): The torch_dtype to initialize your model with. + trust_remote_code (bool): + Some models on the Hub require running remote code: for this model, you would have to set this flag to True. kwargs (dict, *optional*): Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`. Raises: @@ -360,6 +362,7 @@ class TransformersModel(Model): model_id: Optional[str] = None, device_map: Optional[str] = None, torch_dtype: Optional[str] = None, + trust_remote_code: bool = False, **kwargs, ): super().__init__() @@ -381,7 +384,10 @@ class TransformersModel(Model): try: self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.model = AutoModelForCausalLM.from_pretrained( - model_id, device_map=device_map, torch_dtype=torch_dtype + model_id, + device_map=device_map, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, ) except Exception as e: logger.warning(