Get rid of necessity to declare tools in their own .py file
This commit is contained in:
parent
aef0510e68
commit
0eb582bdba
|
@ -38,10 +38,10 @@ The custom tool needs:
|
||||||
|
|
||||||
The types for both `inputs` and `output_type` should be amongst [Pydantic formats](https://docs.pydantic.dev/latest/concepts/json_schema/#generating-json-schema), they can be either of these: [`~AUTHORIZED_TYPES`].
|
The types for both `inputs` and `output_type` should be amongst [Pydantic formats](https://docs.pydantic.dev/latest/concepts/json_schema/#generating-json-schema), they can be either of these: [`~AUTHORIZED_TYPES`].
|
||||||
|
|
||||||
|
Also, all imports should be put within the tool's forward function, else you will get an error.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from transformers import Tool
|
from transformers import Tool
|
||||||
from huggingface_hub import list_models
|
|
||||||
|
|
||||||
class HFModelDownloadsTool(Tool):
|
class HFModelDownloadsTool(Tool):
|
||||||
name = "model_download_counter"
|
name = "model_download_counter"
|
||||||
|
@ -58,26 +58,27 @@ class HFModelDownloadsTool(Tool):
|
||||||
output_type = "string"
|
output_type = "string"
|
||||||
|
|
||||||
def forward(self, task: str):
|
def forward(self, task: str):
|
||||||
|
from huggingface_hub import list_models
|
||||||
|
|
||||||
model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
|
model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
|
||||||
return model.id
|
return model.id
|
||||||
```
|
|
||||||
|
|
||||||
Now that the custom `HfModelDownloadsTool` class is ready, you can save it to a file named `model_downloads.py` and import it for use.
|
|
||||||
|
|
||||||
|
|
||||||
```python
|
|
||||||
from model_downloads import HFModelDownloadsTool
|
|
||||||
|
|
||||||
tool = HFModelDownloadsTool()
|
tool = HFModelDownloadsTool()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Now the custom `HfModelDownloadsTool` class is ready.
|
||||||
|
|
||||||
You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access.
|
You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
tool.push_to_hub("{your_username}/hf-model-downloads")
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
tool.push_to_hub("m-ric/hf-model-downloads", token=os.getenv("HF_TOKEN"))
|
||||||
```
|
```
|
||||||
|
|
||||||
Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent.
|
Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent.
|
||||||
|
Since running tools means running custom code, you need to make sure you trust the repository, and pass `trust_remote_code=True`.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from transformers import load_tool, CodeAgent
|
from transformers import load_tool, CodeAgent
|
||||||
|
@ -159,7 +160,7 @@ We love Langchain and think it has a very compelling suite of tools.
|
||||||
To import a tool from LangChain, use the `from_langchain()` method.
|
To import a tool from LangChain, use the `from_langchain()` method.
|
||||||
|
|
||||||
Here is how you can use it to recreate the intro's search result using a LangChain web search tool.
|
Here is how you can use it to recreate the intro's search result using a LangChain web search tool.
|
||||||
This tool will need `pip install google-search-results` to work properly.
|
This tool will need `pip install langchain google-search-results -q` to work properly.
|
||||||
```python
|
```python
|
||||||
from langchain.agents import load_tools
|
from langchain.agents import load_tools
|
||||||
from agents import Tool, CodeAgent
|
from agents import Tool, CodeAgent
|
||||||
|
@ -191,7 +192,6 @@ agent.run(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
| **Audio** |
|
| **Audio** |
|
||||||
|------------------------------------------------------------------------------------------------------------------------------------------------------|
|
|------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
| <audio controls><source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/damo.wav" type="audio/wav"/> |
|
| <audio controls><source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/damo.wav" type="audio/wav"/> |
|
||||||
|
|
|
@ -44,7 +44,6 @@ from transformers.dynamic_module_utils import (
|
||||||
)
|
)
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
CONFIG_NAME,
|
|
||||||
TypeHintParsingException,
|
TypeHintParsingException,
|
||||||
cached_file,
|
cached_file,
|
||||||
get_json_schema,
|
get_json_schema,
|
||||||
|
@ -53,6 +52,8 @@ from transformers.utils import (
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
)
|
)
|
||||||
from .types import ImageType, handle_agent_inputs, handle_agent_outputs
|
from .types import ImageType, handle_agent_inputs, handle_agent_outputs
|
||||||
|
from .utils import ImportFinder
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -104,7 +105,7 @@ def setup_default_tools():
|
||||||
|
|
||||||
# docstyle-ignore
|
# docstyle-ignore
|
||||||
APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo
|
APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo
|
||||||
from {module_name} import {class_name}
|
from tool import {class_name}
|
||||||
|
|
||||||
launch_gradio_demo({class_name})
|
launch_gradio_demo({class_name})
|
||||||
"""
|
"""
|
||||||
|
@ -304,28 +305,44 @@ class Tool:
|
||||||
output_dir (`str`): The folder in which you want to save your tool.
|
output_dir (`str`): The folder in which you want to save your tool.
|
||||||
"""
|
"""
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
# Save module file
|
class_name = self.__class__.__name__
|
||||||
if self.__module__ == "__main__":
|
|
||||||
raise ValueError(
|
|
||||||
f"We can't save the code defining {self} in {output_dir} as it's been defined in __main__. You "
|
|
||||||
"have to put this code in a separate module so we can include it in the saved folder."
|
|
||||||
)
|
|
||||||
module_files = custom_object_save(self, output_dir)
|
|
||||||
|
|
||||||
module_name = self.__class__.__module__
|
# Save tool file
|
||||||
last_module = module_name.split(".")[-1]
|
forward_source_code = inspect.getsource(self.forward)
|
||||||
full_name = f"{last_module}.{self.__class__.__name__}"
|
tool_code = textwrap.dedent(f"""
|
||||||
|
from agents import Tool
|
||||||
|
|
||||||
|
class {class_name}(Tool):
|
||||||
|
name = "{self.name}"
|
||||||
|
description = "{self.description}"
|
||||||
|
inputs = {json.dumps(self.inputs, separators=(',', ':'))}
|
||||||
|
output_type = "{self.output_type}"
|
||||||
|
""").strip()
|
||||||
|
|
||||||
|
import re
|
||||||
|
def add_self_argument(source_code: str) -> str:
|
||||||
|
"""Add 'self' as first argument to a function definition if not present."""
|
||||||
|
pattern = r'def forward\(((?!self)[^)]*)\)'
|
||||||
|
|
||||||
|
def replacement(match):
|
||||||
|
args = match.group(1).strip()
|
||||||
|
if args: # If there are other arguments
|
||||||
|
return f'def forward(self, {args})'
|
||||||
|
return 'def forward(self)'
|
||||||
|
|
||||||
|
return re.sub(pattern, replacement, source_code)
|
||||||
|
|
||||||
|
forward_source_code = forward_source_code.replace(self.name, "forward")
|
||||||
|
forward_source_code = add_self_argument(forward_source_code)
|
||||||
|
forward_source_code = forward_source_code.replace("@tool", "").strip()
|
||||||
|
tool_code += "\n\n" + textwrap.indent(forward_source_code, " ")
|
||||||
|
with open(os.path.join(output_dir, "tool.py"), "w", encoding="utf-8") as f:
|
||||||
|
f.write(tool_code)
|
||||||
|
|
||||||
# Save config file
|
# Save config file
|
||||||
config_file = os.path.join(output_dir, "tool_config.json")
|
config_file = os.path.join(output_dir, "tool_config.json")
|
||||||
if os.path.isfile(config_file):
|
|
||||||
with open(config_file, "r", encoding="utf-8") as f:
|
|
||||||
tool_config = json.load(f)
|
|
||||||
else:
|
|
||||||
tool_config = {}
|
|
||||||
|
|
||||||
tool_config = {
|
tool_config = {
|
||||||
"tool_class": full_name,
|
"tool_class": self.__class__.__name__,
|
||||||
"description": self.description,
|
"description": self.description,
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"inputs": self.inputs,
|
"inputs": self.inputs,
|
||||||
|
@ -339,131 +356,20 @@ class Tool:
|
||||||
with open(app_file, "w", encoding="utf-8") as f:
|
with open(app_file, "w", encoding="utf-8") as f:
|
||||||
f.write(
|
f.write(
|
||||||
APP_FILE_TEMPLATE.format(
|
APP_FILE_TEMPLATE.format(
|
||||||
module_name=last_module, class_name=self.__class__.__name__
|
class_name=class_name
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save requirements file
|
# Save requirements file
|
||||||
requirements_file = os.path.join(output_dir, "requirements.txt")
|
requirements_file = os.path.join(output_dir, "requirements.txt")
|
||||||
imports = []
|
|
||||||
for module in module_files:
|
tree = ast.parse(forward_source_code)
|
||||||
imports.extend(get_imports(module))
|
import_finder = ImportFinder()
|
||||||
imports = list(set(imports))
|
import_finder.visit(tree)
|
||||||
|
|
||||||
|
imports = list(set(import_finder.packages))
|
||||||
with open(requirements_file, "w", encoding="utf-8") as f:
|
with open(requirements_file, "w", encoding="utf-8") as f:
|
||||||
f.write("\n".join(imports) + "\n")
|
f.write("agents_package\n" + "\n".join(imports) + "\n")
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_hub(
|
|
||||||
cls,
|
|
||||||
repo_id: str,
|
|
||||||
token: Optional[str] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Loads a tool defined on the Hub.
|
|
||||||
|
|
||||||
<Tip warning={true}>
|
|
||||||
|
|
||||||
Loading a tool from the Hub means that you'll download the tool and execute it locally.
|
|
||||||
ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
|
|
||||||
installing a package using pip/npm/apt.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
Args:
|
|
||||||
repo_id (`str`):
|
|
||||||
The name of the repo on the Hub where your tool is defined.
|
|
||||||
token (`str`, *optional*):
|
|
||||||
The token to identify you on hf.co. If unset, will use the token generated when running
|
|
||||||
`huggingface-cli login` (stored in `~/.huggingface`).
|
|
||||||
kwargs (additional keyword arguments, *optional*):
|
|
||||||
Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
|
|
||||||
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
|
|
||||||
others will be passed along to its init.
|
|
||||||
"""
|
|
||||||
hub_kwargs_names = [
|
|
||||||
"cache_dir",
|
|
||||||
"force_download",
|
|
||||||
"resume_download",
|
|
||||||
"proxies",
|
|
||||||
"revision",
|
|
||||||
"repo_type",
|
|
||||||
"subfolder",
|
|
||||||
"local_files_only",
|
|
||||||
]
|
|
||||||
hub_kwargs = {k: v for k, v in kwargs.items() if k in hub_kwargs_names}
|
|
||||||
|
|
||||||
# Try to get the tool config first.
|
|
||||||
hub_kwargs["repo_type"] = get_repo_type(repo_id, **hub_kwargs)
|
|
||||||
resolved_config_file = cached_file(
|
|
||||||
repo_id,
|
|
||||||
TOOL_CONFIG_FILE,
|
|
||||||
token=token,
|
|
||||||
**hub_kwargs,
|
|
||||||
_raise_exceptions_for_gated_repo=False,
|
|
||||||
_raise_exceptions_for_missing_entries=False,
|
|
||||||
_raise_exceptions_for_connection_errors=False,
|
|
||||||
)
|
|
||||||
is_tool_config = resolved_config_file is not None
|
|
||||||
if resolved_config_file is None:
|
|
||||||
resolved_config_file = cached_file(
|
|
||||||
repo_id,
|
|
||||||
CONFIG_NAME,
|
|
||||||
token=token,
|
|
||||||
**hub_kwargs,
|
|
||||||
_raise_exceptions_for_gated_repo=False,
|
|
||||||
_raise_exceptions_for_missing_entries=False,
|
|
||||||
_raise_exceptions_for_connection_errors=False,
|
|
||||||
)
|
|
||||||
if resolved_config_file is None:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"{repo_id} does not appear to provide a valid configuration in `tool_config.json` or `config.json`."
|
|
||||||
)
|
|
||||||
|
|
||||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
|
||||||
config = json.load(reader)
|
|
||||||
|
|
||||||
if not is_tool_config:
|
|
||||||
if "custom_tool" not in config:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"{repo_id} does not provide a mapping to custom tools in its configuration `config.json`."
|
|
||||||
)
|
|
||||||
custom_tool = config["custom_tool"]
|
|
||||||
else:
|
|
||||||
custom_tool = config
|
|
||||||
|
|
||||||
tool_class = custom_tool["tool_class"]
|
|
||||||
tool_class = get_class_from_dynamic_module(
|
|
||||||
tool_class, repo_id, token=token, **hub_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(tool_class.name) == 0:
|
|
||||||
tool_class.name = custom_tool["name"]
|
|
||||||
if tool_class.name != custom_tool["name"]:
|
|
||||||
logger.warning(
|
|
||||||
f"{tool_class.__name__} implements a different name in its configuration and class. Using the tool "
|
|
||||||
"configuration name."
|
|
||||||
)
|
|
||||||
tool_class.name = custom_tool["name"]
|
|
||||||
|
|
||||||
if len(tool_class.description) == 0:
|
|
||||||
tool_class.description = custom_tool["description"]
|
|
||||||
if tool_class.description != custom_tool["description"]:
|
|
||||||
logger.warning(
|
|
||||||
f"{tool_class.__name__} implements a different description in its configuration and class. Using the "
|
|
||||||
"tool configuration description."
|
|
||||||
)
|
|
||||||
tool_class.description = custom_tool["description"]
|
|
||||||
|
|
||||||
if tool_class.inputs != custom_tool["inputs"]:
|
|
||||||
tool_class.inputs = custom_tool["inputs"]
|
|
||||||
if tool_class.output_type != custom_tool["output_type"]:
|
|
||||||
tool_class.output_type = custom_tool["output_type"]
|
|
||||||
|
|
||||||
if not isinstance(tool_class.inputs, dict):
|
|
||||||
tool_class.inputs = ast.literal_eval(tool_class.inputs)
|
|
||||||
|
|
||||||
return tool_class(**kwargs)
|
|
||||||
|
|
||||||
def push_to_hub(
|
def push_to_hub(
|
||||||
self,
|
self,
|
||||||
|
@ -512,6 +418,9 @@ class Tool:
|
||||||
with tempfile.TemporaryDirectory() as work_dir:
|
with tempfile.TemporaryDirectory() as work_dir:
|
||||||
# Save all files.
|
# Save all files.
|
||||||
self.save(work_dir)
|
self.save(work_dir)
|
||||||
|
print(work_dir)
|
||||||
|
with open(work_dir + "/tool.py", "r") as f:
|
||||||
|
print('\n'.join(f.readlines()))
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}"
|
f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}"
|
||||||
)
|
)
|
||||||
|
@ -524,6 +433,110 @@ class Tool:
|
||||||
repo_type="space",
|
repo_type="space",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_hub(
|
||||||
|
cls,
|
||||||
|
repo_id: str,
|
||||||
|
token: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Loads a tool defined on the Hub.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
Loading a tool from the Hub means that you'll download the tool and execute it locally.
|
||||||
|
ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
|
||||||
|
installing a package using pip/npm/apt.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id (`str`):
|
||||||
|
The name of the repo on the Hub where your tool is defined.
|
||||||
|
token (`str`, *optional*):
|
||||||
|
The token to identify you on hf.co. If unset, will use the token generated when running
|
||||||
|
`huggingface-cli login` (stored in `~/.huggingface`).
|
||||||
|
kwargs (additional keyword arguments, *optional*):
|
||||||
|
Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
|
||||||
|
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
|
||||||
|
others will be passed along to its init.
|
||||||
|
"""
|
||||||
|
assert trust_remote_code, "Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool."
|
||||||
|
|
||||||
|
hub_kwargs_names = [
|
||||||
|
"cache_dir",
|
||||||
|
"force_download",
|
||||||
|
"resume_download",
|
||||||
|
"proxies",
|
||||||
|
"revision",
|
||||||
|
"repo_type",
|
||||||
|
"subfolder",
|
||||||
|
"local_files_only",
|
||||||
|
]
|
||||||
|
hub_kwargs = {k: v for k, v in kwargs.items() if k in hub_kwargs_names}
|
||||||
|
|
||||||
|
tool_file = "tool.py"
|
||||||
|
|
||||||
|
# Get the tool's tool.py file.
|
||||||
|
hub_kwargs["repo_type"] = get_repo_type(repo_id, **hub_kwargs)
|
||||||
|
resolved_tool_file = cached_file(
|
||||||
|
repo_id,
|
||||||
|
tool_file,
|
||||||
|
token=token,
|
||||||
|
**hub_kwargs,
|
||||||
|
_raise_exceptions_for_gated_repo=False,
|
||||||
|
_raise_exceptions_for_missing_entries=False,
|
||||||
|
_raise_exceptions_for_connection_errors=False,
|
||||||
|
)
|
||||||
|
tool_code = resolved_tool_file is not None
|
||||||
|
if resolved_tool_file is None:
|
||||||
|
resolved_tool_file = cached_file(
|
||||||
|
repo_id,
|
||||||
|
tool_file,
|
||||||
|
token=token,
|
||||||
|
**hub_kwargs,
|
||||||
|
_raise_exceptions_for_gated_repo=False,
|
||||||
|
_raise_exceptions_for_missing_entries=False,
|
||||||
|
_raise_exceptions_for_connection_errors=False,
|
||||||
|
)
|
||||||
|
if resolved_tool_file is None:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{repo_id} does not appear to provide a valid configuration in `tool_config.json` or `config.json`."
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(resolved_tool_file, encoding="utf-8") as reader:
|
||||||
|
tool_code = "".join(reader.readlines())
|
||||||
|
|
||||||
|
# Find the Tool subclass in the namespace
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
# Save the code to a file
|
||||||
|
module_path = os.path.join(temp_dir, "tool.py")
|
||||||
|
with open(module_path, "w") as f:
|
||||||
|
f.write(tool_code)
|
||||||
|
|
||||||
|
# Load module from file path
|
||||||
|
spec = importlib.util.spec_from_file_location("custom_tool", module_path)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
|
||||||
|
# Find and instantiate the Tool class
|
||||||
|
for item_name in dir(module):
|
||||||
|
item = getattr(module, item_name)
|
||||||
|
if isinstance(item, type) and issubclass(item, Tool) and item != Tool:
|
||||||
|
tool_class = item
|
||||||
|
break
|
||||||
|
|
||||||
|
if tool_class is None:
|
||||||
|
raise ValueError("No Tool subclass found in the code")
|
||||||
|
|
||||||
|
if not isinstance(tool_class.inputs, dict):
|
||||||
|
tool_class.inputs = ast.literal_eval(tool_class.inputs)
|
||||||
|
|
||||||
|
return tool_class(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_space(
|
def from_space(
|
||||||
space_id: str,
|
space_id: str,
|
||||||
|
@ -967,7 +980,8 @@ def tool(tool_function: Callable) -> Tool:
|
||||||
raise TypeHintParsingException(
|
raise TypeHintParsingException(
|
||||||
"Tool return type not found: make sure your function has a return type hint!"
|
"Tool return type not found: make sure your function has a return type hint!"
|
||||||
)
|
)
|
||||||
class_name = f"{parameters['name'].capitalize()}Tool"
|
class_name = ''.join([el.title() for el in parameters['name'].split('_')])
|
||||||
|
|
||||||
if parameters["return"]["type"] == "object":
|
if parameters["return"]["type"] == "object":
|
||||||
parameters["return"]["type"] = "any"
|
parameters["return"]["type"] = "any"
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import Tuple, Dict, Union
|
from typing import Tuple, Dict, Union
|
||||||
|
import ast
|
||||||
|
|
||||||
from transformers.utils.import_utils import _is_package_available
|
from transformers.utils.import_utils import _is_package_available
|
||||||
|
|
||||||
|
@ -109,5 +110,21 @@ def truncate_content(
|
||||||
+ f"\n..._This content has been truncated to stay below {max_length} characters_...\n"
|
+ f"\n..._This content has been truncated to stay below {max_length} characters_...\n"
|
||||||
+ content[-MAX_LENGTH_TRUNCATE_CONTENT // 2 :]
|
+ content[-MAX_LENGTH_TRUNCATE_CONTENT // 2 :]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class ImportFinder(ast.NodeVisitor):
|
||||||
|
def __init__(self):
|
||||||
|
self.packages = set()
|
||||||
|
|
||||||
|
def visit_Import(self, node):
|
||||||
|
for alias in node.names:
|
||||||
|
# Get the base package name (before any dots)
|
||||||
|
base_package = alias.name.split('.')[0]
|
||||||
|
self.packages.add(base_package)
|
||||||
|
|
||||||
|
def visit_ImportFrom(self, node):
|
||||||
|
if node.module: # for "from x import y" statements
|
||||||
|
# Get the base package name (before any dots)
|
||||||
|
base_package = node.module.split('.')[0]
|
||||||
|
self.packages.add(base_package)
|
||||||
|
|
||||||
__all__ = []
|
__all__ = []
|
Loading…
Reference in New Issue