Add tests for all docs
This commit is contained in:
parent
9232528232
commit
57b35337c2
|
@ -51,7 +51,7 @@ Particular guidelines to follow:
|
||||||
For instance, here's a tool that :
|
For instance, here's a tool that :
|
||||||
|
|
||||||
First, here's a poor version:
|
First, here's a poor version:
|
||||||
```py
|
```python
|
||||||
from my_weather_api return convert_location_to_coordinates, get_weather_report_at_coordinates
|
from my_weather_api return convert_location_to_coordinates, get_weather_report_at_coordinates
|
||||||
# Let's say "get_weather_report_at_coordinates" returns a list of [temperature in °C, risk of rain on a scale 0-1, wave height in m]
|
# Let's say "get_weather_report_at_coordinates" returns a list of [temperature in °C, risk of rain on a scale 0-1, wave height in m]
|
||||||
import datetime
|
import datetime
|
||||||
|
@ -79,7 +79,7 @@ Why is it bad?
|
||||||
If the tool call fails, the error trace logged in memory can help the LLM reverse engineer the tool to fix the errors. But why leave it so much heavy lifting to do?
|
If the tool call fails, the error trace logged in memory can help the LLM reverse engineer the tool to fix the errors. But why leave it so much heavy lifting to do?
|
||||||
|
|
||||||
A better way to build this tool would have been the following:
|
A better way to build this tool would have been the following:
|
||||||
```py
|
```python
|
||||||
from my_weather_api return convert_location_to_coordinates, get_weather_report_at_coordinates
|
from my_weather_api return convert_location_to_coordinates, get_weather_report_at_coordinates
|
||||||
# Let's say "get_weather_report_at_coordinates" returns a list of [temperature in °C, risk of rain on a scale 0-1, wave height in m]
|
# Let's say "get_weather_report_at_coordinates" returns a list of [temperature in °C, risk of rain on a scale 0-1, wave height in m]
|
||||||
import datetime
|
import datetime
|
||||||
|
@ -104,7 +104,9 @@ def get_weather_api(location (str), date_time: str) -> str:
|
||||||
|
|
||||||
In general, to ease the load on your LLM, the good question to ask yourself is: "How easy would it be for me, if I was dumb and using this tool for thsefirst time ever, to program with this tool and correct my own errors?".
|
In general, to ease the load on your LLM, the good question to ask yourself is: "How easy would it be for me, if I was dumb and using this tool for thsefirst time ever, to program with this tool and correct my own errors?".
|
||||||
|
|
||||||
### How to debug your agent
|
## How to debug your agent
|
||||||
|
|
||||||
|
### 1. Use a stronger LLM
|
||||||
|
|
||||||
In an agentic workflows, some of the errors are actual errors, some other are the fault of your LLM engine not reasoning properly.
|
In an agentic workflows, some of the errors are actual errors, some other are the fault of your LLM engine not reasoning properly.
|
||||||
For instance, consider this trace for an `CodeAgent` that I asked to make me a car picture:
|
For instance, consider this trace for an `CodeAgent` that I asked to make me a car picture:
|
||||||
|
@ -140,6 +142,8 @@ Thus it cannot access the image again except by leveraging the path that was log
|
||||||
|
|
||||||
The first step to debugging your agent is thus "Use a more powerful LLM". Alternatives like `Qwen2/5-72B-Instruct` wouldn't have made that mistake.
|
The first step to debugging your agent is thus "Use a more powerful LLM". Alternatives like `Qwen2/5-72B-Instruct` wouldn't have made that mistake.
|
||||||
|
|
||||||
|
### 2. Provide more guidance / more information
|
||||||
|
|
||||||
Then you can also use less powerful models but guide them better.
|
Then you can also use less powerful models but guide them better.
|
||||||
|
|
||||||
To provide extra information, we do not recommend modifying the system prompt compared to default : there are many adjustments there that you do not want to mess up except if you understand the prompt very well.
|
To provide extra information, we do not recommend modifying the system prompt compared to default : there are many adjustments there that you do not want to mess up except if you understand the prompt very well.
|
||||||
|
@ -147,4 +151,6 @@ Better ways to guide your LLM engine are:
|
||||||
- If it 's about the task to solve: add all these details to the task. The task could be 100s of pages long.
|
- If it 's about the task to solve: add all these details to the task. The task could be 100s of pages long.
|
||||||
- If it's about how to use tools: the description attribute of your tools.
|
- If it's about how to use tools: the description attribute of your tools.
|
||||||
|
|
||||||
|
### 3. Extra planning
|
||||||
|
|
||||||
|
We provide a model for a supplementary planning step, that an agent can run regularly in-between normal action steps. In this step, there is no tool call, the LLM is simply asked to update a list of facts it knows and to reflect on what steps it should take next based on those facts.
|
||||||
|
|
|
@ -36,7 +36,7 @@ The custom tool needs:
|
||||||
- An `output_type` attribute, which specifies the output type.
|
- An `output_type` attribute, which specifies the output type.
|
||||||
- A `forward` method which contains the inference code to be executed.
|
- A `forward` method which contains the inference code to be executed.
|
||||||
|
|
||||||
The types for both `inputs` and `output_type` should be amongst [Pydantic formats](https://docs.pydantic.dev/latest/concepts/json_schema/#generating-json-schema), they can be either of these: `["string", "boolean", "integer", "number", "audio", "image", "any"]`.
|
The types for both `inputs` and `output_type` should be amongst [Pydantic formats](https://docs.pydantic.dev/latest/concepts/json_schema/#generating-json-schema), they can be either of these: [`~AUTHORIZED_TYPES`].
|
||||||
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
@ -93,7 +93,7 @@ You only need to provide the id of the Space on the Hub, its name, and a descrip
|
||||||
|
|
||||||
For instance, let's import the [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) Space from the Hub and use it to generate an image.
|
For instance, let's import the [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) Space from the Hub and use it to generate an image.
|
||||||
|
|
||||||
```
|
```python
|
||||||
from transformers import Tool
|
from transformers import Tool
|
||||||
|
|
||||||
image_generation_tool = Tool.from_space(
|
image_generation_tool = Tool.from_space(
|
||||||
|
@ -144,7 +144,7 @@ Import and instantiate the tool, then pass it to the `Tool.from_gradio` method:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from gradio_tools import StableDiffusionPromptGeneratorTool
|
from gradio_tools import StableDiffusionPromptGeneratorTool
|
||||||
from transformers import Tool, load_tool, CodeAgent
|
from agents import Tool, load_tool, CodeAgent
|
||||||
|
|
||||||
gradio_prompt_generator_tool = StableDiffusionPromptGeneratorTool()
|
gradio_prompt_generator_tool = StableDiffusionPromptGeneratorTool()
|
||||||
prompt_generator_tool = Tool.from_gradio(gradio_prompt_generator_tool)
|
prompt_generator_tool = Tool.from_gradio(gradio_prompt_generator_tool)
|
||||||
|
@ -162,7 +162,7 @@ Here is how you can use it to recreate the intro's search result using a LangCha
|
||||||
This tool will need `pip install google-search-results` to work properly.
|
This tool will need `pip install google-search-results` to work properly.
|
||||||
```python
|
```python
|
||||||
from langchain.agents import load_tools
|
from langchain.agents import load_tools
|
||||||
from transformers import Tool, CodeAgent
|
from agents import Tool, CodeAgent
|
||||||
|
|
||||||
search_tool = Tool.from_langchain(load_tools(["serpapi"])[0])
|
search_tool = Tool.from_langchain(load_tools(["serpapi"])[0])
|
||||||
|
|
||||||
|
|
|
@ -1,30 +0,0 @@
|
||||||
from agents import OpenAIEngine, AnthropicEngine, HfApiEngine, CodeAgent
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
openai_engine = OpenAIEngine(model_name="gpt-4o")
|
|
||||||
|
|
||||||
agent = CodeAgent([], llm_engine=openai_engine)
|
|
||||||
|
|
||||||
print("\n\n##############")
|
|
||||||
print("Running OpenAI agent:")
|
|
||||||
agent.run("What is the 10th Fibonacci Number?")
|
|
||||||
|
|
||||||
|
|
||||||
anthropic_engine = AnthropicEngine()
|
|
||||||
|
|
||||||
agent = CodeAgent([], llm_engine=anthropic_engine)
|
|
||||||
|
|
||||||
print("\n\n##############")
|
|
||||||
print("Running Anthropic agent:")
|
|
||||||
agent.run("What is the 10th Fibonacci Number?")
|
|
||||||
|
|
||||||
# Here, our token stored as HF_TOKEN environment variable has accesses 'Make calls to the serverless Inference API' and 'Read access to contents of all public gated repos you can access'
|
|
||||||
llama_engine = HfApiEngine(model="meta-llama/Llama-3.3-70B-Instruct")
|
|
||||||
|
|
||||||
agent = CodeAgent([], llm_engine=llama_engine)
|
|
||||||
|
|
||||||
print("\n\n##############")
|
|
||||||
print("Running Llama3.3-70B agent:")
|
|
||||||
agent.run("What is the 10th Fibonacci Number?")
|
|
|
@ -121,6 +121,15 @@ def validate_after_init(cls, do_validate_forward: bool = True):
|
||||||
cls.__init__ = new_init
|
cls.__init__ = new_init
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
AUTHORIZED_TYPES = [
|
||||||
|
"string",
|
||||||
|
"boolean",
|
||||||
|
"integer",
|
||||||
|
"number",
|
||||||
|
"image",
|
||||||
|
"audio",
|
||||||
|
"any",
|
||||||
|
]
|
||||||
|
|
||||||
CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"}
|
CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"}
|
||||||
|
|
||||||
|
@ -166,15 +175,6 @@ class Tool:
|
||||||
"inputs": dict,
|
"inputs": dict,
|
||||||
"output_type": str,
|
"output_type": str,
|
||||||
}
|
}
|
||||||
authorized_types = [
|
|
||||||
"string",
|
|
||||||
"integer",
|
|
||||||
"number",
|
|
||||||
"image",
|
|
||||||
"audio",
|
|
||||||
"any",
|
|
||||||
"boolean",
|
|
||||||
]
|
|
||||||
|
|
||||||
for attr, expected_type in required_attributes.items():
|
for attr, expected_type in required_attributes.items():
|
||||||
attr_value = getattr(self, attr, None)
|
attr_value = getattr(self, attr, None)
|
||||||
|
@ -191,12 +191,12 @@ class Tool:
|
||||||
assert (
|
assert (
|
||||||
"type" in input_content and "description" in input_content
|
"type" in input_content and "description" in input_content
|
||||||
), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
|
), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
|
||||||
if input_content["type"] not in authorized_types:
|
if input_content["type"] not in AUTHORIZED_TYPES:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {authorized_types}."
|
f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {AUTHORIZED_TYPES}."
|
||||||
)
|
)
|
||||||
|
|
||||||
assert getattr(self, "output_type", None) in authorized_types
|
assert getattr(self, "output_type", None) in AUTHORIZED_TYPES
|
||||||
if do_validate_forward:
|
if do_validate_forward:
|
||||||
if not isinstance(self, PipelineTool):
|
if not isinstance(self, PipelineTool):
|
||||||
signature = inspect.signature(self.forward)
|
signature = inspect.signature(self.forward)
|
||||||
|
@ -1077,7 +1077,6 @@ def tool(tool_function: Callable) -> Tool:
|
||||||
] + list(original_signature.parameters.values())
|
] + list(original_signature.parameters.values())
|
||||||
new_signature = original_signature.replace(parameters=new_parameters)
|
new_signature = original_signature.replace(parameters=new_parameters)
|
||||||
SpecificTool.forward.__signature__ = new_signature
|
SpecificTool.forward.__signature__ = new_signature
|
||||||
|
|
||||||
SpecificTool.__name__ = class_name
|
SpecificTool.__name__ = class_name
|
||||||
return SpecificTool()
|
return SpecificTool()
|
||||||
|
|
||||||
|
@ -1186,4 +1185,4 @@ class Toolbox:
|
||||||
toolbox_description += f"\t{tool.name}: {tool.description}\n"
|
toolbox_description += f"\t{tool.name}: {tool.description}\n"
|
||||||
return toolbox_description
|
return toolbox_description
|
||||||
|
|
||||||
__all__ = ["Tool", "tool", "load_tool", "launch_gradio_demo", "Toolbox"]
|
__all__ = ["AUTHORIZED_TYPES", "Tool", "tool", "load_tool", "launch_gradio_demo", "Toolbox"]
|
|
@ -0,0 +1,149 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
import subprocess
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
class SubprocessCallException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def run_command(command: List[str], return_stdout=False, env=None):
|
||||||
|
"""
|
||||||
|
Runs command with subprocess.check_output and returns stdout if requested.
|
||||||
|
Properly captures and handles errors during command execution.
|
||||||
|
"""
|
||||||
|
for i, c in enumerate(command):
|
||||||
|
if isinstance(c, Path):
|
||||||
|
command[i] = str(c)
|
||||||
|
|
||||||
|
if env is None:
|
||||||
|
env = os.environ.copy()
|
||||||
|
|
||||||
|
try:
|
||||||
|
output = subprocess.check_output(command, stderr=subprocess.STDOUT, env=env)
|
||||||
|
if return_stdout:
|
||||||
|
if hasattr(output, "decode"):
|
||||||
|
output = output.decode("utf-8")
|
||||||
|
return output
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
raise SubprocessCallException(
|
||||||
|
f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
class DocCodeExtractor:
|
||||||
|
"""Handles extraction and validation of Python code from markdown files."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_python_code(content: str) -> List[str]:
|
||||||
|
"""Extract Python code blocks from markdown content."""
|
||||||
|
pattern = r'```(?:python|py)\n(.*?)\n```'
|
||||||
|
matches = re.finditer(pattern, content, re.DOTALL)
|
||||||
|
return [match.group(1).strip() for match in matches]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_test_script(code_blocks: List[str], tmp_dir: str) -> Path:
|
||||||
|
"""Create a temporary Python script from code blocks."""
|
||||||
|
combined_code = "\n\n".join(code_blocks)
|
||||||
|
assert len(combined_code) > 0, "Code is empty!"
|
||||||
|
tmp_file = Path(tmp_dir) / "test_script.py"
|
||||||
|
|
||||||
|
with open(tmp_file, "w", encoding="utf-8") as f:
|
||||||
|
f.write(combined_code)
|
||||||
|
|
||||||
|
return tmp_file
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocs:
|
||||||
|
"""Test case for documentation code testing."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls):
|
||||||
|
cls._tmpdir = tempfile.mkdtemp()
|
||||||
|
cls.launch_args = ["python3"]
|
||||||
|
cls.docs_dir = Path(__file__).parent.parent / "docs" / "source"
|
||||||
|
cls.extractor = DocCodeExtractor()
|
||||||
|
|
||||||
|
# Verify docs directory exists
|
||||||
|
if not cls.docs_dir.exists():
|
||||||
|
raise ValueError(f"Docs directory not found at {cls.docs_dir}")
|
||||||
|
|
||||||
|
# Verify we have markdown files
|
||||||
|
cls.md_files = list(cls.docs_dir.glob("*.md"))
|
||||||
|
if not cls.md_files:
|
||||||
|
raise ValueError(f"No markdown files found in {cls.docs_dir}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def teardown_class(cls):
|
||||||
|
shutil.rmtree(cls._tmpdir)
|
||||||
|
|
||||||
|
def test_single_doc(self, doc_path: Path):
|
||||||
|
"""Test a single documentation file."""
|
||||||
|
with open(doc_path, "r", encoding="utf-8") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
code_blocks = self.extractor.extract_python_code(content)
|
||||||
|
if not code_blocks:
|
||||||
|
pytest.skip(f"No Python code blocks found in {doc_path.name}")
|
||||||
|
|
||||||
|
# Validate syntax of each block individually by parsing it
|
||||||
|
for i, block in enumerate(code_blocks, 1):
|
||||||
|
ast.parse(block)
|
||||||
|
|
||||||
|
# Create and execute test script
|
||||||
|
try:
|
||||||
|
test_script = self.extractor.create_test_script(code_blocks, self._tmpdir)
|
||||||
|
run_command(self.launch_args + [str(test_script)])
|
||||||
|
|
||||||
|
except SubprocessCallException as e:
|
||||||
|
pytest.fail(str(e))
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error testing {doc_path.name}: {str(e)}")
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _setup(self):
|
||||||
|
"""Fixture to ensure temporary directory exists for each test."""
|
||||||
|
os.makedirs(self._tmpdir, exist_ok=True)
|
||||||
|
yield
|
||||||
|
# Clean up test files after each test
|
||||||
|
for file in Path(self._tmpdir).glob("*"):
|
||||||
|
file.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_generate_tests(metafunc):
|
||||||
|
"""Generate test cases for each markdown file."""
|
||||||
|
if "doc_path" in metafunc.fixturenames:
|
||||||
|
test_class = metafunc.cls
|
||||||
|
|
||||||
|
# Initialize the class if needed
|
||||||
|
if not hasattr(test_class, "md_files"):
|
||||||
|
test_class.setup_class()
|
||||||
|
|
||||||
|
# Parameterize with the markdown files
|
||||||
|
metafunc.parametrize(
|
||||||
|
"doc_path",
|
||||||
|
test_class.md_files,
|
||||||
|
ids=[f.stem for f in test_class.md_files]
|
||||||
|
)
|
|
@ -61,11 +61,3 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
|
||||||
output = self.tool(**input)
|
output = self.tool(**input)
|
||||||
agent_type = AGENT_TYPE_MAPPING[input_type]
|
agent_type = AGENT_TYPE_MAPPING[input_type]
|
||||||
self.assertTrue(isinstance(output, agent_type))
|
self.assertTrue(isinstance(output, agent_type))
|
||||||
|
|
||||||
@require_torch
|
|
||||||
def test_agent_types_inputs(self):
|
|
||||||
inputs = self.create_inputs()
|
|
||||||
for input_type, input in inputs.items():
|
|
||||||
output = self.tool(**input)
|
|
||||||
agent_type = AGENT_TYPE_MAPPING[input_type]
|
|
||||||
self.assertTrue(isinstance(output, agent_type))
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ from agents.types import (
|
||||||
AgentImage,
|
AgentImage,
|
||||||
AgentText,
|
AgentText,
|
||||||
)
|
)
|
||||||
from agents.tools import Tool, tool
|
from agents.tools import Tool, tool, AUTHORIZED_TYPES
|
||||||
from transformers.testing_utils import get_tests_dir
|
from transformers.testing_utils import get_tests_dir
|
||||||
|
|
||||||
|
|
||||||
|
@ -37,9 +37,6 @@ if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
AUTHORIZED_TYPES = ["string", "boolean", "integer", "number", "audio", "image", "any"]
|
|
||||||
|
|
||||||
|
|
||||||
def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]):
|
def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]):
|
||||||
inputs = {}
|
inputs = {}
|
||||||
|
|
||||||
|
@ -101,14 +98,6 @@ class ToolTesterMixin:
|
||||||
agent_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
agent_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||||
self.assertTrue(isinstance(output, agent_type))
|
self.assertTrue(isinstance(output, agent_type))
|
||||||
|
|
||||||
def test_agent_types_inputs(self):
|
|
||||||
inputs = create_inputs(self.tool.inputs)
|
|
||||||
_inputs = []
|
|
||||||
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
|
|
||||||
input_type = expected_input["type"]
|
|
||||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
|
||||||
|
|
||||||
|
|
||||||
class ToolTests(unittest.TestCase):
|
class ToolTests(unittest.TestCase):
|
||||||
def test_tool_init_with_decorator(self):
|
def test_tool_init_with_decorator(self):
|
||||||
@tool
|
@tool
|
||||||
|
|
|
@ -4,6 +4,7 @@ import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
def str_to_bool(value) -> int:
|
def str_to_bool(value) -> int:
|
||||||
"""
|
"""
|
||||||
Converts a string representation of truth to `True` (1) or `False` (0).
|
Converts a string representation of truth to `True` (1) or `False` (0).
|
||||||
|
@ -48,24 +49,6 @@ def slow(test_case):
|
||||||
"""
|
"""
|
||||||
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
|
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
|
||||||
|
|
||||||
def get_launch_command(**kwargs) -> list:
|
|
||||||
"""
|
|
||||||
Wraps around `kwargs` to help simplify launching from `subprocess`.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```python
|
|
||||||
# returns ['accelerate', 'launch', '--num_processes=2', '--device_count=2']
|
|
||||||
get_launch_command(num_processes=2, device_count=2)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
command = ["accelerate", "launch"]
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
if isinstance(v, bool) and v:
|
|
||||||
command.append(f"--{k}")
|
|
||||||
elif v is not None:
|
|
||||||
command.append(f"--{k}={v}")
|
|
||||||
return command
|
|
||||||
|
|
||||||
|
|
||||||
class TempDirTestCase(unittest.TestCase):
|
class TempDirTestCase(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue