Halve import time by removing torch dependency (#147)
* Halve import time by removing torch dependency
This commit is contained in:
parent
d8a4b831bb
commit
eca83800e3
|
@ -13,7 +13,7 @@ jobs:
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v2
|
uses: actions/setup-python@v2
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.12"
|
||||||
|
|
||||||
# Setup venv
|
# Setup venv
|
||||||
# TODO: revisit when https://github.com/astral-sh/uv/issues/1526 is addressed.
|
# TODO: revisit when https://github.com/astral-sh/uv/issues/1526 is addressed.
|
||||||
|
|
|
@ -48,10 +48,10 @@ Run the line below to install the required dependencies:
|
||||||
|
|
||||||
Let's login in order to call the HF Inference API:
|
Let's login in order to call the HF Inference API:
|
||||||
|
|
||||||
```py
|
```
|
||||||
from huggingface_hub import notebook_login
|
from huggingface_hub import login
|
||||||
|
|
||||||
notebook_login()
|
login()
|
||||||
```
|
```
|
||||||
|
|
||||||
⚡️ Our agent will be powered by [Qwen/Qwen2.5-Coder-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-Coder-32B-Instruct) using `HfApiModel` class that uses HF's Inference API: the Inference API allows to quickly and easily run any OS model.
|
⚡️ Our agent will be powered by [Qwen/Qwen2.5-Coder-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-Coder-32B-Instruct) using `HfApiModel` class that uses HF's Inference API: the Inference API allows to quickly and easily run any OS model.
|
||||||
|
|
|
@ -177,7 +177,7 @@ agent.run("How many more blocks (also denoted as layers) are in BERT base encode
|
||||||
|
|
||||||
### Manage your agent's toolbox
|
### Manage your agent's toolbox
|
||||||
|
|
||||||
You can manage an agent's toolbox by adding or replacing a tool.
|
You can manage an agent's toolbox by adding or replacing a tool in attribute `agent.tools`, since it is a standard dictionary.
|
||||||
|
|
||||||
Let's add the `model_download_tool` to an existing agent initialized with only the default toolbox.
|
Let's add the `model_download_tool` to an existing agent initialized with only the default toolbox.
|
||||||
|
|
||||||
|
@ -187,7 +187,7 @@ from smolagents import HfApiModel
|
||||||
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct")
|
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct")
|
||||||
|
|
||||||
agent = CodeAgent(tools=[], model=model, add_base_tools=True)
|
agent = CodeAgent(tools=[], model=model, add_base_tools=True)
|
||||||
agent.tools.append(model_download_tool)
|
agent.tools[model_download_tool.name] = model_download_tool
|
||||||
```
|
```
|
||||||
Now we can leverage the new tool:
|
Now we can leverage the new tool:
|
||||||
|
|
||||||
|
|
|
@ -12,9 +12,6 @@ authors = [
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"torch",
|
|
||||||
"torchaudio",
|
|
||||||
"torchvision",
|
|
||||||
"transformers>=4.0.0",
|
"transformers>=4.0.0",
|
||||||
"requests>=2.32.3",
|
"requests>=2.32.3",
|
||||||
"rich>=13.9.4",
|
"rich>=13.9.4",
|
||||||
|
@ -30,10 +27,22 @@ dependencies = [
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
ignore = ["F403"]
|
lint.ignore = ["F403"]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"torch",
|
||||||
|
"torchaudio",
|
||||||
|
"torchvision",
|
||||||
|
"sqlalchemy",
|
||||||
|
"accelerate",
|
||||||
|
"soundfile",
|
||||||
|
"litellm>=1.55.10",
|
||||||
|
]
|
||||||
test = [
|
test = [
|
||||||
|
"torch",
|
||||||
|
"torchaudio",
|
||||||
|
"torchvision",
|
||||||
"pytest>=8.1.0",
|
"pytest>=8.1.0",
|
||||||
"sqlalchemy",
|
"sqlalchemy",
|
||||||
"ruff>=0.5.0",
|
"ruff>=0.5.0",
|
||||||
|
|
|
@ -20,11 +20,9 @@ from dataclasses import dataclass
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download, list_spaces
|
from huggingface_hub import hf_hub_download, list_spaces
|
||||||
from transformers.models.whisper import (
|
|
||||||
WhisperForConditionalGeneration,
|
|
||||||
WhisperProcessor,
|
from transformers.utils import is_offline_mode, is_torch_available
|
||||||
)
|
|
||||||
from transformers.utils import is_offline_mode
|
|
||||||
|
|
||||||
from .local_python_executor import (
|
from .local_python_executor import (
|
||||||
BASE_BUILTIN_MODULES,
|
BASE_BUILTIN_MODULES,
|
||||||
|
@ -34,6 +32,15 @@ from .local_python_executor import (
|
||||||
from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool
|
from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool
|
||||||
from .types import AgentAudio
|
from .types import AgentAudio
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from transformers.models.whisper import (
|
||||||
|
WhisperForConditionalGeneration,
|
||||||
|
WhisperProcessor,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
WhisperForConditionalGeneration = object
|
||||||
|
WhisperProcessor = object
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PreTool:
|
class PreTool:
|
||||||
|
|
|
@ -22,7 +22,6 @@ from copy import deepcopy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
|
||||||
from huggingface_hub import (
|
from huggingface_hub import (
|
||||||
InferenceClient,
|
InferenceClient,
|
||||||
ChatCompletionOutputMessage,
|
ChatCompletionOutputMessage,
|
||||||
|
@ -35,6 +34,7 @@ from transformers import (
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
StoppingCriteriaList,
|
StoppingCriteriaList,
|
||||||
|
is_torch_available,
|
||||||
)
|
)
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
|
@ -147,29 +147,12 @@ class Model:
|
||||||
self.last_input_token_count = None
|
self.last_input_token_count = None
|
||||||
self.last_output_token_count = None
|
self.last_output_token_count = None
|
||||||
|
|
||||||
def get_token_counts(self):
|
def get_token_counts(self) -> Dict[str, int]:
|
||||||
return {
|
return {
|
||||||
"input_token_count": self.last_input_token_count,
|
"input_token_count": self.last_input_token_count,
|
||||||
"output_token_count": self.last_output_token_count,
|
"output_token_count": self.last_output_token_count,
|
||||||
}
|
}
|
||||||
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
messages: List[Dict[str, str]],
|
|
||||||
stop_sequences: Optional[List[str]] = None,
|
|
||||||
grammar: Optional[str] = None,
|
|
||||||
max_tokens: int = 1500,
|
|
||||||
):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def get_tool_call(
|
|
||||||
self,
|
|
||||||
messages: List[Dict[str, str]],
|
|
||||||
available_tools: List[Tool],
|
|
||||||
stop_sequences,
|
|
||||||
):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
|
@ -256,6 +239,10 @@ class HfApiModel(Model):
|
||||||
max_tokens: int = 1500,
|
max_tokens: int = 1500,
|
||||||
tools_to_call_from: Optional[List[Tool]] = None,
|
tools_to_call_from: Optional[List[Tool]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
"""
|
||||||
|
Gets an LLM output message for the given list of input messages.
|
||||||
|
If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call.
|
||||||
|
"""
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
messages, role_conversions=tool_role_conversions
|
messages, role_conversions=tool_role_conversions
|
||||||
)
|
)
|
||||||
|
@ -293,6 +280,10 @@ class TransformersModel(Model):
|
||||||
|
|
||||||
def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None):
|
def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if not is_torch_available():
|
||||||
|
raise ImportError("Please install torch in order to use TransformersModel.")
|
||||||
|
import torch
|
||||||
|
|
||||||
default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
|
default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
|
||||||
if model_id is None:
|
if model_id is None:
|
||||||
model_id = default_model_id
|
model_id = default_model_id
|
||||||
|
|
|
@ -27,7 +27,6 @@ from functools import lru_cache, wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, Optional, Union, get_type_hints
|
from typing import Callable, Dict, Optional, Union, get_type_hints
|
||||||
|
|
||||||
import torch
|
|
||||||
from huggingface_hub import (
|
from huggingface_hub import (
|
||||||
create_repo,
|
create_repo,
|
||||||
get_collection,
|
get_collection,
|
||||||
|
@ -37,7 +36,6 @@ from huggingface_hub import (
|
||||||
)
|
)
|
||||||
from huggingface_hub.utils import RepositoryNotFoundError
|
from huggingface_hub.utils import RepositoryNotFoundError
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from transformers import AutoProcessor
|
|
||||||
from transformers.dynamic_module_utils import get_imports
|
from transformers.dynamic_module_utils import get_imports
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
TypeHintParsingException,
|
TypeHintParsingException,
|
||||||
|
@ -54,13 +52,14 @@ from .utils import instance_to_source
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if is_accelerate_available():
|
||||||
|
from accelerate import PartialState
|
||||||
|
from accelerate.utils import send_to_device
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
pass
|
from transformers import AutoProcessor
|
||||||
|
else:
|
||||||
if is_accelerate_available():
|
AutoProcessor = object
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
TOOL_CONFIG_FILE = "tool_config.json"
|
TOOL_CONFIG_FILE = "tool_config.json"
|
||||||
|
|
||||||
|
@ -1026,8 +1025,6 @@ class PipelineTool(Tool):
|
||||||
"""
|
"""
|
||||||
Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
|
Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
|
||||||
"""
|
"""
|
||||||
from accelerate import PartialState
|
|
||||||
|
|
||||||
if isinstance(self.pre_processor, str):
|
if isinstance(self.pre_processor, str):
|
||||||
self.pre_processor = self.pre_processor_class.from_pretrained(
|
self.pre_processor = self.pre_processor_class.from_pretrained(
|
||||||
self.pre_processor, **self.hub_kwargs
|
self.pre_processor, **self.hub_kwargs
|
||||||
|
@ -1066,6 +1063,8 @@ class PipelineTool(Tool):
|
||||||
"""
|
"""
|
||||||
Sends the inputs through the `model`.
|
Sends the inputs through the `model`.
|
||||||
"""
|
"""
|
||||||
|
import torch
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
return self.model(**inputs)
|
return self.model(**inputs)
|
||||||
|
|
||||||
|
@ -1076,6 +1075,8 @@ class PipelineTool(Tool):
|
||||||
return self.post_processor(outputs)
|
return self.post_processor(outputs)
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
|
import torch
|
||||||
|
|
||||||
args, kwargs = handle_agent_input_types(*args, **kwargs)
|
args, kwargs = handle_agent_input_types(*args, **kwargs)
|
||||||
|
|
||||||
if not self.is_initialized:
|
if not self.is_initialized:
|
||||||
|
@ -1083,9 +1084,6 @@ class PipelineTool(Tool):
|
||||||
|
|
||||||
encoded_inputs = self.encode(*args, **kwargs)
|
encoded_inputs = self.encode(*args, **kwargs)
|
||||||
|
|
||||||
import torch
|
|
||||||
from accelerate.utils import send_to_device
|
|
||||||
|
|
||||||
tensor_inputs = {
|
tensor_inputs = {
|
||||||
k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)
|
k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,10 +22,10 @@ from io import BytesIO
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
is_soundfile_availble,
|
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
)
|
)
|
||||||
|
from transformers.utils.import_utils import _is_package_available
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -41,7 +41,7 @@ if is_torch_available():
|
||||||
else:
|
else:
|
||||||
Tensor = object
|
Tensor = object
|
||||||
|
|
||||||
if is_soundfile_availble():
|
if _is_package_available("soundfile"):
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
|
|
||||||
|
@ -189,7 +189,7 @@ class AgentAudio(AgentType, str):
|
||||||
def __init__(self, value, samplerate=16_000):
|
def __init__(self, value, samplerate=16_000):
|
||||||
super().__init__(value)
|
super().__init__(value)
|
||||||
|
|
||||||
if not is_soundfile_availble():
|
if not _is_package_available("soundfile"):
|
||||||
raise ImportError("soundfile must be installed in order to handle audio.")
|
raise ImportError("soundfile must be installed in order to handle audio.")
|
||||||
|
|
||||||
self._path = None
|
self._path = None
|
||||||
|
@ -253,7 +253,7 @@ AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAu
|
||||||
INSTANCE_TYPE_MAPPING = {
|
INSTANCE_TYPE_MAPPING = {
|
||||||
str: AgentText,
|
str: AgentText,
|
||||||
ImageType: AgentImage,
|
ImageType: AgentImage,
|
||||||
torch.Tensor: AgentAudio,
|
Tensor: AgentAudio,
|
||||||
}
|
}
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
|
|
@ -18,20 +18,19 @@ import unittest
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_soundfile,
|
require_soundfile,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_vision,
|
require_vision,
|
||||||
)
|
)
|
||||||
from transformers.utils import (
|
from transformers.utils.import_utils import (
|
||||||
is_soundfile_availble,
|
_is_package_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
from smolagents.types import AgentAudio, AgentImage, AgentText
|
from smolagents.types import AgentAudio, AgentImage, AgentText
|
||||||
|
|
||||||
if is_soundfile_availble():
|
if _is_package_available("soundfile"):
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
|
|
||||||
|
@ -44,6 +43,8 @@ def get_new_path(suffix="") -> str:
|
||||||
@require_torch
|
@require_torch
|
||||||
class AgentAudioTests(unittest.TestCase):
|
class AgentAudioTests(unittest.TestCase):
|
||||||
def test_from_tensor(self):
|
def test_from_tensor(self):
|
||||||
|
import torch
|
||||||
|
|
||||||
tensor = torch.rand(12, dtype=torch.float64) - 0.5
|
tensor = torch.rand(12, dtype=torch.float64) - 0.5
|
||||||
agent_type = AgentAudio(tensor)
|
agent_type = AgentAudio(tensor)
|
||||||
path = str(agent_type.to_string())
|
path = str(agent_type.to_string())
|
||||||
|
@ -61,6 +62,8 @@ class AgentAudioTests(unittest.TestCase):
|
||||||
self.assertTrue(torch.allclose(tensor, torch.tensor(new_tensor), atol=1e-4))
|
self.assertTrue(torch.allclose(tensor, torch.tensor(new_tensor), atol=1e-4))
|
||||||
|
|
||||||
def test_from_string(self):
|
def test_from_string(self):
|
||||||
|
import torch
|
||||||
|
|
||||||
tensor = torch.rand(12, dtype=torch.float64) - 0.5
|
tensor = torch.rand(12, dtype=torch.float64) - 0.5
|
||||||
path = get_new_path(suffix=".wav")
|
path = get_new_path(suffix=".wav")
|
||||||
sf.write(path, tensor, 16000)
|
sf.write(path, tensor, 16000)
|
||||||
|
@ -75,6 +78,8 @@ class AgentAudioTests(unittest.TestCase):
|
||||||
@require_torch
|
@require_torch
|
||||||
class AgentImageTests(unittest.TestCase):
|
class AgentImageTests(unittest.TestCase):
|
||||||
def test_from_tensor(self):
|
def test_from_tensor(self):
|
||||||
|
import torch
|
||||||
|
|
||||||
tensor = torch.randint(0, 256, (64, 64, 3))
|
tensor = torch.randint(0, 256, (64, 64, 3))
|
||||||
agent_type = AgentImage(tensor)
|
agent_type = AgentImage(tensor)
|
||||||
path = str(agent_type.to_string())
|
path = str(agent_type.to_string())
|
||||||
|
|
Loading…
Reference in New Issue