This commit is contained in:
Aymeric 2024-12-27 16:18:19 +01:00
parent 710fb75559
commit c880f2f5b6
13 changed files with 115 additions and 48 deletions

View File

@ -122,9 +122,9 @@ def sql_engine(query: str) -> str:
Now let us create an agent that leverages this tool.
We use the CodeAgent, which is transformers.agents main agent class: an agent that writes actions in code and can iterate on previous output according to the ReAct framework.
We use the `CodeAgent`, which is transformers.agents main agent class: an agent that writes actions in code and can iterate on previous output according to the ReAct framework.
The model is the LLM that powers the agent system. HfModel allows you to call LLMs using HFs Inference API, either via Serverless or Dedicated endpoint, but you could also use any proprietary API.
The model is the LLM that powers the agent system. HfApiModel allows you to call LLMs using HFs Inference API, either via Serverless or Dedicated endpoint, but you could also use any proprietary API.
```py
from smolagents import CodeAgent, HfApiModel
@ -180,14 +180,14 @@ for table in ["receipts", "waiters"]:
print(updated_description)
```
Since this request is a bit harder than the previous one, well switch the LLM engine to use the more powerful [Qwen/Qwen2.5-72B-Instruct](https://huggingface.co/Qwen/Qwen2.5-72B-Instruct)!
Since this request is a bit harder than the previous one, well switch the LLM engine to use the more powerful [Qwen/Qwen2.5-Coder-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-Coder-32B-Instruct)!
```py
sql_engine.description = updated_description
agent = CodeAgent(
tools=[sql_engine],
model=HfApiModel("Qwen/Qwen2.5-72B-Instruct"),
model=HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct"),
)
agent.run("Which waiter got more total money from tips?")

View File

@ -19,7 +19,6 @@ dependencies = [
"pandas>=2.2.3",
"jinja2>=3.1.4",
"pillow>=11.0.0",
"llama-cpp-python>=0.3.4",
"markdownify>=0.14.1",
"gradio>=5.8.0",
"duckduckgo-search>=6.3.7",
@ -30,9 +29,6 @@ dependencies = [
]
[project.optional-dependencies]
dev = [
"anthropic",
]
test = [
"gradio-tools"
]

View File

@ -389,12 +389,16 @@ class MultiStepAgent:
try:
if isinstance(arguments, str):
observation = available_tools[tool_name].__call__(arguments, sanitize_inputs_outputs=True)
observation = available_tools[tool_name].__call__(
arguments, sanitize_inputs_outputs=True
)
elif isinstance(arguments, dict):
for key, value in arguments.items():
if isinstance(value, str) and value in self.state:
arguments[key] = self.state[value]
observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True)
observation = available_tools[tool_name].__call__(
**arguments, sanitize_inputs_outputs=True
)
else:
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
raise AgentExecutionError(error_msg)
@ -774,10 +778,14 @@ class ToolCallingAgent(MultiStepAgent):
isinstance(answer, str) and answer in self.state.keys()
): # if the answer is a state variable, return the value
final_answer = self.state[answer]
console.print(f"[bold {YELLOW_HEX}]Final answer:[/bold {YELLOW_HEX}] Extracting key '{answer}' from state to return value '{final_answer}'.")
console.print(
f"[bold {YELLOW_HEX}]Final answer:[/bold {YELLOW_HEX}] Extracting key '{answer}' from state to return value '{final_answer}'."
)
else:
final_answer = answer
console.print(Text(f"Final answer: {final_answer}", style=f"bold {YELLOW_HEX}"))
console.print(
Text(f"Final answer: {final_answer}", style=f"bold {YELLOW_HEX}")
)
log_entry.action_output = final_answer
return final_answer
@ -891,7 +899,12 @@ class CodeAgent(MultiStepAgent):
align="left",
style="orange",
),
Syntax(llm_output, lexer="markdown", theme="github-dark", word_wrap=True),
Syntax(
llm_output,
lexer="markdown",
theme="github-dark",
word_wrap=True,
),
)
)

View File

@ -163,10 +163,12 @@ class DuckDuckGoSearchTool(Tool):
)
self.ddgs = DDGS()
def forward(self, query: str) -> str:
results = self.ddgs.text(query, max_results=10)
postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results]
postprocessed_results = [
f"[{result['title']}]({result['href']})\n{result['body']}"
for result in results
]
return "## Search Results\n\n" + "\n\n".join(postprocessed_results)
@ -301,7 +303,12 @@ class SpeechToTextTool(PipelineTool):
pre_processor_class = WhisperProcessor
model_class = WhisperForConditionalGeneration
inputs = {"audio": {"type": "audio", "description": "The audio to transcribe. Can be a local path, an url, or a tensor."}}
inputs = {
"audio": {
"type": "audio",
"description": "The audio to transcribe. Can be a local path, an url, or a tensor.",
}
}
output_type = "string"
def encode(self, audio):

View File

@ -110,7 +110,6 @@ locals().update({key: value for key, value in pickle_dict.items()})
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
console.print(execution_logs)
execution = self.run_code_raise_errors(code_action)
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
if not execution.results:

View File

@ -129,7 +129,8 @@ def get_clean_message_list(
final_message_list.append(message)
return final_message_list
class Model():
class Model:
def __init__(self):
self.last_input_token_count = None
self.last_output_token_count = None
@ -313,9 +314,16 @@ class TransformersModel(Model):
self.stream = ""
def __call__(self, input_ids, scores, **kwargs):
generated = self.tokenizer.decode(input_ids[0][-1], skip_special_tokens=True)
generated = self.tokenizer.decode(
input_ids[0][-1], skip_special_tokens=True
)
self.stream += generated
if any([self.stream.endswith(stop_string) for stop_string in self.stop_strings]):
if any(
[
self.stream.endswith(stop_string)
for stop_string in self.stop_strings
]
):
return True
return False
@ -458,7 +466,7 @@ __all__ = [
"MessageRole",
"tool_role_conversions",
"get_clean_message_list",
"HfModel",
"Model",
"TransformersModel",
"HfApiModel",
"LiteLLMModel",

View File

@ -32,7 +32,7 @@ class Monitor:
def get_total_token_counts(self):
return {
"input": self.total_input_token_count,
"output": self.total_output_token_count
"output": self.total_output_token_count,
}
def reset(self):

View File

@ -1,6 +1,5 @@
import ast
import inspect
import importlib.util
import builtins
from typing import Set
import textwrap

View File

@ -108,6 +108,7 @@ def validate_after_init(cls):
cls.__init__ = new_init
return cls
def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
type_hints = get_type_hints(func)
signature = inspect.signature(func)
@ -119,10 +120,13 @@ def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
properties[param_name]["nullable"] = True
for param_name in signature.parameters.keys():
if signature.parameters[param_name].default != inspect.Parameter.empty:
if param_name not in properties: # this can happen if the param has no type hint but a default value
if (
param_name not in properties
): # this can happen if the param has no type hint but a default value
properties[param_name] = {"nullable": True}
return properties
AUTHORIZED_TYPES = [
"string",
"boolean",
@ -202,7 +206,10 @@ class Tool:
assert getattr(self, "output_type", None) in AUTHORIZED_TYPES
# Validate forward function signature, except for PipelineTool
if not (hasattr(self, "is_pipeline_tool") and getattr(self, "is_pipeline_tool") is True):
if not (
hasattr(self, "is_pipeline_tool")
and getattr(self, "is_pipeline_tool") is True
):
signature = inspect.signature(self.forward)
if not set(signature.parameters.keys()) == set(self.inputs.keys()):
@ -213,9 +220,13 @@ class Tool:
json_schema = _convert_type_hints_to_json_schema(self.forward)
for key, value in self.inputs.items():
if "nullable" in value:
assert (key in json_schema and "nullable" in json_schema[key]), f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature."
assert (
key in json_schema and "nullable" in json_schema[key]
), f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature."
if key in json_schema and "nullable" in json_schema[key]:
assert "nullable" in value, f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs."
assert (
"nullable" in value
), f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs."
def forward(self, *args, **kwargs):
return NotImplementedError("Write this method in your subclass of `Tool`.")

View File

@ -249,7 +249,11 @@ class AgentAudio(AgentType, str):
AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio}
INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage, torch.Tensor: AgentAudio}
INSTANCE_TYPE_MAPPING = {
str: AgentText,
ImageType: AgentImage,
torch.Tensor: AgentAudio,
}
if is_torch_available():
INSTANCE_TYPE_MAPPING[Tensor] = AgentAudio

View File

@ -106,6 +106,7 @@ final_answer("got an error")
```<end_code>
"""
def fake_code_model_syntax_error(messages, stop_sequences=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
@ -255,12 +256,13 @@ class AgentTests(unittest.TestCase):
assert "Code execution failed at line 'print = 2' because of" in str(agent.logs)
def test_code_agent_syntax_error_show_offending_lines(self):
agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_syntax_error)
agent = CodeAgent(
tools=[PythonInterpreterTool()], model=fake_code_model_syntax_error
)
output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, AgentText)
assert output == "got an error"
assert " print(\"Failing due to unexpected indent\")" in str(agent.logs)
assert ' print("Failing due to unexpected indent")' in str(agent.logs)
def test_setup_agent_with_empty_toolbox(self):
ToolCallingAgent(model=FakeToolCallModel(), tools=[])

View File

@ -16,6 +16,7 @@ import unittest
from smolagents import models, tool
from typing import Optional
class ModelTests(unittest.TestCase):
def test_get_json_schema_has_nullable_args(self):
@tool
@ -29,4 +30,10 @@ class ModelTests(unittest.TestCase):
celsius: the temperature type
"""
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
assert "nullable" in models.get_json_schema(get_weather)["function"]["parameters"]["properties"]["celsius"]
assert (
"nullable"
in models.get_json_schema(get_weather)["function"]["parameters"][
"properties"
]["celsius"]
)

View File

@ -286,18 +286,24 @@ class ToolTests(unittest.TestCase):
def test_tool_missing_class_attributes_raises_error(self):
with pytest.raises(Exception) as e:
class GetWeatherTool(Tool):
name = "get_weather"
description = "Get weather in the next days at given location."
inputs = {
"location": {"type": "string", "description": "the location"},
"celsius": {"type": "string", "description": "the temperature type"}
"celsius": {
"type": "string",
"description": "the temperature type",
},
}
def forward(self, location: str, celsius: Optional[bool] = False) -> str:
def forward(
self, location: str, celsius: Optional[bool] = False
) -> str:
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
tool = GetWeatherTool()
GetWeatherTool()
assert "You must set an attribute output_type" in str(e)
def test_tool_from_decorator_optional_args(self):
@ -312,56 +318,71 @@ class ToolTests(unittest.TestCase):
celsius: the temperature type
"""
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
assert "nullable" in get_weather.inputs["celsius"]
assert get_weather.inputs["celsius"]["nullable"] == True
assert get_weather.inputs["celsius"]["nullable"]
assert "nullable" not in get_weather.inputs["location"]
def test_tool_mismatching_nullable_args_raises_error(self):
with pytest.raises(Exception) as e:
class GetWeatherTool(Tool):
name = "get_weather"
description = "Get weather in the next days at given location."
inputs = {
"location": {"type": "string", "description": "the location"},
"celsius": {"type": "string", "description": "the temperature type"}
"celsius": {
"type": "string",
"description": "the temperature type",
},
}
output_type = "string"
def forward(self, location: str, celsius: Optional[bool] = False) -> str:
def forward(
self, location: str, celsius: Optional[bool] = False
) -> str:
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
tool = GetWeatherTool()
GetWeatherTool()
assert "Nullable" in str(e)
with pytest.raises(Exception) as e:
class GetWeatherTool2(Tool):
name = "get_weather"
description = "Get weather in the next days at given location."
inputs = {
"location": {"type": "string", "description": "the location"},
"celsius": {"type": "string", "description": "the temperature type"}
"celsius": {
"type": "string",
"description": "the temperature type",
},
}
output_type = "string"
def forward(self, location: str, celsius: bool = False) -> str:
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
tool = GetWeatherTool2()
GetWeatherTool2()
assert "Nullable" in str(e)
with pytest.raises(Exception) as e:
class GetWeatherTool3(Tool):
name = "get_weather"
description = "Get weather in the next days at given location."
inputs = {
"location": {"type": "string", "description": "the location"},
"celsius": {"type": "string", "description": "the temperature type", "nullable": True}
"celsius": {
"type": "string",
"description": "the temperature type",
"nullable": True,
},
}
output_type = "string"
def forward(self, location, celsius: str) -> str:
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
tool = GetWeatherTool3()
GetWeatherTool3()
assert "Nullable" in str(e)