diff --git a/agents/__init__.py b/agents/__init__.py index 3aa8701..c2b44e2 100644 --- a/agents/__init__.py +++ b/agents/__init__.py @@ -62,7 +62,6 @@ else: if TYPE_CHECKING: from .agents import ( Agent, - CodeAgent, ManagedAgent, ReactAgent, CodeAgent, diff --git a/agents/agents.py b/agents/agents.py index 36108b3..9652257 100644 --- a/agents/agents.py +++ b/agents/agents.py @@ -47,9 +47,6 @@ from .tools import ( ) -HUGGINGFACE_DEFAULT_TOOLS = {} - - class AgentError(Exception): """Base class for other agent-related exceptions""" @@ -145,14 +142,18 @@ Here is a list of the team members that you can call:""" def format_prompt_with_managed_agents_descriptions( - prompt_template, managed_agents=None + prompt_template, managed_agents, agent_descriptions_placeholder: Optional[str] = None ) -> str: - if managed_agents is not None: + if agent_descriptions_placeholder is None: + agent_descriptions_placeholder = "{{managed_agents_descriptions}}" + if agent_descriptions_placeholder not in prompt_template: + raise ValueError(f"Provided prompt template does not contain the managed agents descriptions placeholder '{agent_descriptions_placeholder}'") + if len(managed_agents.keys()) > 0: return prompt_template.replace( - "<>", show_agents_descriptions(managed_agents) + agent_descriptions_placeholder, show_agents_descriptions(managed_agents) ) else: - return prompt_template.replace("<>", "") + return prompt_template.replace(agent_descriptions_placeholder, "") def format_prompt_with_imports( @@ -220,12 +221,8 @@ class BaseAgent: self._toolbox = Toolbox(tools, add_base_tools=add_base_tools) self._toolbox.add_tool(FinalAnswerTool()) - self.system_prompt = format_prompt_with_tools( - self._toolbox, self.system_prompt_template, self.tool_description_template - ) - self.system_prompt = format_prompt_with_managed_agents_descriptions( - self.system_prompt, self.managed_agents - ) + self.system_prompt = self.initialize_system_prompt() + print("SYS0:", self.system_prompt) self.prompt_messages = None self.logs = [] self.task = None @@ -353,7 +350,7 @@ class BaseAgent: split[-2], split[-1], ) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output - except Exception as e: + except Exception: raise AgentParsingError( f"Error: No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!" ) @@ -909,8 +906,9 @@ class CodeAgent(ReactAgent): self.authorized_imports = list( set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports) ) + print("SYSS:", self.system_prompt) self.system_prompt = self.system_prompt.replace( - "<>", str(self.authorized_imports) + "{{authorized_imports}}", str(self.authorized_imports) ) self.custom_tools = {} diff --git a/agents/prompts.py b/agents/prompts.py index e358f7d..1112238 100644 --- a/agents/prompts.py +++ b/agents/prompts.py @@ -135,6 +135,8 @@ final_answer(caption) Above example were using tools that might not exist for you. You only have access to these tools: {{tool_names}} +{{managed_agents_descriptions}} + Remember to make sure that variables you use are all defined. In particular don't import packages! Be sure to provide a 'Code:\n```' sequence before the code and '```' after, else you will get an error. DO NOT pass the arguments as a dict as in 'answer = ask_search_agent({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = ask_search_agent(query="What is the place where James Bond lives?")'. @@ -260,8 +262,11 @@ Action: Above example were using notional tools that might not exist for you. You only have access to these tools: + {{tool_descriptions}} +{{managed_agents_descriptions}} + Here are the rules you should always follow to solve your task: 1. ALWAYS provide a 'Thought:' sequence, and an 'Action:' sequence that ends with , else you will fail. 2. Always use the right arguments for the tools. Never use variable names in the 'action_input' field, use the value instead. @@ -355,7 +360,7 @@ Above example were using notional tools that might not exist for you. On top of {{tool_descriptions}} -<> +{{managed_agents_descriptions}} Here are the rules you should always follow to solve your task: 1. Always provide a 'Thought:' sequence, and a 'Code:\n```py' sequence ending with '```' sequence, else you will fail. diff --git a/agents/tools.py b/agents/tools.py index 7357ab2..2b29f9c 100644 --- a/agents/tools.py +++ b/agents/tools.py @@ -643,8 +643,10 @@ DEFAULT_TOOL_DESCRIPTION_TEMPLATE = """ def get_tool_description_with_args( - tool: Tool, description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE + tool: Tool, description_template: Optional[str] = None ) -> str: + if description_template is None: + description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE compiled_template = compile_jinja_template(description_template) rendered = compiled_template.render( tool=tool, @@ -1080,6 +1082,9 @@ def tool(tool_function: Callable) -> Tool: return SpecificTool() +HUGGINGFACE_DEFAULT_TOOLS = {} + + class Toolbox: """ The toolbox contains all tools that the agent can perform operations with, as well as a few methods to @@ -1110,7 +1115,7 @@ class Toolbox: """Get all tools currently in the toolbox""" return self._tools - def show_tool_descriptions(self, tool_description_template: str = None) -> str: + def show_tool_descriptions(self, tool_description_template: Optional[str] = None) -> str: """ Returns the description of all tools in the toolbox diff --git a/agents/translation.py b/agents/translation.py deleted file mode 100644 index b03c804..0000000 --- a/agents/translation.py +++ /dev/null @@ -1,279 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 - -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from transformers import AutoModelForSeq2SeqLM, AutoTokenizer -from .tools import PipelineTool - - -LANGUAGE_CODES = { - "Acehnese Arabic": "ace_Arab", - "Acehnese Latin": "ace_Latn", - "Mesopotamian Arabic": "acm_Arab", - "Ta'izzi-Adeni Arabic": "acq_Arab", - "Tunisian Arabic": "aeb_Arab", - "Afrikaans": "afr_Latn", - "South Levantine Arabic": "ajp_Arab", - "Akan": "aka_Latn", - "Amharic": "amh_Ethi", - "North Levantine Arabic": "apc_Arab", - "Modern Standard Arabic": "arb_Arab", - "Modern Standard Arabic Romanized": "arb_Latn", - "Najdi Arabic": "ars_Arab", - "Moroccan Arabic": "ary_Arab", - "Egyptian Arabic": "arz_Arab", - "Assamese": "asm_Beng", - "Asturian": "ast_Latn", - "Awadhi": "awa_Deva", - "Central Aymara": "ayr_Latn", - "South Azerbaijani": "azb_Arab", - "North Azerbaijani": "azj_Latn", - "Bashkir": "bak_Cyrl", - "Bambara": "bam_Latn", - "Balinese": "ban_Latn", - "Belarusian": "bel_Cyrl", - "Bemba": "bem_Latn", - "Bengali": "ben_Beng", - "Bhojpuri": "bho_Deva", - "Banjar Arabic": "bjn_Arab", - "Banjar Latin": "bjn_Latn", - "Standard Tibetan": "bod_Tibt", - "Bosnian": "bos_Latn", - "Buginese": "bug_Latn", - "Bulgarian": "bul_Cyrl", - "Catalan": "cat_Latn", - "Cebuano": "ceb_Latn", - "Czech": "ces_Latn", - "Chokwe": "cjk_Latn", - "Central Kurdish": "ckb_Arab", - "Crimean Tatar": "crh_Latn", - "Welsh": "cym_Latn", - "Danish": "dan_Latn", - "German": "deu_Latn", - "Southwestern Dinka": "dik_Latn", - "Dyula": "dyu_Latn", - "Dzongkha": "dzo_Tibt", - "Greek": "ell_Grek", - "English": "eng_Latn", - "Esperanto": "epo_Latn", - "Estonian": "est_Latn", - "Basque": "eus_Latn", - "Ewe": "ewe_Latn", - "Faroese": "fao_Latn", - "Fijian": "fij_Latn", - "Finnish": "fin_Latn", - "Fon": "fon_Latn", - "French": "fra_Latn", - "Friulian": "fur_Latn", - "Nigerian Fulfulde": "fuv_Latn", - "Scottish Gaelic": "gla_Latn", - "Irish": "gle_Latn", - "Galician": "glg_Latn", - "Guarani": "grn_Latn", - "Gujarati": "guj_Gujr", - "Haitian Creole": "hat_Latn", - "Hausa": "hau_Latn", - "Hebrew": "heb_Hebr", - "Hindi": "hin_Deva", - "Chhattisgarhi": "hne_Deva", - "Croatian": "hrv_Latn", - "Hungarian": "hun_Latn", - "Armenian": "hye_Armn", - "Igbo": "ibo_Latn", - "Ilocano": "ilo_Latn", - "Indonesian": "ind_Latn", - "Icelandic": "isl_Latn", - "Italian": "ita_Latn", - "Javanese": "jav_Latn", - "Japanese": "jpn_Jpan", - "Kabyle": "kab_Latn", - "Jingpho": "kac_Latn", - "Kamba": "kam_Latn", - "Kannada": "kan_Knda", - "Kashmiri Arabic": "kas_Arab", - "Kashmiri Devanagari": "kas_Deva", - "Georgian": "kat_Geor", - "Central Kanuri Arabic": "knc_Arab", - "Central Kanuri Latin": "knc_Latn", - "Kazakh": "kaz_Cyrl", - "Kabiyè": "kbp_Latn", - "Kabuverdianu": "kea_Latn", - "Khmer": "khm_Khmr", - "Kikuyu": "kik_Latn", - "Kinyarwanda": "kin_Latn", - "Kyrgyz": "kir_Cyrl", - "Kimbundu": "kmb_Latn", - "Northern Kurdish": "kmr_Latn", - "Kikongo": "kon_Latn", - "Korean": "kor_Hang", - "Lao": "lao_Laoo", - "Ligurian": "lij_Latn", - "Limburgish": "lim_Latn", - "Lingala": "lin_Latn", - "Lithuanian": "lit_Latn", - "Lombard": "lmo_Latn", - "Latgalian": "ltg_Latn", - "Luxembourgish": "ltz_Latn", - "Luba-Kasai": "lua_Latn", - "Ganda": "lug_Latn", - "Luo": "luo_Latn", - "Mizo": "lus_Latn", - "Standard Latvian": "lvs_Latn", - "Magahi": "mag_Deva", - "Maithili": "mai_Deva", - "Malayalam": "mal_Mlym", - "Marathi": "mar_Deva", - "Minangkabau Arabic ": "min_Arab", - "Minangkabau Latin": "min_Latn", - "Macedonian": "mkd_Cyrl", - "Plateau Malagasy": "plt_Latn", - "Maltese": "mlt_Latn", - "Meitei Bengali": "mni_Beng", - "Halh Mongolian": "khk_Cyrl", - "Mossi": "mos_Latn", - "Maori": "mri_Latn", - "Burmese": "mya_Mymr", - "Dutch": "nld_Latn", - "Norwegian Nynorsk": "nno_Latn", - "Norwegian Bokmål": "nob_Latn", - "Nepali": "npi_Deva", - "Northern Sotho": "nso_Latn", - "Nuer": "nus_Latn", - "Nyanja": "nya_Latn", - "Occitan": "oci_Latn", - "West Central Oromo": "gaz_Latn", - "Odia": "ory_Orya", - "Pangasinan": "pag_Latn", - "Eastern Panjabi": "pan_Guru", - "Papiamento": "pap_Latn", - "Western Persian": "pes_Arab", - "Polish": "pol_Latn", - "Portuguese": "por_Latn", - "Dari": "prs_Arab", - "Southern Pashto": "pbt_Arab", - "Ayacucho Quechua": "quy_Latn", - "Romanian": "ron_Latn", - "Rundi": "run_Latn", - "Russian": "rus_Cyrl", - "Sango": "sag_Latn", - "Sanskrit": "san_Deva", - "Santali": "sat_Olck", - "Sicilian": "scn_Latn", - "Shan": "shn_Mymr", - "Sinhala": "sin_Sinh", - "Slovak": "slk_Latn", - "Slovenian": "slv_Latn", - "Samoan": "smo_Latn", - "Shona": "sna_Latn", - "Sindhi": "snd_Arab", - "Somali": "som_Latn", - "Southern Sotho": "sot_Latn", - "Spanish": "spa_Latn", - "Tosk Albanian": "als_Latn", - "Sardinian": "srd_Latn", - "Serbian": "srp_Cyrl", - "Swati": "ssw_Latn", - "Sundanese": "sun_Latn", - "Swedish": "swe_Latn", - "Swahili": "swh_Latn", - "Silesian": "szl_Latn", - "Tamil": "tam_Taml", - "Tatar": "tat_Cyrl", - "Telugu": "tel_Telu", - "Tajik": "tgk_Cyrl", - "Tagalog": "tgl_Latn", - "Thai": "tha_Thai", - "Tigrinya": "tir_Ethi", - "Tamasheq Latin": "taq_Latn", - "Tamasheq Tifinagh": "taq_Tfng", - "Tok Pisin": "tpi_Latn", - "Tswana": "tsn_Latn", - "Tsonga": "tso_Latn", - "Turkmen": "tuk_Latn", - "Tumbuka": "tum_Latn", - "Turkish": "tur_Latn", - "Twi": "twi_Latn", - "Central Atlas Tamazight": "tzm_Tfng", - "Uyghur": "uig_Arab", - "Ukrainian": "ukr_Cyrl", - "Umbundu": "umb_Latn", - "Urdu": "urd_Arab", - "Northern Uzbek": "uzn_Latn", - "Venetian": "vec_Latn", - "Vietnamese": "vie_Latn", - "Waray": "war_Latn", - "Wolof": "wol_Latn", - "Xhosa": "xho_Latn", - "Eastern Yiddish": "ydd_Hebr", - "Yoruba": "yor_Latn", - "Yue Chinese": "yue_Hant", - "Chinese Simplified": "zho_Hans", - "Chinese Traditional": "zho_Hant", - "Standard Malay": "zsm_Latn", - "Zulu": "zul_Latn", -} - - -class TranslationTool(PipelineTool): - """ - Example: - - ```py - from transformers.agents import TranslationTool - - translator = TranslationTool() - translator("This is a super nice API!", src_lang="English", tgt_lang="French") - ``` - """ - - lang_to_code = LANGUAGE_CODES - default_checkpoint = "facebook/nllb-200-distilled-600M" - description = ( - "This is a tool that translates text from a language to another." - f"Both `src_lang`and `tgt_lang` should belong to this list of languages: {list(lang_to_code.keys())}." - ) - name = "translator" - pre_processor_class = AutoTokenizer - model_class = AutoModelForSeq2SeqLM - - inputs = { - "text": {"type": "string", "description": "The text to translate"}, - "src_lang": { - "type": "string", - "description": "The language of the text to translate. Written in plain English, such as 'Romanian', or 'Albanian'", - }, - "tgt_lang": { - "type": "string", - "description": "The language for the desired ouput language. Written in plain English, such as 'Romanian', or 'Albanian'", - }, - } - output_type = "string" - - def encode(self, text, src_lang, tgt_lang): - if src_lang not in self.lang_to_code: - raise ValueError(f"{src_lang} is not a supported language.") - if tgt_lang not in self.lang_to_code: - raise ValueError(f"{tgt_lang} is not a supported language.") - src_lang = self.lang_to_code[src_lang] - tgt_lang = self.lang_to_code[tgt_lang] - return self.pre_processor._build_translation_inputs( - text, return_tensors="pt", src_lang=src_lang, tgt_lang=tgt_lang - ) - - def forward(self, inputs): - return self.model.generate(**inputs) - - def decode(self, outputs): - return self.post_processor.decode(outputs[0].tolist(), skip_special_tokens=True) diff --git a/examples/agent_with_tools.py b/examples/agent_with_tools.py index 8b179c9..bd6f98e 100644 --- a/examples/agent_with_tools.py +++ b/examples/agent_with_tools.py @@ -1,10 +1,10 @@ from agents import load_tool, CodeAgent, HfApiEngine +from agents.search import DuckDuckGoSearchTool # Import tool from Hub image_generation_tool = load_tool("m-ric/text-to-image", cache=False) # Import tool from LangChain -from agents.search import DuckDuckGoSearchTool search_tool = DuckDuckGoSearchTool() diff --git a/examples/gradio_example.py b/examples/gradio_example.py index 757de81..d159ff0 100644 --- a/examples/gradio_example.py +++ b/examples/gradio_example.py @@ -1,4 +1,5 @@ from agents import stream_to_gradio, HfApiEngine, load_tool, CodeAgent +import gradio as gr image_generation_tool = load_tool("m-ric/text-to-image") @@ -6,9 +7,6 @@ llm_engine = HfApiEngine("Qwen/Qwen2.5-72B-Instruct") agent = CodeAgent(tools=[image_generation_tool], llm_engine=llm_engine) -import gradio as gr - - def interact_with_agent(prompt, messages): messages.append(gr.ChatMessage(role="user", content=prompt)) yield messages diff --git a/examples/local_smollm.py b/examples/local_smollm.py index cdb3bc9..fe0ed48 100644 --- a/examples/local_smollm.py +++ b/examples/local_smollm.py @@ -1,14 +1,19 @@ -from agents.llm_engine import TransformersEngine -from agents import CodeAgent, JsonAgent - +from agents import JsonAgent +from agents import tool +import webbrowser import requests from datetime import datetime +import random +from llama_cpp import Llama +from agents import tool +import webbrowser +from typing import List, Generator, Dict, Any +import json +import re model_repo="andito/SmolLM2-1.7B-Instruct-F16-GGUF" model_filename="smollm2-1.7b-8k-dpo-f16.gguf" -import random -from llama_cpp import Llama model = Llama.from_pretrained( repo_id=model_repo, @@ -55,10 +60,6 @@ The example format is as follows. Please make sure the parameter type is correct ... (more tool calls as required) ]""" - -from agents import tool -import webbrowser - @tool def get_random_number_between(min: int, max: int) -> int: """ @@ -110,9 +111,7 @@ def open_webbrowser(url: str) -> str: webbrowser.open(url) return f"I opened {url.replace('https://', '').replace('www.', '')} in the browser." -from typing import List, Dict, Generator, Any -import re -import json +‹ def _parse_response(self, text: str) -> List[Dict[str, Any]]: pattern = r"(.*?)" matches = re.findall(pattern, text, re.DOTALL) diff --git a/examples/oneshot.py b/examples/oneshot.py index dd6333d..7e45151 100644 --- a/examples/oneshot.py +++ b/examples/oneshot.py @@ -1,4 +1,4 @@ -from agents import load_tool, CodeAgent, JsonAgent, HfApiEngine +from agents import load_tool, CodeAgent, HfApiEngine from agents.prompts import ONESHOT_CODE_SYSTEM_PROMPT # Import tool from Hub diff --git a/tests/fixtures/000000039769.png b/tests/fixtures/000000039769.png new file mode 100644 index 0000000..a3b5225 Binary files /dev/null and b/tests/fixtures/000000039769.png differ diff --git a/tests/test_agent_types.py b/tests/test_agent_types.py index 47d5c0a..c274f27 100644 --- a/tests/test_agent_types.py +++ b/tests/test_agent_types.py @@ -27,8 +27,6 @@ from transformers.testing_utils import ( ) from transformers.utils import ( is_soundfile_availble, - is_torch_available, - is_vision_available, ) import torch @@ -93,7 +91,7 @@ class AgentImageTests(unittest.TestCase): self.assertTrue(os.path.exists(path)) def test_from_string(self): - path = Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png" + path = Path(get_tests_dir("fixtures/")) / "000000039769.png" image = Image.open(path) agent_type = AgentImage(path) @@ -105,7 +103,7 @@ class AgentImageTests(unittest.TestCase): self.assertTrue(os.path.exists(path)) def test_from_image(self): - path = Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png" + path = Path(get_tests_dir("fixtures/")) / "000000039769.png" image = Image.open(path) agent_type = AgentImage(image) diff --git a/tests/test_agents.py b/tests/test_agents.py index 89668c8..eaa0f56 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -22,7 +22,6 @@ import pytest from agents.agent_types import AgentText from agents.agents import ( AgentMaxIterationsError, - CodeAgent, ManagedAgent, CodeAgent, JsonAgent, @@ -162,14 +161,14 @@ class AgentTests(unittest.TestCase): output = agent.run("What is 2 multiplied by 3.6452?") assert isinstance(output, str) assert output == "7.2904" - assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?" - assert agent.logs[1]["observation"] == "7.2904" + assert agent.logs[1].task == "What is 2 multiplied by 3.6452?" + assert agent.logs[2].observation == "7.2904" assert ( - agent.logs[1]["rationale"].strip() + agent.logs[2].rationale.strip() == "Thought: I should multiply 2 by 3.6452. special_marker" ) assert ( - agent.logs[2]["llm_output"] + agent.logs[3].llm_output == """ Thought: I can now answer the initial question Action: @@ -187,8 +186,8 @@ Action: output = agent.run("What is 2 multiplied by 3.6452?") assert isinstance(output, float) assert output == 7.2904 - assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?" - assert agent.logs[2]["tool_call"] == { + assert agent.logs[1].task == "What is 2 multiplied by 3.6452?" + assert agent.logs[3].tool_call == { "tool_arguments": "final_answer(7.2904)", "tool_name": "code interpreter", } @@ -212,10 +211,9 @@ Action: max_iterations=5, ) agent.run("What is 2 multiplied by 3.6452?") - assert len(agent.logs) == 7 - assert type(agent.logs[-1]["error"]) is AgentMaxIterationsError + assert len(agent.logs) == 8 + assert type(agent.logs[-1].error) is AgentMaxIterationsError - @require_torch def test_init_agent_with_different_toolsets(self): toolset_1 = [] agent = CodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm) @@ -245,8 +243,8 @@ Action: # check that python_interpreter base tool does not get added to code agents agent = CodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True) assert ( - len(agent.toolbox.tools) == 7 - ) # added final_answer tool + 6 base tools (excluding interpreter) + len(agent.toolbox.tools) == 2 + ) # added final_answer tool + search def test_function_persistence_across_steps(self): agent = CodeAgent( @@ -273,7 +271,8 @@ Action: managed_agents=[managed_agent], ) assert "You can also give requests to team members." not in agent.system_prompt - assert "<>" not in agent.system_prompt + print("ok1") + assert "{{managed_agents_descriptions}}" not in agent.system_prompt assert ( "You can also give requests to team members." in manager_agent.system_prompt ) diff --git a/tests/test_examples.py b/tests/test_examples.py index 55ddbe0..e14ce88 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -20,25 +20,9 @@ import tempfile import unittest from pathlib import Path from unittest import mock, skip +from typing import List -import torch - -from accelerate.test_utils.examples import compare_against_test -from accelerate.test_utils.testing import ( - TempDirTestCase, - get_launch_command, - require_huggingface_suite, - require_multi_device, - require_multi_gpu, - require_non_xpu, - require_pippy, - require_schedulefree, - require_trackers, - run_command, - slow, -) -from accelerate.utils import write_basic_config - +from .test_utils import slow, skip, get_launch_command, TempDirTestCase # DataLoaders built from `test_samples/MRPC` for quick testing # Should mock `{script_name}.get_dataloaders` via: @@ -62,242 +46,51 @@ EXCLUDE_EXAMPLES = [ "profiler.py", ] +import subprocess -class ExampleDifferenceTests(unittest.TestCase): + +class SubprocessCallException(Exception): + pass + + +def run_command(command: List[str], return_stdout=False, env=None): """ - This TestCase checks that all of the `complete_*` scripts contain all of the - information found in the `by_feature` scripts, line for line. If one fails, - then a complete example does not contain all of the features in the features - scripts, and should be updated. - - Each example script should be a single test (such as `test_nlp_example`), - and should run `one_complete_example` twice: once with `parser_only=True`, - and the other with `parser_only=False`. This is so that when the test - failures are returned to the user, they understand if the discrepancy lies in - the `main` function, or the `training_loop` function. Otherwise it will be - unclear. - - Also, if there are any expected differences between the base script used and - `complete_nlp_example.py` (the canonical base script), these should be included in - `special_strings`. These would be differences in how something is logged, print statements, - etc (such as calls to `Accelerate.log()`) + Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture + if an error occured while running `command` """ - - by_feature_path = Path("examples", "by_feature").resolve() - examples_path = Path("examples").resolve() - - def one_complete_example( - self, - complete_file_name: str, - parser_only: bool, - secondary_filename: str = None, - special_strings: list = None, - ): - """ - Tests a single `complete` example against all of the implemented `by_feature` scripts - - Args: - complete_file_name (`str`): - The filename of a complete example - parser_only (`bool`): - Whether to look at the main training function, or the argument parser - secondary_filename (`str`, *optional*): - A potential secondary base file to strip all script information not relevant for checking, - such as "cv_example.py" when testing "complete_cv_example.py" - special_strings (`list`, *optional*): - A list of strings to potentially remove before checking no differences are left. These should be - diffs that are file specific, such as different logging variations between files. - """ - self.maxDiff = None - for item in os.listdir(self.by_feature_path): - if item not in EXCLUDE_EXAMPLES: - item_path = self.by_feature_path / item - if item_path.is_file() and item_path.suffix == ".py": - with self.subTest( - tested_script=complete_file_name, - feature_script=item, - tested_section="main()" - if parser_only - else "training_function()", - ): - diff = compare_against_test( - self.examples_path / complete_file_name, - item_path, - parser_only, - secondary_filename, - ) - diff = "\n".join(diff) - if special_strings is not None: - for string in special_strings: - diff = diff.replace(string, "") - assert diff == "" - - def test_nlp_examples(self): - self.one_complete_example("complete_nlp_example.py", True) - self.one_complete_example("complete_nlp_example.py", False) - - def test_cv_examples(self): - cv_path = (self.examples_path / "cv_example.py").resolve() - special_strings = [ - " " * 16 + "{\n\n", - " " * 20 + '"accuracy": eval_metric["accuracy"],\n\n', - " " * 20 + '"f1": eval_metric["f1"],\n\n', - " " * 20 + '"train_loss": total_loss.item() / len(train_dataloader),\n\n', - " " * 20 + '"epoch": epoch,\n\n', - " " * 16 + "},\n\n", - " " * 16 + "step=epoch,\n", - " " * 12, - " " * 8 + "for step, batch in enumerate(active_dataloader):\n", - ] - self.one_complete_example( - "complete_cv_example.py", True, cv_path, special_strings - ) - self.one_complete_example( - "complete_cv_example.py", False, cv_path, special_strings - ) + # Cast every path in `command` to a string + for i, c in enumerate(command): + if isinstance(c, Path): + command[i] = str(c) + if env is None: + env = os.environ.copy() + try: + output = subprocess.check_output(command, stderr=subprocess.STDOUT, env=env) + if return_stdout: + if hasattr(output, "decode"): + output = output.decode("utf-8") + return output + except subprocess.CalledProcessError as e: + raise SubprocessCallException( + f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}" + ) from e -@mock.patch.dict(os.environ, {"TESTING_MOCKED_DATALOADERS": "1"}) -@require_huggingface_suite -class FeatureExamplesTests(TempDirTestCase): +class ExamplesTests(TempDirTestCase): clear_on_setup = False @classmethod def setUpClass(cls): super().setUpClass() cls._tmpdir = tempfile.mkdtemp() - cls.config_file = Path(cls._tmpdir) / "default_config.yml" - - write_basic_config(save_location=cls.config_file) - cls.launch_args = get_launch_command(config_file=cls.config_file) + cls.launch_args = ["python3"] @classmethod def tearDownClass(cls): super().tearDownClass() shutil.rmtree(cls._tmpdir) - def test_checkpointing_by_epoch(self): - testargs = f""" - examples/by_feature/checkpointing.py - --checkpointing_steps epoch - --output_dir {self.tmpdir} - """.split() - run_command(self.launch_args + testargs) - assert (self.tmpdir / "epoch_0").exists() - - def test_checkpointing_by_steps(self): - testargs = f""" - examples/by_feature/checkpointing.py - --checkpointing_steps 1 - --output_dir {self.tmpdir} - """.split() - _ = run_command(self.launch_args + testargs) - assert (self.tmpdir / "step_2").exists() - - def test_load_states_by_epoch(self): - testargs = f""" - examples/by_feature/checkpointing.py - --resume_from_checkpoint {self.tmpdir / "epoch_0"} - """.split() - output = run_command(self.launch_args + testargs, return_stdout=True) - assert "epoch 0:" not in output - assert "epoch 1:" in output - - def test_load_states_by_steps(self): - testargs = f""" - examples/by_feature/checkpointing.py - --resume_from_checkpoint {self.tmpdir / "step_2"} - """.split() - output = run_command(self.launch_args + testargs, return_stdout=True) - if torch.cuda.is_available(): - num_processes = torch.cuda.device_count() - else: - num_processes = 1 - if num_processes > 1: - assert "epoch 0:" not in output - assert "epoch 1:" in output - else: - assert "epoch 0:" in output - assert "epoch 1:" in output - - @slow - def test_cross_validation(self): - testargs = """ - examples/by_feature/cross_validation.py - --num_folds 2 - """.split() - with mock.patch.dict(os.environ, {"TESTING_MOCKED_DATALOADERS": "0"}): - output = run_command(self.launch_args + testargs, return_stdout=True) - results = re.findall("({.+})", output) - results = [r for r in results if "accuracy" in r][-1] - results = ast.literal_eval(results) - assert results["accuracy"] >= 0.75 - - def test_multi_process_metrics(self): - testargs = ["examples/by_feature/multi_process_metrics.py"] - run_command(self.launch_args + testargs) - - @require_schedulefree - def test_schedulefree(self): - testargs = ["examples/by_feature/schedule_free.py"] - run_command(self.launch_args + testargs) - - @require_trackers - @mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"}) - def test_tracking(self): - with tempfile.TemporaryDirectory() as tmpdir: - testargs = f""" - examples/by_feature/tracking.py - --with_tracking - --project_dir {tmpdir} - """.split() - run_command(self.launch_args + testargs) - assert os.path.exists(os.path.join(tmpdir, "tracking")) - - def test_gradient_accumulation(self): - testargs = ["examples/by_feature/gradient_accumulation.py"] - run_command(self.launch_args + testargs) - - def test_local_sgd(self): - testargs = ["examples/by_feature/local_sgd.py"] - run_command(self.launch_args + testargs) - - def test_early_stopping(self): - testargs = ["examples/by_feature/early_stopping.py"] - run_command(self.launch_args + testargs) - - def test_profiler(self): - testargs = ["examples/by_feature/profiler.py"] - run_command(self.launch_args + testargs) - - @require_multi_device - def test_ddp_comm_hook(self): - testargs = ["examples/by_feature/ddp_comm_hook.py", "--ddp_comm_hook", "fp16"] - run_command(self.launch_args + testargs) - - @skip( - reason="stable-diffusion-v1-5 is no longer available. Potentially `Comfy-Org/stable-diffusion-v1-5-archive` once diffusers support is added." - ) - @require_multi_device - def test_distributed_inference_examples_stable_diffusion(self): - testargs = ["examples/inference/distributed/stable_diffusion.py"] - run_command(self.launch_args + testargs) - - @require_multi_device - def test_distributed_inference_examples_phi2(self): - testargs = ["examples/inference/distributed/phi2.py"] - run_command(self.launch_args + testargs) - - @require_non_xpu - @require_pippy - @require_multi_gpu - def test_pippy_examples_bert(self): - testargs = ["examples/inference/pippy/bert.py"] - run_command(self.launch_args + testargs) - - @require_non_xpu - @require_pippy - @require_multi_gpu - def test_pippy_examples_gpt2(self): - testargs = ["examples/inference/pippy/gpt2.py"] + + def test_oneshot(self): + testargs = ["examples/oneshot.py"] run_command(self.launch_args + testargs) diff --git a/tests/test_final_answer.py b/tests/test_final_answer.py index 2da70c2..f406a80 100644 --- a/tests/test_final_answer.py +++ b/tests/test_final_answer.py @@ -48,7 +48,7 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin): inputs_text = {"answer": "Text input"} inputs_image = { "answer": Image.open( - Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png" + Path(get_tests_dir("fixtures")) / "000000039769.png" ).resize((512, 512)) } inputs_audio = {"answer": torch.Tensor(np.ones(3000))} diff --git a/tests/test_tools_common.py b/tests/test_tools_common.py index 289f2b4..2d887a7 100644 --- a/tests/test_tools_common.py +++ b/tests/test_tools_common.py @@ -50,7 +50,7 @@ def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]): inputs[input_name] = "Text input" elif input_type == "image": inputs[input_name] = Image.open( - Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png" + Path(get_tests_dir("fixtures")) / "000000039769.png" ).resize((512, 512)) elif input_type == "audio": inputs[input_name] = np.ones(3000) diff --git a/tests/test_translation.py b/tests/test_translation.py deleted file mode 100644 index 06d6996..0000000 --- a/tests/test_translation.py +++ /dev/null @@ -1,69 +0,0 @@ -# coding=utf-8 -# Copyright 2024 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -from transformers import load_tool -from transformers.agents.agent_types import AGENT_TYPE_MAPPING - -from .test_tools_common import ToolTesterMixin, output_type - - -class TranslationToolTester(unittest.TestCase, ToolTesterMixin): - def setUp(self): - self.tool = load_tool("translation") - self.tool.setup() - self.remote_tool = load_tool("translation", remote=True) - - def test_exact_match_arg(self): - result = self.tool("Hey, what's up?", src_lang="English", tgt_lang="French") - self.assertEqual(result, "- Hé, comment ça va?") - - def test_exact_match_kwarg(self): - result = self.tool( - text="Hey, what's up?", src_lang="English", tgt_lang="French" - ) - self.assertEqual(result, "- Hé, comment ça va?") - - def test_call(self): - inputs = ["Hey, what's up?", "English", "Spanish"] - output = self.tool(*inputs) - - self.assertEqual(output_type(output), self.tool.output_type) - - def test_agent_type_output(self): - inputs = ["Hey, what's up?", "English", "Spanish"] - output = self.tool(*inputs) - output_type = AGENT_TYPE_MAPPING[self.tool.output_type] - self.assertTrue(isinstance(output, output_type)) - - def test_agent_types_inputs(self): - example_inputs = { - "text": "Hey, what's up?", - "src_lang": "English", - "tgt_lang": "Spanish", - } - - _inputs = [] - for input_name in example_inputs.keys(): - example_input = example_inputs[input_name] - input_description = self.tool.inputs[input_name] - input_type = input_description["type"] - _inputs.append(AGENT_TYPE_MAPPING[input_type](example_input)) - - # Should not raise an error - output = self.tool(**example_inputs) - output_type = AGENT_TYPE_MAPPING[self.tool.output_type] - self.assertTrue(isinstance(output, output_type)) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..e39f851 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,100 @@ +import os +import unittest +import shutil +import tempfile + +from pathlib import Path +def str_to_bool(value) -> int: + """ + Converts a string representation of truth to `True` (1) or `False` (0). + + True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`; + """ + value = value.lower() + if value in ("y", "yes", "t", "true", "on", "1"): + return 1 + elif value in ("n", "no", "f", "false", "off", "0"): + return 0 + else: + raise ValueError(f"invalid truth value {value}") + + +def get_int_from_env(env_keys, default): + """Returns the first positive env value found in the `env_keys` list or the default.""" + for e in env_keys: + val = int(os.environ.get(e, -1)) + if val >= 0: + return val + return default + +def parse_flag_from_env(key, default=False): + """Returns truthy value for `key` from the env if available else the default.""" + value = os.environ.get(key, str(default)) + return str_to_bool(value) == 1 # As its name indicates `str_to_bool` actually returns an int... + + +_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) + + +def skip(test_case): + "Decorator that skips a test unconditionally" + return unittest.skip("Test was skipped")(test_case) + + +def slow(test_case): + """ + Decorator marking a test as slow. Slow tests are skipped by default. Set the RUN_SLOW environment variable to a + truthy value to run them. + """ + return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) + +def get_launch_command(**kwargs) -> list: + """ + Wraps around `kwargs` to help simplify launching from `subprocess`. + + Example: + ```python + # returns ['accelerate', 'launch', '--num_processes=2', '--device_count=2'] + get_launch_command(num_processes=2, device_count=2) + ``` + """ + command = ["accelerate", "launch"] + for k, v in kwargs.items(): + if isinstance(v, bool) and v: + command.append(f"--{k}") + elif v is not None: + command.append(f"--{k}={v}") + return command + + +class TempDirTestCase(unittest.TestCase): + """ + A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its + data at the start of a test, and then destroyes it at the end of the TestCase. + + Useful for when a class or API requires a single constant folder throughout it's use, such as Weights and Biases + + The temporary directory location will be stored in `self.tmpdir` + """ + + clear_on_setup = True + + @classmethod + def setUpClass(cls): + "Creates a `tempfile.TemporaryDirectory` and stores it in `cls.tmpdir`" + cls.tmpdir = Path(tempfile.mkdtemp()) + + @classmethod + def tearDownClass(cls): + "Remove `cls.tmpdir` after test suite has finished" + if os.path.exists(cls.tmpdir): + shutil.rmtree(cls.tmpdir) + + def setUp(self): + "Destroy all contents in `self.tmpdir`, but not `self.tmpdir`" + if self.clear_on_setup: + for path in self.tmpdir.glob("**/*"): + if path.is_file(): + path.unlink() + elif path.is_dir(): + shutil.rmtree(path) \ No newline at end of file