Pass tests

This commit is contained in:
Aymeric 2024-12-11 19:23:07 +01:00
parent 67deb6808f
commit 1606b9a80c
17 changed files with 188 additions and 642 deletions

View File

@ -62,7 +62,6 @@ else:
if TYPE_CHECKING: if TYPE_CHECKING:
from .agents import ( from .agents import (
Agent, Agent,
CodeAgent,
ManagedAgent, ManagedAgent,
ReactAgent, ReactAgent,
CodeAgent, CodeAgent,

View File

@ -47,9 +47,6 @@ from .tools import (
) )
HUGGINGFACE_DEFAULT_TOOLS = {}
class AgentError(Exception): class AgentError(Exception):
"""Base class for other agent-related exceptions""" """Base class for other agent-related exceptions"""
@ -145,14 +142,18 @@ Here is a list of the team members that you can call:"""
def format_prompt_with_managed_agents_descriptions( def format_prompt_with_managed_agents_descriptions(
prompt_template, managed_agents=None prompt_template, managed_agents, agent_descriptions_placeholder: Optional[str] = None
) -> str: ) -> str:
if managed_agents is not None: if agent_descriptions_placeholder is None:
agent_descriptions_placeholder = "{{managed_agents_descriptions}}"
if agent_descriptions_placeholder not in prompt_template:
raise ValueError(f"Provided prompt template does not contain the managed agents descriptions placeholder '{agent_descriptions_placeholder}'")
if len(managed_agents.keys()) > 0:
return prompt_template.replace( return prompt_template.replace(
"<<managed_agents_descriptions>>", show_agents_descriptions(managed_agents) agent_descriptions_placeholder, show_agents_descriptions(managed_agents)
) )
else: else:
return prompt_template.replace("<<managed_agents_descriptions>>", "") return prompt_template.replace(agent_descriptions_placeholder, "")
def format_prompt_with_imports( def format_prompt_with_imports(
@ -220,12 +221,8 @@ class BaseAgent:
self._toolbox = Toolbox(tools, add_base_tools=add_base_tools) self._toolbox = Toolbox(tools, add_base_tools=add_base_tools)
self._toolbox.add_tool(FinalAnswerTool()) self._toolbox.add_tool(FinalAnswerTool())
self.system_prompt = format_prompt_with_tools( self.system_prompt = self.initialize_system_prompt()
self._toolbox, self.system_prompt_template, self.tool_description_template print("SYS0:", self.system_prompt)
)
self.system_prompt = format_prompt_with_managed_agents_descriptions(
self.system_prompt, self.managed_agents
)
self.prompt_messages = None self.prompt_messages = None
self.logs = [] self.logs = []
self.task = None self.task = None
@ -353,7 +350,7 @@ class BaseAgent:
split[-2], split[-2],
split[-1], split[-1],
) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output ) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output
except Exception as e: except Exception:
raise AgentParsingError( raise AgentParsingError(
f"Error: No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!" f"Error: No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!"
) )
@ -909,8 +906,9 @@ class CodeAgent(ReactAgent):
self.authorized_imports = list( self.authorized_imports = list(
set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports) set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)
) )
print("SYSS:", self.system_prompt)
self.system_prompt = self.system_prompt.replace( self.system_prompt = self.system_prompt.replace(
"<<authorized_imports>>", str(self.authorized_imports) "{{authorized_imports}}", str(self.authorized_imports)
) )
self.custom_tools = {} self.custom_tools = {}

View File

@ -135,6 +135,8 @@ final_answer(caption)
Above example were using tools that might not exist for you. You only have access to these tools: Above example were using tools that might not exist for you. You only have access to these tools:
{{tool_names}} {{tool_names}}
{{managed_agents_descriptions}}
Remember to make sure that variables you use are all defined. In particular don't import packages! Remember to make sure that variables you use are all defined. In particular don't import packages!
Be sure to provide a 'Code:\n```' sequence before the code and '```<end_action>' after, else you will get an error. Be sure to provide a 'Code:\n```' sequence before the code and '```<end_action>' after, else you will get an error.
DO NOT pass the arguments as a dict as in 'answer = ask_search_agent({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = ask_search_agent(query="What is the place where James Bond lives?")'. DO NOT pass the arguments as a dict as in 'answer = ask_search_agent({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = ask_search_agent(query="What is the place where James Bond lives?")'.
@ -260,8 +262,11 @@ Action:
Above example were using notional tools that might not exist for you. You only have access to these tools: Above example were using notional tools that might not exist for you. You only have access to these tools:
{{tool_descriptions}} {{tool_descriptions}}
{{managed_agents_descriptions}}
Here are the rules you should always follow to solve your task: Here are the rules you should always follow to solve your task:
1. ALWAYS provide a 'Thought:' sequence, and an 'Action:' sequence that ends with <end_action>, else you will fail. 1. ALWAYS provide a 'Thought:' sequence, and an 'Action:' sequence that ends with <end_action>, else you will fail.
2. Always use the right arguments for the tools. Never use variable names in the 'action_input' field, use the value instead. 2. Always use the right arguments for the tools. Never use variable names in the 'action_input' field, use the value instead.
@ -355,7 +360,7 @@ Above example were using notional tools that might not exist for you. On top of
{{tool_descriptions}} {{tool_descriptions}}
<<managed_agents_descriptions>> {{managed_agents_descriptions}}
Here are the rules you should always follow to solve your task: Here are the rules you should always follow to solve your task:
1. Always provide a 'Thought:' sequence, and a 'Code:\n```py' sequence ending with '```<end_action>' sequence, else you will fail. 1. Always provide a 'Thought:' sequence, and a 'Code:\n```py' sequence ending with '```<end_action>' sequence, else you will fail.

View File

@ -643,8 +643,10 @@ DEFAULT_TOOL_DESCRIPTION_TEMPLATE = """
def get_tool_description_with_args( def get_tool_description_with_args(
tool: Tool, description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE tool: Tool, description_template: Optional[str] = None
) -> str: ) -> str:
if description_template is None:
description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
compiled_template = compile_jinja_template(description_template) compiled_template = compile_jinja_template(description_template)
rendered = compiled_template.render( rendered = compiled_template.render(
tool=tool, tool=tool,
@ -1080,6 +1082,9 @@ def tool(tool_function: Callable) -> Tool:
return SpecificTool() return SpecificTool()
HUGGINGFACE_DEFAULT_TOOLS = {}
class Toolbox: class Toolbox:
""" """
The toolbox contains all tools that the agent can perform operations with, as well as a few methods to The toolbox contains all tools that the agent can perform operations with, as well as a few methods to
@ -1110,7 +1115,7 @@ class Toolbox:
"""Get all tools currently in the toolbox""" """Get all tools currently in the toolbox"""
return self._tools return self._tools
def show_tool_descriptions(self, tool_description_template: str = None) -> str: def show_tool_descriptions(self, tool_description_template: Optional[str] = None) -> str:
""" """
Returns the description of all tools in the toolbox Returns the description of all tools in the toolbox

View File

@ -1,279 +0,0 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from .tools import PipelineTool
LANGUAGE_CODES = {
"Acehnese Arabic": "ace_Arab",
"Acehnese Latin": "ace_Latn",
"Mesopotamian Arabic": "acm_Arab",
"Ta'izzi-Adeni Arabic": "acq_Arab",
"Tunisian Arabic": "aeb_Arab",
"Afrikaans": "afr_Latn",
"South Levantine Arabic": "ajp_Arab",
"Akan": "aka_Latn",
"Amharic": "amh_Ethi",
"North Levantine Arabic": "apc_Arab",
"Modern Standard Arabic": "arb_Arab",
"Modern Standard Arabic Romanized": "arb_Latn",
"Najdi Arabic": "ars_Arab",
"Moroccan Arabic": "ary_Arab",
"Egyptian Arabic": "arz_Arab",
"Assamese": "asm_Beng",
"Asturian": "ast_Latn",
"Awadhi": "awa_Deva",
"Central Aymara": "ayr_Latn",
"South Azerbaijani": "azb_Arab",
"North Azerbaijani": "azj_Latn",
"Bashkir": "bak_Cyrl",
"Bambara": "bam_Latn",
"Balinese": "ban_Latn",
"Belarusian": "bel_Cyrl",
"Bemba": "bem_Latn",
"Bengali": "ben_Beng",
"Bhojpuri": "bho_Deva",
"Banjar Arabic": "bjn_Arab",
"Banjar Latin": "bjn_Latn",
"Standard Tibetan": "bod_Tibt",
"Bosnian": "bos_Latn",
"Buginese": "bug_Latn",
"Bulgarian": "bul_Cyrl",
"Catalan": "cat_Latn",
"Cebuano": "ceb_Latn",
"Czech": "ces_Latn",
"Chokwe": "cjk_Latn",
"Central Kurdish": "ckb_Arab",
"Crimean Tatar": "crh_Latn",
"Welsh": "cym_Latn",
"Danish": "dan_Latn",
"German": "deu_Latn",
"Southwestern Dinka": "dik_Latn",
"Dyula": "dyu_Latn",
"Dzongkha": "dzo_Tibt",
"Greek": "ell_Grek",
"English": "eng_Latn",
"Esperanto": "epo_Latn",
"Estonian": "est_Latn",
"Basque": "eus_Latn",
"Ewe": "ewe_Latn",
"Faroese": "fao_Latn",
"Fijian": "fij_Latn",
"Finnish": "fin_Latn",
"Fon": "fon_Latn",
"French": "fra_Latn",
"Friulian": "fur_Latn",
"Nigerian Fulfulde": "fuv_Latn",
"Scottish Gaelic": "gla_Latn",
"Irish": "gle_Latn",
"Galician": "glg_Latn",
"Guarani": "grn_Latn",
"Gujarati": "guj_Gujr",
"Haitian Creole": "hat_Latn",
"Hausa": "hau_Latn",
"Hebrew": "heb_Hebr",
"Hindi": "hin_Deva",
"Chhattisgarhi": "hne_Deva",
"Croatian": "hrv_Latn",
"Hungarian": "hun_Latn",
"Armenian": "hye_Armn",
"Igbo": "ibo_Latn",
"Ilocano": "ilo_Latn",
"Indonesian": "ind_Latn",
"Icelandic": "isl_Latn",
"Italian": "ita_Latn",
"Javanese": "jav_Latn",
"Japanese": "jpn_Jpan",
"Kabyle": "kab_Latn",
"Jingpho": "kac_Latn",
"Kamba": "kam_Latn",
"Kannada": "kan_Knda",
"Kashmiri Arabic": "kas_Arab",
"Kashmiri Devanagari": "kas_Deva",
"Georgian": "kat_Geor",
"Central Kanuri Arabic": "knc_Arab",
"Central Kanuri Latin": "knc_Latn",
"Kazakh": "kaz_Cyrl",
"Kabiyè": "kbp_Latn",
"Kabuverdianu": "kea_Latn",
"Khmer": "khm_Khmr",
"Kikuyu": "kik_Latn",
"Kinyarwanda": "kin_Latn",
"Kyrgyz": "kir_Cyrl",
"Kimbundu": "kmb_Latn",
"Northern Kurdish": "kmr_Latn",
"Kikongo": "kon_Latn",
"Korean": "kor_Hang",
"Lao": "lao_Laoo",
"Ligurian": "lij_Latn",
"Limburgish": "lim_Latn",
"Lingala": "lin_Latn",
"Lithuanian": "lit_Latn",
"Lombard": "lmo_Latn",
"Latgalian": "ltg_Latn",
"Luxembourgish": "ltz_Latn",
"Luba-Kasai": "lua_Latn",
"Ganda": "lug_Latn",
"Luo": "luo_Latn",
"Mizo": "lus_Latn",
"Standard Latvian": "lvs_Latn",
"Magahi": "mag_Deva",
"Maithili": "mai_Deva",
"Malayalam": "mal_Mlym",
"Marathi": "mar_Deva",
"Minangkabau Arabic ": "min_Arab",
"Minangkabau Latin": "min_Latn",
"Macedonian": "mkd_Cyrl",
"Plateau Malagasy": "plt_Latn",
"Maltese": "mlt_Latn",
"Meitei Bengali": "mni_Beng",
"Halh Mongolian": "khk_Cyrl",
"Mossi": "mos_Latn",
"Maori": "mri_Latn",
"Burmese": "mya_Mymr",
"Dutch": "nld_Latn",
"Norwegian Nynorsk": "nno_Latn",
"Norwegian Bokmål": "nob_Latn",
"Nepali": "npi_Deva",
"Northern Sotho": "nso_Latn",
"Nuer": "nus_Latn",
"Nyanja": "nya_Latn",
"Occitan": "oci_Latn",
"West Central Oromo": "gaz_Latn",
"Odia": "ory_Orya",
"Pangasinan": "pag_Latn",
"Eastern Panjabi": "pan_Guru",
"Papiamento": "pap_Latn",
"Western Persian": "pes_Arab",
"Polish": "pol_Latn",
"Portuguese": "por_Latn",
"Dari": "prs_Arab",
"Southern Pashto": "pbt_Arab",
"Ayacucho Quechua": "quy_Latn",
"Romanian": "ron_Latn",
"Rundi": "run_Latn",
"Russian": "rus_Cyrl",
"Sango": "sag_Latn",
"Sanskrit": "san_Deva",
"Santali": "sat_Olck",
"Sicilian": "scn_Latn",
"Shan": "shn_Mymr",
"Sinhala": "sin_Sinh",
"Slovak": "slk_Latn",
"Slovenian": "slv_Latn",
"Samoan": "smo_Latn",
"Shona": "sna_Latn",
"Sindhi": "snd_Arab",
"Somali": "som_Latn",
"Southern Sotho": "sot_Latn",
"Spanish": "spa_Latn",
"Tosk Albanian": "als_Latn",
"Sardinian": "srd_Latn",
"Serbian": "srp_Cyrl",
"Swati": "ssw_Latn",
"Sundanese": "sun_Latn",
"Swedish": "swe_Latn",
"Swahili": "swh_Latn",
"Silesian": "szl_Latn",
"Tamil": "tam_Taml",
"Tatar": "tat_Cyrl",
"Telugu": "tel_Telu",
"Tajik": "tgk_Cyrl",
"Tagalog": "tgl_Latn",
"Thai": "tha_Thai",
"Tigrinya": "tir_Ethi",
"Tamasheq Latin": "taq_Latn",
"Tamasheq Tifinagh": "taq_Tfng",
"Tok Pisin": "tpi_Latn",
"Tswana": "tsn_Latn",
"Tsonga": "tso_Latn",
"Turkmen": "tuk_Latn",
"Tumbuka": "tum_Latn",
"Turkish": "tur_Latn",
"Twi": "twi_Latn",
"Central Atlas Tamazight": "tzm_Tfng",
"Uyghur": "uig_Arab",
"Ukrainian": "ukr_Cyrl",
"Umbundu": "umb_Latn",
"Urdu": "urd_Arab",
"Northern Uzbek": "uzn_Latn",
"Venetian": "vec_Latn",
"Vietnamese": "vie_Latn",
"Waray": "war_Latn",
"Wolof": "wol_Latn",
"Xhosa": "xho_Latn",
"Eastern Yiddish": "ydd_Hebr",
"Yoruba": "yor_Latn",
"Yue Chinese": "yue_Hant",
"Chinese Simplified": "zho_Hans",
"Chinese Traditional": "zho_Hant",
"Standard Malay": "zsm_Latn",
"Zulu": "zul_Latn",
}
class TranslationTool(PipelineTool):
"""
Example:
```py
from transformers.agents import TranslationTool
translator = TranslationTool()
translator("This is a super nice API!", src_lang="English", tgt_lang="French")
```
"""
lang_to_code = LANGUAGE_CODES
default_checkpoint = "facebook/nllb-200-distilled-600M"
description = (
"This is a tool that translates text from a language to another."
f"Both `src_lang`and `tgt_lang` should belong to this list of languages: {list(lang_to_code.keys())}."
)
name = "translator"
pre_processor_class = AutoTokenizer
model_class = AutoModelForSeq2SeqLM
inputs = {
"text": {"type": "string", "description": "The text to translate"},
"src_lang": {
"type": "string",
"description": "The language of the text to translate. Written in plain English, such as 'Romanian', or 'Albanian'",
},
"tgt_lang": {
"type": "string",
"description": "The language for the desired ouput language. Written in plain English, such as 'Romanian', or 'Albanian'",
},
}
output_type = "string"
def encode(self, text, src_lang, tgt_lang):
if src_lang not in self.lang_to_code:
raise ValueError(f"{src_lang} is not a supported language.")
if tgt_lang not in self.lang_to_code:
raise ValueError(f"{tgt_lang} is not a supported language.")
src_lang = self.lang_to_code[src_lang]
tgt_lang = self.lang_to_code[tgt_lang]
return self.pre_processor._build_translation_inputs(
text, return_tensors="pt", src_lang=src_lang, tgt_lang=tgt_lang
)
def forward(self, inputs):
return self.model.generate(**inputs)
def decode(self, outputs):
return self.post_processor.decode(outputs[0].tolist(), skip_special_tokens=True)

View File

@ -1,10 +1,10 @@
from agents import load_tool, CodeAgent, HfApiEngine from agents import load_tool, CodeAgent, HfApiEngine
from agents.search import DuckDuckGoSearchTool
# Import tool from Hub # Import tool from Hub
image_generation_tool = load_tool("m-ric/text-to-image", cache=False) image_generation_tool = load_tool("m-ric/text-to-image", cache=False)
# Import tool from LangChain # Import tool from LangChain
from agents.search import DuckDuckGoSearchTool
search_tool = DuckDuckGoSearchTool() search_tool = DuckDuckGoSearchTool()

View File

@ -1,4 +1,5 @@
from agents import stream_to_gradio, HfApiEngine, load_tool, CodeAgent from agents import stream_to_gradio, HfApiEngine, load_tool, CodeAgent
import gradio as gr
image_generation_tool = load_tool("m-ric/text-to-image") image_generation_tool = load_tool("m-ric/text-to-image")
@ -6,9 +7,6 @@ llm_engine = HfApiEngine("Qwen/Qwen2.5-72B-Instruct")
agent = CodeAgent(tools=[image_generation_tool], llm_engine=llm_engine) agent = CodeAgent(tools=[image_generation_tool], llm_engine=llm_engine)
import gradio as gr
def interact_with_agent(prompt, messages): def interact_with_agent(prompt, messages):
messages.append(gr.ChatMessage(role="user", content=prompt)) messages.append(gr.ChatMessage(role="user", content=prompt))
yield messages yield messages

View File

@ -1,14 +1,19 @@
from agents.llm_engine import TransformersEngine from agents import JsonAgent
from agents import CodeAgent, JsonAgent from agents import tool
import webbrowser
import requests import requests
from datetime import datetime from datetime import datetime
import random
from llama_cpp import Llama
from agents import tool
import webbrowser
from typing import List, Generator, Dict, Any
import json
import re
model_repo="andito/SmolLM2-1.7B-Instruct-F16-GGUF" model_repo="andito/SmolLM2-1.7B-Instruct-F16-GGUF"
model_filename="smollm2-1.7b-8k-dpo-f16.gguf" model_filename="smollm2-1.7b-8k-dpo-f16.gguf"
import random
from llama_cpp import Llama
model = Llama.from_pretrained( model = Llama.from_pretrained(
repo_id=model_repo, repo_id=model_repo,
@ -55,10 +60,6 @@ The example format is as follows. Please make sure the parameter type is correct
... (more tool calls as required) ... (more tool calls as required)
]</tool_call>""" ]</tool_call>"""
from agents import tool
import webbrowser
@tool @tool
def get_random_number_between(min: int, max: int) -> int: def get_random_number_between(min: int, max: int) -> int:
""" """
@ -110,9 +111,7 @@ def open_webbrowser(url: str) -> str:
webbrowser.open(url) webbrowser.open(url)
return f"I opened {url.replace('https://', '').replace('www.', '')} in the browser." return f"I opened {url.replace('https://', '').replace('www.', '')} in the browser."
from typing import List, Dict, Generator, Any
import re
import json
def _parse_response(self, text: str) -> List[Dict[str, Any]]: def _parse_response(self, text: str) -> List[Dict[str, Any]]:
pattern = r"<tool_call>(.*?)</tool_call>" pattern = r"<tool_call>(.*?)</tool_call>"
matches = re.findall(pattern, text, re.DOTALL) matches = re.findall(pattern, text, re.DOTALL)

View File

@ -1,4 +1,4 @@
from agents import load_tool, CodeAgent, JsonAgent, HfApiEngine from agents import load_tool, CodeAgent, HfApiEngine
from agents.prompts import ONESHOT_CODE_SYSTEM_PROMPT from agents.prompts import ONESHOT_CODE_SYSTEM_PROMPT
# Import tool from Hub # Import tool from Hub

BIN
tests/fixtures/000000039769.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 678 KiB

View File

@ -27,8 +27,6 @@ from transformers.testing_utils import (
) )
from transformers.utils import ( from transformers.utils import (
is_soundfile_availble, is_soundfile_availble,
is_torch_available,
is_vision_available,
) )
import torch import torch
@ -93,7 +91,7 @@ class AgentImageTests(unittest.TestCase):
self.assertTrue(os.path.exists(path)) self.assertTrue(os.path.exists(path))
def test_from_string(self): def test_from_string(self):
path = Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png" path = Path(get_tests_dir("fixtures/")) / "000000039769.png"
image = Image.open(path) image = Image.open(path)
agent_type = AgentImage(path) agent_type = AgentImage(path)
@ -105,7 +103,7 @@ class AgentImageTests(unittest.TestCase):
self.assertTrue(os.path.exists(path)) self.assertTrue(os.path.exists(path))
def test_from_image(self): def test_from_image(self):
path = Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png" path = Path(get_tests_dir("fixtures/")) / "000000039769.png"
image = Image.open(path) image = Image.open(path)
agent_type = AgentImage(image) agent_type = AgentImage(image)

View File

@ -22,7 +22,6 @@ import pytest
from agents.agent_types import AgentText from agents.agent_types import AgentText
from agents.agents import ( from agents.agents import (
AgentMaxIterationsError, AgentMaxIterationsError,
CodeAgent,
ManagedAgent, ManagedAgent,
CodeAgent, CodeAgent,
JsonAgent, JsonAgent,
@ -162,14 +161,14 @@ class AgentTests(unittest.TestCase):
output = agent.run("What is 2 multiplied by 3.6452?") output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, str) assert isinstance(output, str)
assert output == "7.2904" assert output == "7.2904"
assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?" assert agent.logs[1].task == "What is 2 multiplied by 3.6452?"
assert agent.logs[1]["observation"] == "7.2904" assert agent.logs[2].observation == "7.2904"
assert ( assert (
agent.logs[1]["rationale"].strip() agent.logs[2].rationale.strip()
== "Thought: I should multiply 2 by 3.6452. special_marker" == "Thought: I should multiply 2 by 3.6452. special_marker"
) )
assert ( assert (
agent.logs[2]["llm_output"] agent.logs[3].llm_output
== """ == """
Thought: I can now answer the initial question Thought: I can now answer the initial question
Action: Action:
@ -187,8 +186,8 @@ Action:
output = agent.run("What is 2 multiplied by 3.6452?") output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, float) assert isinstance(output, float)
assert output == 7.2904 assert output == 7.2904
assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?" assert agent.logs[1].task == "What is 2 multiplied by 3.6452?"
assert agent.logs[2]["tool_call"] == { assert agent.logs[3].tool_call == {
"tool_arguments": "final_answer(7.2904)", "tool_arguments": "final_answer(7.2904)",
"tool_name": "code interpreter", "tool_name": "code interpreter",
} }
@ -212,10 +211,9 @@ Action:
max_iterations=5, max_iterations=5,
) )
agent.run("What is 2 multiplied by 3.6452?") agent.run("What is 2 multiplied by 3.6452?")
assert len(agent.logs) == 7 assert len(agent.logs) == 8
assert type(agent.logs[-1]["error"]) is AgentMaxIterationsError assert type(agent.logs[-1].error) is AgentMaxIterationsError
@require_torch
def test_init_agent_with_different_toolsets(self): def test_init_agent_with_different_toolsets(self):
toolset_1 = [] toolset_1 = []
agent = CodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm) agent = CodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm)
@ -245,8 +243,8 @@ Action:
# check that python_interpreter base tool does not get added to code agents # check that python_interpreter base tool does not get added to code agents
agent = CodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True) agent = CodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
assert ( assert (
len(agent.toolbox.tools) == 7 len(agent.toolbox.tools) == 2
) # added final_answer tool + 6 base tools (excluding interpreter) ) # added final_answer tool + search
def test_function_persistence_across_steps(self): def test_function_persistence_across_steps(self):
agent = CodeAgent( agent = CodeAgent(
@ -273,7 +271,8 @@ Action:
managed_agents=[managed_agent], managed_agents=[managed_agent],
) )
assert "You can also give requests to team members." not in agent.system_prompt assert "You can also give requests to team members." not in agent.system_prompt
assert "<<managed_agents_descriptions>>" not in agent.system_prompt print("ok1")
assert "{{managed_agents_descriptions}}" not in agent.system_prompt
assert ( assert (
"You can also give requests to team members." in manager_agent.system_prompt "You can also give requests to team members." in manager_agent.system_prompt
) )

View File

@ -20,25 +20,9 @@ import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
from unittest import mock, skip from unittest import mock, skip
from typing import List
import torch from .test_utils import slow, skip, get_launch_command, TempDirTestCase
from accelerate.test_utils.examples import compare_against_test
from accelerate.test_utils.testing import (
TempDirTestCase,
get_launch_command,
require_huggingface_suite,
require_multi_device,
require_multi_gpu,
require_non_xpu,
require_pippy,
require_schedulefree,
require_trackers,
run_command,
slow,
)
from accelerate.utils import write_basic_config
# DataLoaders built from `test_samples/MRPC` for quick testing # DataLoaders built from `test_samples/MRPC` for quick testing
# Should mock `{script_name}.get_dataloaders` via: # Should mock `{script_name}.get_dataloaders` via:
@ -62,242 +46,51 @@ EXCLUDE_EXAMPLES = [
"profiler.py", "profiler.py",
] ]
import subprocess
class ExampleDifferenceTests(unittest.TestCase):
class SubprocessCallException(Exception):
pass
def run_command(command: List[str], return_stdout=False, env=None):
""" """
This TestCase checks that all of the `complete_*` scripts contain all of the Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
information found in the `by_feature` scripts, line for line. If one fails, if an error occured while running `command`
then a complete example does not contain all of the features in the features
scripts, and should be updated.
Each example script should be a single test (such as `test_nlp_example`),
and should run `one_complete_example` twice: once with `parser_only=True`,
and the other with `parser_only=False`. This is so that when the test
failures are returned to the user, they understand if the discrepancy lies in
the `main` function, or the `training_loop` function. Otherwise it will be
unclear.
Also, if there are any expected differences between the base script used and
`complete_nlp_example.py` (the canonical base script), these should be included in
`special_strings`. These would be differences in how something is logged, print statements,
etc (such as calls to `Accelerate.log()`)
""" """
# Cast every path in `command` to a string
by_feature_path = Path("examples", "by_feature").resolve() for i, c in enumerate(command):
examples_path = Path("examples").resolve() if isinstance(c, Path):
command[i] = str(c)
def one_complete_example( if env is None:
self, env = os.environ.copy()
complete_file_name: str, try:
parser_only: bool, output = subprocess.check_output(command, stderr=subprocess.STDOUT, env=env)
secondary_filename: str = None, if return_stdout:
special_strings: list = None, if hasattr(output, "decode"):
): output = output.decode("utf-8")
""" return output
Tests a single `complete` example against all of the implemented `by_feature` scripts except subprocess.CalledProcessError as e:
raise SubprocessCallException(
Args: f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
complete_file_name (`str`): ) from e
The filename of a complete example
parser_only (`bool`):
Whether to look at the main training function, or the argument parser
secondary_filename (`str`, *optional*):
A potential secondary base file to strip all script information not relevant for checking,
such as "cv_example.py" when testing "complete_cv_example.py"
special_strings (`list`, *optional*):
A list of strings to potentially remove before checking no differences are left. These should be
diffs that are file specific, such as different logging variations between files.
"""
self.maxDiff = None
for item in os.listdir(self.by_feature_path):
if item not in EXCLUDE_EXAMPLES:
item_path = self.by_feature_path / item
if item_path.is_file() and item_path.suffix == ".py":
with self.subTest(
tested_script=complete_file_name,
feature_script=item,
tested_section="main()"
if parser_only
else "training_function()",
):
diff = compare_against_test(
self.examples_path / complete_file_name,
item_path,
parser_only,
secondary_filename,
)
diff = "\n".join(diff)
if special_strings is not None:
for string in special_strings:
diff = diff.replace(string, "")
assert diff == ""
def test_nlp_examples(self):
self.one_complete_example("complete_nlp_example.py", True)
self.one_complete_example("complete_nlp_example.py", False)
def test_cv_examples(self):
cv_path = (self.examples_path / "cv_example.py").resolve()
special_strings = [
" " * 16 + "{\n\n",
" " * 20 + '"accuracy": eval_metric["accuracy"],\n\n',
" " * 20 + '"f1": eval_metric["f1"],\n\n',
" " * 20 + '"train_loss": total_loss.item() / len(train_dataloader),\n\n',
" " * 20 + '"epoch": epoch,\n\n',
" " * 16 + "},\n\n",
" " * 16 + "step=epoch,\n",
" " * 12,
" " * 8 + "for step, batch in enumerate(active_dataloader):\n",
]
self.one_complete_example(
"complete_cv_example.py", True, cv_path, special_strings
)
self.one_complete_example(
"complete_cv_example.py", False, cv_path, special_strings
)
@mock.patch.dict(os.environ, {"TESTING_MOCKED_DATALOADERS": "1"}) class ExamplesTests(TempDirTestCase):
@require_huggingface_suite
class FeatureExamplesTests(TempDirTestCase):
clear_on_setup = False clear_on_setup = False
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super().setUpClass() super().setUpClass()
cls._tmpdir = tempfile.mkdtemp() cls._tmpdir = tempfile.mkdtemp()
cls.config_file = Path(cls._tmpdir) / "default_config.yml" cls.launch_args = ["python3"]
write_basic_config(save_location=cls.config_file)
cls.launch_args = get_launch_command(config_file=cls.config_file)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
super().tearDownClass() super().tearDownClass()
shutil.rmtree(cls._tmpdir) shutil.rmtree(cls._tmpdir)
def test_checkpointing_by_epoch(self):
testargs = f""" def test_oneshot(self):
examples/by_feature/checkpointing.py testargs = ["examples/oneshot.py"]
--checkpointing_steps epoch
--output_dir {self.tmpdir}
""".split()
run_command(self.launch_args + testargs)
assert (self.tmpdir / "epoch_0").exists()
def test_checkpointing_by_steps(self):
testargs = f"""
examples/by_feature/checkpointing.py
--checkpointing_steps 1
--output_dir {self.tmpdir}
""".split()
_ = run_command(self.launch_args + testargs)
assert (self.tmpdir / "step_2").exists()
def test_load_states_by_epoch(self):
testargs = f"""
examples/by_feature/checkpointing.py
--resume_from_checkpoint {self.tmpdir / "epoch_0"}
""".split()
output = run_command(self.launch_args + testargs, return_stdout=True)
assert "epoch 0:" not in output
assert "epoch 1:" in output
def test_load_states_by_steps(self):
testargs = f"""
examples/by_feature/checkpointing.py
--resume_from_checkpoint {self.tmpdir / "step_2"}
""".split()
output = run_command(self.launch_args + testargs, return_stdout=True)
if torch.cuda.is_available():
num_processes = torch.cuda.device_count()
else:
num_processes = 1
if num_processes > 1:
assert "epoch 0:" not in output
assert "epoch 1:" in output
else:
assert "epoch 0:" in output
assert "epoch 1:" in output
@slow
def test_cross_validation(self):
testargs = """
examples/by_feature/cross_validation.py
--num_folds 2
""".split()
with mock.patch.dict(os.environ, {"TESTING_MOCKED_DATALOADERS": "0"}):
output = run_command(self.launch_args + testargs, return_stdout=True)
results = re.findall("({.+})", output)
results = [r for r in results if "accuracy" in r][-1]
results = ast.literal_eval(results)
assert results["accuracy"] >= 0.75
def test_multi_process_metrics(self):
testargs = ["examples/by_feature/multi_process_metrics.py"]
run_command(self.launch_args + testargs)
@require_schedulefree
def test_schedulefree(self):
testargs = ["examples/by_feature/schedule_free.py"]
run_command(self.launch_args + testargs)
@require_trackers
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
def test_tracking(self):
with tempfile.TemporaryDirectory() as tmpdir:
testargs = f"""
examples/by_feature/tracking.py
--with_tracking
--project_dir {tmpdir}
""".split()
run_command(self.launch_args + testargs)
assert os.path.exists(os.path.join(tmpdir, "tracking"))
def test_gradient_accumulation(self):
testargs = ["examples/by_feature/gradient_accumulation.py"]
run_command(self.launch_args + testargs)
def test_local_sgd(self):
testargs = ["examples/by_feature/local_sgd.py"]
run_command(self.launch_args + testargs)
def test_early_stopping(self):
testargs = ["examples/by_feature/early_stopping.py"]
run_command(self.launch_args + testargs)
def test_profiler(self):
testargs = ["examples/by_feature/profiler.py"]
run_command(self.launch_args + testargs)
@require_multi_device
def test_ddp_comm_hook(self):
testargs = ["examples/by_feature/ddp_comm_hook.py", "--ddp_comm_hook", "fp16"]
run_command(self.launch_args + testargs)
@skip(
reason="stable-diffusion-v1-5 is no longer available. Potentially `Comfy-Org/stable-diffusion-v1-5-archive` once diffusers support is added."
)
@require_multi_device
def test_distributed_inference_examples_stable_diffusion(self):
testargs = ["examples/inference/distributed/stable_diffusion.py"]
run_command(self.launch_args + testargs)
@require_multi_device
def test_distributed_inference_examples_phi2(self):
testargs = ["examples/inference/distributed/phi2.py"]
run_command(self.launch_args + testargs)
@require_non_xpu
@require_pippy
@require_multi_gpu
def test_pippy_examples_bert(self):
testargs = ["examples/inference/pippy/bert.py"]
run_command(self.launch_args + testargs)
@require_non_xpu
@require_pippy
@require_multi_gpu
def test_pippy_examples_gpt2(self):
testargs = ["examples/inference/pippy/gpt2.py"]
run_command(self.launch_args + testargs) run_command(self.launch_args + testargs)

View File

@ -48,7 +48,7 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
inputs_text = {"answer": "Text input"} inputs_text = {"answer": "Text input"}
inputs_image = { inputs_image = {
"answer": Image.open( "answer": Image.open(
Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png" Path(get_tests_dir("fixtures")) / "000000039769.png"
).resize((512, 512)) ).resize((512, 512))
} }
inputs_audio = {"answer": torch.Tensor(np.ones(3000))} inputs_audio = {"answer": torch.Tensor(np.ones(3000))}

View File

@ -50,7 +50,7 @@ def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]):
inputs[input_name] = "Text input" inputs[input_name] = "Text input"
elif input_type == "image": elif input_type == "image":
inputs[input_name] = Image.open( inputs[input_name] = Image.open(
Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png" Path(get_tests_dir("fixtures")) / "000000039769.png"
).resize((512, 512)) ).resize((512, 512))
elif input_type == "audio": elif input_type == "audio":
inputs[input_name] = np.ones(3000) inputs[input_name] = np.ones(3000)

View File

@ -1,69 +0,0 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import load_tool
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
from .test_tools_common import ToolTesterMixin, output_type
class TranslationToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("translation")
self.tool.setup()
self.remote_tool = load_tool("translation", remote=True)
def test_exact_match_arg(self):
result = self.tool("Hey, what's up?", src_lang="English", tgt_lang="French")
self.assertEqual(result, "- Hé, comment ça va?")
def test_exact_match_kwarg(self):
result = self.tool(
text="Hey, what's up?", src_lang="English", tgt_lang="French"
)
self.assertEqual(result, "- Hé, comment ça va?")
def test_call(self):
inputs = ["Hey, what's up?", "English", "Spanish"]
output = self.tool(*inputs)
self.assertEqual(output_type(output), self.tool.output_type)
def test_agent_type_output(self):
inputs = ["Hey, what's up?", "English", "Spanish"]
output = self.tool(*inputs)
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
self.assertTrue(isinstance(output, output_type))
def test_agent_types_inputs(self):
example_inputs = {
"text": "Hey, what's up?",
"src_lang": "English",
"tgt_lang": "Spanish",
}
_inputs = []
for input_name in example_inputs.keys():
example_input = example_inputs[input_name]
input_description = self.tool.inputs[input_name]
input_type = input_description["type"]
_inputs.append(AGENT_TYPE_MAPPING[input_type](example_input))
# Should not raise an error
output = self.tool(**example_inputs)
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
self.assertTrue(isinstance(output, output_type))

100
tests/test_utils.py Normal file
View File

@ -0,0 +1,100 @@
import os
import unittest
import shutil
import tempfile
from pathlib import Path
def str_to_bool(value) -> int:
"""
Converts a string representation of truth to `True` (1) or `False` (0).
True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`;
"""
value = value.lower()
if value in ("y", "yes", "t", "true", "on", "1"):
return 1
elif value in ("n", "no", "f", "false", "off", "0"):
return 0
else:
raise ValueError(f"invalid truth value {value}")
def get_int_from_env(env_keys, default):
"""Returns the first positive env value found in the `env_keys` list or the default."""
for e in env_keys:
val = int(os.environ.get(e, -1))
if val >= 0:
return val
return default
def parse_flag_from_env(key, default=False):
"""Returns truthy value for `key` from the env if available else the default."""
value = os.environ.get(key, str(default))
return str_to_bool(value) == 1 # As its name indicates `str_to_bool` actually returns an int...
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
def skip(test_case):
"Decorator that skips a test unconditionally"
return unittest.skip("Test was skipped")(test_case)
def slow(test_case):
"""
Decorator marking a test as slow. Slow tests are skipped by default. Set the RUN_SLOW environment variable to a
truthy value to run them.
"""
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
def get_launch_command(**kwargs) -> list:
"""
Wraps around `kwargs` to help simplify launching from `subprocess`.
Example:
```python
# returns ['accelerate', 'launch', '--num_processes=2', '--device_count=2']
get_launch_command(num_processes=2, device_count=2)
```
"""
command = ["accelerate", "launch"]
for k, v in kwargs.items():
if isinstance(v, bool) and v:
command.append(f"--{k}")
elif v is not None:
command.append(f"--{k}={v}")
return command
class TempDirTestCase(unittest.TestCase):
"""
A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its
data at the start of a test, and then destroyes it at the end of the TestCase.
Useful for when a class or API requires a single constant folder throughout it's use, such as Weights and Biases
The temporary directory location will be stored in `self.tmpdir`
"""
clear_on_setup = True
@classmethod
def setUpClass(cls):
"Creates a `tempfile.TemporaryDirectory` and stores it in `cls.tmpdir`"
cls.tmpdir = Path(tempfile.mkdtemp())
@classmethod
def tearDownClass(cls):
"Remove `cls.tmpdir` after test suite has finished"
if os.path.exists(cls.tmpdir):
shutil.rmtree(cls.tmpdir)
def setUp(self):
"Destroy all contents in `self.tmpdir`, but not `self.tmpdir`"
if self.clear_on_setup:
for path in self.tmpdir.glob("**/*"):
if path.is_file():
path.unlink()
elif path.is_dir():
shutil.rmtree(path)