From eca83800e37881b101ee1d289fa4b2b597645ee1 Mon Sep 17 00:00:00 2001 From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> Date: Fri, 10 Jan 2025 15:00:28 +0100 Subject: [PATCH] Halve import time by removing torch dependency (#147) * Halve import time by removing torch dependency --- .github/workflows/quality.yml | 2 +- docs/source/en/examples/multiagents.md | 6 +++--- docs/source/en/tutorials/tools.md | 4 ++-- pyproject.toml | 17 +++++++++++---- src/smolagents/default_tools.py | 17 ++++++++++----- src/smolagents/models.py | 29 +++++++++----------------- src/smolagents/tools.py | 22 +++++++++---------- src/smolagents/types.py | 8 +++---- tests/test_types.py | 13 ++++++++---- 9 files changed, 64 insertions(+), 54 deletions(-) diff --git a/.github/workflows/quality.yml b/.github/workflows/quality.yml index 37749d7..2e4f5c6 100644 --- a/.github/workflows/quality.yml +++ b/.github/workflows/quality.yml @@ -13,7 +13,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v2 with: - python-version: "3.10" + python-version: "3.12" # Setup venv # TODO: revisit when https://github.com/astral-sh/uv/issues/1526 is addressed. diff --git a/docs/source/en/examples/multiagents.md b/docs/source/en/examples/multiagents.md index 4ea4e51..7901de2 100644 --- a/docs/source/en/examples/multiagents.md +++ b/docs/source/en/examples/multiagents.md @@ -48,10 +48,10 @@ Run the line below to install the required dependencies: Let's login in order to call the HF Inference API: -```py -from huggingface_hub import notebook_login +``` +from huggingface_hub import login -notebook_login() +login() ``` ⚡️ Our agent will be powered by [Qwen/Qwen2.5-Coder-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-Coder-32B-Instruct) using `HfApiModel` class that uses HF's Inference API: the Inference API allows to quickly and easily run any OS model. diff --git a/docs/source/en/tutorials/tools.md b/docs/source/en/tutorials/tools.md index 014cd3b..bcaaa0f 100644 --- a/docs/source/en/tutorials/tools.md +++ b/docs/source/en/tutorials/tools.md @@ -177,7 +177,7 @@ agent.run("How many more blocks (also denoted as layers) are in BERT base encode ### Manage your agent's toolbox -You can manage an agent's toolbox by adding or replacing a tool. +You can manage an agent's toolbox by adding or replacing a tool in attribute `agent.tools`, since it is a standard dictionary. Let's add the `model_download_tool` to an existing agent initialized with only the default toolbox. @@ -187,7 +187,7 @@ from smolagents import HfApiModel model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct") agent = CodeAgent(tools=[], model=model, add_base_tools=True) -agent.tools.append(model_download_tool) +agent.tools[model_download_tool.name] = model_download_tool ``` Now we can leverage the new tool: diff --git a/pyproject.toml b/pyproject.toml index 978c1fb..addfc0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,9 +12,6 @@ authors = [ readme = "README.md" requires-python = ">=3.10" dependencies = [ - "torch", - "torchaudio", - "torchvision", "transformers>=4.0.0", "requests>=2.32.3", "rich>=13.9.4", @@ -30,10 +27,22 @@ dependencies = [ ] [tool.ruff] -ignore = ["F403"] +lint.ignore = ["F403"] [project.optional-dependencies] +dev = [ + "torch", + "torchaudio", + "torchvision", + "sqlalchemy", + "accelerate", + "soundfile", + "litellm>=1.55.10", +] test = [ + "torch", + "torchaudio", + "torchvision", "pytest>=8.1.0", "sqlalchemy", "ruff>=0.5.0", diff --git a/src/smolagents/default_tools.py b/src/smolagents/default_tools.py index 79539fd..75fe8d0 100644 --- a/src/smolagents/default_tools.py +++ b/src/smolagents/default_tools.py @@ -20,11 +20,9 @@ from dataclasses import dataclass from typing import Dict, Optional from huggingface_hub import hf_hub_download, list_spaces -from transformers.models.whisper import ( - WhisperForConditionalGeneration, - WhisperProcessor, -) -from transformers.utils import is_offline_mode + + +from transformers.utils import is_offline_mode, is_torch_available from .local_python_executor import ( BASE_BUILTIN_MODULES, @@ -34,6 +32,15 @@ from .local_python_executor import ( from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool from .types import AgentAudio +if is_torch_available(): + from transformers.models.whisper import ( + WhisperForConditionalGeneration, + WhisperProcessor, + ) +else: + WhisperForConditionalGeneration = object + WhisperProcessor = object + @dataclass class PreTool: diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 403e9fa..fd68607 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -22,7 +22,6 @@ from copy import deepcopy from enum import Enum from typing import Dict, List, Optional -import torch from huggingface_hub import ( InferenceClient, ChatCompletionOutputMessage, @@ -35,6 +34,7 @@ from transformers import ( AutoTokenizer, StoppingCriteria, StoppingCriteriaList, + is_torch_available, ) import openai @@ -147,29 +147,12 @@ class Model: self.last_input_token_count = None self.last_output_token_count = None - def get_token_counts(self): + def get_token_counts(self) -> Dict[str, int]: return { "input_token_count": self.last_input_token_count, "output_token_count": self.last_output_token_count, } - def generate( - self, - messages: List[Dict[str, str]], - stop_sequences: Optional[List[str]] = None, - grammar: Optional[str] = None, - max_tokens: int = 1500, - ): - raise NotImplementedError - - def get_tool_call( - self, - messages: List[Dict[str, str]], - available_tools: List[Tool], - stop_sequences, - ): - raise NotImplementedError - def __call__( self, messages: List[Dict[str, str]], @@ -256,6 +239,10 @@ class HfApiModel(Model): max_tokens: int = 1500, tools_to_call_from: Optional[List[Tool]] = None, ) -> str: + """ + Gets an LLM output message for the given list of input messages. + If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call. + """ messages = get_clean_message_list( messages, role_conversions=tool_role_conversions ) @@ -293,6 +280,10 @@ class TransformersModel(Model): def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None): super().__init__() + if not is_torch_available(): + raise ImportError("Please install torch in order to use TransformersModel.") + import torch + default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct" if model_id is None: model_id = default_model_id diff --git a/src/smolagents/tools.py b/src/smolagents/tools.py index 12d7d63..2638f54 100644 --- a/src/smolagents/tools.py +++ b/src/smolagents/tools.py @@ -27,7 +27,6 @@ from functools import lru_cache, wraps from pathlib import Path from typing import Callable, Dict, Optional, Union, get_type_hints -import torch from huggingface_hub import ( create_repo, get_collection, @@ -37,7 +36,6 @@ from huggingface_hub import ( ) from huggingface_hub.utils import RepositoryNotFoundError from packaging import version -from transformers import AutoProcessor from transformers.dynamic_module_utils import get_imports from transformers.utils import ( TypeHintParsingException, @@ -54,13 +52,14 @@ from .utils import instance_to_source logger = logging.getLogger(__name__) +if is_accelerate_available(): + from accelerate import PartialState + from accelerate.utils import send_to_device if is_torch_available(): - pass - -if is_accelerate_available(): - pass - + from transformers import AutoProcessor +else: + AutoProcessor = object TOOL_CONFIG_FILE = "tool_config.json" @@ -1026,8 +1025,6 @@ class PipelineTool(Tool): """ Instantiates the `pre_processor`, `model` and `post_processor` if necessary. """ - from accelerate import PartialState - if isinstance(self.pre_processor, str): self.pre_processor = self.pre_processor_class.from_pretrained( self.pre_processor, **self.hub_kwargs @@ -1066,6 +1063,8 @@ class PipelineTool(Tool): """ Sends the inputs through the `model`. """ + import torch + with torch.no_grad(): return self.model(**inputs) @@ -1076,6 +1075,8 @@ class PipelineTool(Tool): return self.post_processor(outputs) def __call__(self, *args, **kwargs): + import torch + args, kwargs = handle_agent_input_types(*args, **kwargs) if not self.is_initialized: @@ -1083,9 +1084,6 @@ class PipelineTool(Tool): encoded_inputs = self.encode(*args, **kwargs) - import torch - from accelerate.utils import send_to_device - tensor_inputs = { k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor) } diff --git a/src/smolagents/types.py b/src/smolagents/types.py index dbc5d5b..a9730c1 100644 --- a/src/smolagents/types.py +++ b/src/smolagents/types.py @@ -22,10 +22,10 @@ from io import BytesIO import numpy as np import requests from transformers.utils import ( - is_soundfile_availble, is_torch_available, is_vision_available, ) +from transformers.utils.import_utils import _is_package_available logger = logging.getLogger(__name__) @@ -41,7 +41,7 @@ if is_torch_available(): else: Tensor = object -if is_soundfile_availble(): +if _is_package_available("soundfile"): import soundfile as sf @@ -189,7 +189,7 @@ class AgentAudio(AgentType, str): def __init__(self, value, samplerate=16_000): super().__init__(value) - if not is_soundfile_availble(): + if not _is_package_available("soundfile"): raise ImportError("soundfile must be installed in order to handle audio.") self._path = None @@ -253,7 +253,7 @@ AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAu INSTANCE_TYPE_MAPPING = { str: AgentText, ImageType: AgentImage, - torch.Tensor: AgentAudio, + Tensor: AgentAudio, } if is_torch_available(): diff --git a/tests/test_types.py b/tests/test_types.py index e988e8b..aa58a8f 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -18,20 +18,19 @@ import unittest import uuid from pathlib import Path -import torch from PIL import Image from transformers.testing_utils import ( require_soundfile, require_torch, require_vision, ) -from transformers.utils import ( - is_soundfile_availble, +from transformers.utils.import_utils import ( + _is_package_available, ) from smolagents.types import AgentAudio, AgentImage, AgentText -if is_soundfile_availble(): +if _is_package_available("soundfile"): import soundfile as sf @@ -44,6 +43,8 @@ def get_new_path(suffix="") -> str: @require_torch class AgentAudioTests(unittest.TestCase): def test_from_tensor(self): + import torch + tensor = torch.rand(12, dtype=torch.float64) - 0.5 agent_type = AgentAudio(tensor) path = str(agent_type.to_string()) @@ -61,6 +62,8 @@ class AgentAudioTests(unittest.TestCase): self.assertTrue(torch.allclose(tensor, torch.tensor(new_tensor), atol=1e-4)) def test_from_string(self): + import torch + tensor = torch.rand(12, dtype=torch.float64) - 0.5 path = get_new_path(suffix=".wav") sf.write(path, tensor, 16000) @@ -75,6 +78,8 @@ class AgentAudioTests(unittest.TestCase): @require_torch class AgentImageTests(unittest.TestCase): def test_from_tensor(self): + import torch + tensor = torch.randint(0, 256, (64, 64, 3)) agent_type = AgentImage(tensor) path = str(agent_type.to_string())