Halve import time by removing torch dependency (#147)

* Halve import time by removing torch dependency
This commit is contained in:
Aymeric Roucher 2025-01-10 15:00:28 +01:00 committed by GitHub
parent d8a4b831bb
commit eca83800e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 64 additions and 54 deletions

View File

@ -13,7 +13,7 @@ jobs:
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v2 uses: actions/setup-python@v2
with: with:
python-version: "3.10" python-version: "3.12"
# Setup venv # Setup venv
# TODO: revisit when https://github.com/astral-sh/uv/issues/1526 is addressed. # TODO: revisit when https://github.com/astral-sh/uv/issues/1526 is addressed.

View File

@ -48,10 +48,10 @@ Run the line below to install the required dependencies:
Let's login in order to call the HF Inference API: Let's login in order to call the HF Inference API:
```py ```
from huggingface_hub import notebook_login from huggingface_hub import login
notebook_login() login()
``` ```
⚡️ Our agent will be powered by [Qwen/Qwen2.5-Coder-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-Coder-32B-Instruct) using `HfApiModel` class that uses HF's Inference API: the Inference API allows to quickly and easily run any OS model. ⚡️ Our agent will be powered by [Qwen/Qwen2.5-Coder-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-Coder-32B-Instruct) using `HfApiModel` class that uses HF's Inference API: the Inference API allows to quickly and easily run any OS model.

View File

@ -177,7 +177,7 @@ agent.run("How many more blocks (also denoted as layers) are in BERT base encode
### Manage your agent's toolbox ### Manage your agent's toolbox
You can manage an agent's toolbox by adding or replacing a tool. You can manage an agent's toolbox by adding or replacing a tool in attribute `agent.tools`, since it is a standard dictionary.
Let's add the `model_download_tool` to an existing agent initialized with only the default toolbox. Let's add the `model_download_tool` to an existing agent initialized with only the default toolbox.
@ -187,7 +187,7 @@ from smolagents import HfApiModel
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct") model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct")
agent = CodeAgent(tools=[], model=model, add_base_tools=True) agent = CodeAgent(tools=[], model=model, add_base_tools=True)
agent.tools.append(model_download_tool) agent.tools[model_download_tool.name] = model_download_tool
``` ```
Now we can leverage the new tool: Now we can leverage the new tool:

View File

@ -12,9 +12,6 @@ authors = [
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"
dependencies = [ dependencies = [
"torch",
"torchaudio",
"torchvision",
"transformers>=4.0.0", "transformers>=4.0.0",
"requests>=2.32.3", "requests>=2.32.3",
"rich>=13.9.4", "rich>=13.9.4",
@ -30,10 +27,22 @@ dependencies = [
] ]
[tool.ruff] [tool.ruff]
ignore = ["F403"] lint.ignore = ["F403"]
[project.optional-dependencies] [project.optional-dependencies]
dev = [
"torch",
"torchaudio",
"torchvision",
"sqlalchemy",
"accelerate",
"soundfile",
"litellm>=1.55.10",
]
test = [ test = [
"torch",
"torchaudio",
"torchvision",
"pytest>=8.1.0", "pytest>=8.1.0",
"sqlalchemy", "sqlalchemy",
"ruff>=0.5.0", "ruff>=0.5.0",

View File

@ -20,11 +20,9 @@ from dataclasses import dataclass
from typing import Dict, Optional from typing import Dict, Optional
from huggingface_hub import hf_hub_download, list_spaces from huggingface_hub import hf_hub_download, list_spaces
from transformers.models.whisper import (
WhisperForConditionalGeneration,
WhisperProcessor, from transformers.utils import is_offline_mode, is_torch_available
)
from transformers.utils import is_offline_mode
from .local_python_executor import ( from .local_python_executor import (
BASE_BUILTIN_MODULES, BASE_BUILTIN_MODULES,
@ -34,6 +32,15 @@ from .local_python_executor import (
from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool
from .types import AgentAudio from .types import AgentAudio
if is_torch_available():
from transformers.models.whisper import (
WhisperForConditionalGeneration,
WhisperProcessor,
)
else:
WhisperForConditionalGeneration = object
WhisperProcessor = object
@dataclass @dataclass
class PreTool: class PreTool:

View File

@ -22,7 +22,6 @@ from copy import deepcopy
from enum import Enum from enum import Enum
from typing import Dict, List, Optional from typing import Dict, List, Optional
import torch
from huggingface_hub import ( from huggingface_hub import (
InferenceClient, InferenceClient,
ChatCompletionOutputMessage, ChatCompletionOutputMessage,
@ -35,6 +34,7 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
StoppingCriteria, StoppingCriteria,
StoppingCriteriaList, StoppingCriteriaList,
is_torch_available,
) )
import openai import openai
@ -147,29 +147,12 @@ class Model:
self.last_input_token_count = None self.last_input_token_count = None
self.last_output_token_count = None self.last_output_token_count = None
def get_token_counts(self): def get_token_counts(self) -> Dict[str, int]:
return { return {
"input_token_count": self.last_input_token_count, "input_token_count": self.last_input_token_count,
"output_token_count": self.last_output_token_count, "output_token_count": self.last_output_token_count,
} }
def generate(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
):
raise NotImplementedError
def get_tool_call(
self,
messages: List[Dict[str, str]],
available_tools: List[Tool],
stop_sequences,
):
raise NotImplementedError
def __call__( def __call__(
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
@ -256,6 +239,10 @@ class HfApiModel(Model):
max_tokens: int = 1500, max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None, tools_to_call_from: Optional[List[Tool]] = None,
) -> str: ) -> str:
"""
Gets an LLM output message for the given list of input messages.
If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call.
"""
messages = get_clean_message_list( messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions messages, role_conversions=tool_role_conversions
) )
@ -293,6 +280,10 @@ class TransformersModel(Model):
def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None): def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None):
super().__init__() super().__init__()
if not is_torch_available():
raise ImportError("Please install torch in order to use TransformersModel.")
import torch
default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct" default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
if model_id is None: if model_id is None:
model_id = default_model_id model_id = default_model_id

View File

@ -27,7 +27,6 @@ from functools import lru_cache, wraps
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, Optional, Union, get_type_hints from typing import Callable, Dict, Optional, Union, get_type_hints
import torch
from huggingface_hub import ( from huggingface_hub import (
create_repo, create_repo,
get_collection, get_collection,
@ -37,7 +36,6 @@ from huggingface_hub import (
) )
from huggingface_hub.utils import RepositoryNotFoundError from huggingface_hub.utils import RepositoryNotFoundError
from packaging import version from packaging import version
from transformers import AutoProcessor
from transformers.dynamic_module_utils import get_imports from transformers.dynamic_module_utils import get_imports
from transformers.utils import ( from transformers.utils import (
TypeHintParsingException, TypeHintParsingException,
@ -54,13 +52,14 @@ from .utils import instance_to_source
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import send_to_device
if is_torch_available(): if is_torch_available():
pass from transformers import AutoProcessor
else:
if is_accelerate_available(): AutoProcessor = object
pass
TOOL_CONFIG_FILE = "tool_config.json" TOOL_CONFIG_FILE = "tool_config.json"
@ -1026,8 +1025,6 @@ class PipelineTool(Tool):
""" """
Instantiates the `pre_processor`, `model` and `post_processor` if necessary. Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
""" """
from accelerate import PartialState
if isinstance(self.pre_processor, str): if isinstance(self.pre_processor, str):
self.pre_processor = self.pre_processor_class.from_pretrained( self.pre_processor = self.pre_processor_class.from_pretrained(
self.pre_processor, **self.hub_kwargs self.pre_processor, **self.hub_kwargs
@ -1066,6 +1063,8 @@ class PipelineTool(Tool):
""" """
Sends the inputs through the `model`. Sends the inputs through the `model`.
""" """
import torch
with torch.no_grad(): with torch.no_grad():
return self.model(**inputs) return self.model(**inputs)
@ -1076,6 +1075,8 @@ class PipelineTool(Tool):
return self.post_processor(outputs) return self.post_processor(outputs)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
import torch
args, kwargs = handle_agent_input_types(*args, **kwargs) args, kwargs = handle_agent_input_types(*args, **kwargs)
if not self.is_initialized: if not self.is_initialized:
@ -1083,9 +1084,6 @@ class PipelineTool(Tool):
encoded_inputs = self.encode(*args, **kwargs) encoded_inputs = self.encode(*args, **kwargs)
import torch
from accelerate.utils import send_to_device
tensor_inputs = { tensor_inputs = {
k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor) k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)
} }

View File

@ -22,10 +22,10 @@ from io import BytesIO
import numpy as np import numpy as np
import requests import requests
from transformers.utils import ( from transformers.utils import (
is_soundfile_availble,
is_torch_available, is_torch_available,
is_vision_available, is_vision_available,
) )
from transformers.utils.import_utils import _is_package_available
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -41,7 +41,7 @@ if is_torch_available():
else: else:
Tensor = object Tensor = object
if is_soundfile_availble(): if _is_package_available("soundfile"):
import soundfile as sf import soundfile as sf
@ -189,7 +189,7 @@ class AgentAudio(AgentType, str):
def __init__(self, value, samplerate=16_000): def __init__(self, value, samplerate=16_000):
super().__init__(value) super().__init__(value)
if not is_soundfile_availble(): if not _is_package_available("soundfile"):
raise ImportError("soundfile must be installed in order to handle audio.") raise ImportError("soundfile must be installed in order to handle audio.")
self._path = None self._path = None
@ -253,7 +253,7 @@ AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAu
INSTANCE_TYPE_MAPPING = { INSTANCE_TYPE_MAPPING = {
str: AgentText, str: AgentText,
ImageType: AgentImage, ImageType: AgentImage,
torch.Tensor: AgentAudio, Tensor: AgentAudio,
} }
if is_torch_available(): if is_torch_available():

View File

@ -18,20 +18,19 @@ import unittest
import uuid import uuid
from pathlib import Path from pathlib import Path
import torch
from PIL import Image from PIL import Image
from transformers.testing_utils import ( from transformers.testing_utils import (
require_soundfile, require_soundfile,
require_torch, require_torch,
require_vision, require_vision,
) )
from transformers.utils import ( from transformers.utils.import_utils import (
is_soundfile_availble, _is_package_available,
) )
from smolagents.types import AgentAudio, AgentImage, AgentText from smolagents.types import AgentAudio, AgentImage, AgentText
if is_soundfile_availble(): if _is_package_available("soundfile"):
import soundfile as sf import soundfile as sf
@ -44,6 +43,8 @@ def get_new_path(suffix="") -> str:
@require_torch @require_torch
class AgentAudioTests(unittest.TestCase): class AgentAudioTests(unittest.TestCase):
def test_from_tensor(self): def test_from_tensor(self):
import torch
tensor = torch.rand(12, dtype=torch.float64) - 0.5 tensor = torch.rand(12, dtype=torch.float64) - 0.5
agent_type = AgentAudio(tensor) agent_type = AgentAudio(tensor)
path = str(agent_type.to_string()) path = str(agent_type.to_string())
@ -61,6 +62,8 @@ class AgentAudioTests(unittest.TestCase):
self.assertTrue(torch.allclose(tensor, torch.tensor(new_tensor), atol=1e-4)) self.assertTrue(torch.allclose(tensor, torch.tensor(new_tensor), atol=1e-4))
def test_from_string(self): def test_from_string(self):
import torch
tensor = torch.rand(12, dtype=torch.float64) - 0.5 tensor = torch.rand(12, dtype=torch.float64) - 0.5
path = get_new_path(suffix=".wav") path = get_new_path(suffix=".wav")
sf.write(path, tensor, 16000) sf.write(path, tensor, 16000)
@ -75,6 +78,8 @@ class AgentAudioTests(unittest.TestCase):
@require_torch @require_torch
class AgentImageTests(unittest.TestCase): class AgentImageTests(unittest.TestCase):
def test_from_tensor(self): def test_from_tensor(self):
import torch
tensor = torch.randint(0, 256, (64, 64, 3)) tensor = torch.randint(0, 256, (64, 64, 3))
agent_type = AgentImage(tensor) agent_type = AgentImage(tensor)
path = str(agent_type.to_string()) path = str(agent_type.to_string())