Adding default parameter for max_new_tokens in TransformersModel (#604)
This commit is contained in:
parent
a427c84c1c
commit
f3ee6052db
|
@ -599,7 +599,16 @@ class TransformersModel(Model):
|
|||
model_id = default_model_id
|
||||
logger.warning(f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'")
|
||||
self.model_id = model_id
|
||||
|
||||
default_max_tokens = 5000
|
||||
max_new_tokens = kwargs.get("max_new_tokens") or kwargs.get("max_tokens")
|
||||
if not max_new_tokens:
|
||||
kwargs["max_new_tokens"] = default_max_tokens
|
||||
logger.warning(
|
||||
f"`max_new_tokens` not provided, using this default value for `max_new_tokens`: {default_max_tokens}"
|
||||
)
|
||||
self.kwargs = kwargs
|
||||
|
||||
if device_map is None:
|
||||
device_map = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"Using device: {device_map}")
|
||||
|
|
Loading…
Reference in New Issue