Pass more tests

This commit is contained in:
Aymeric 2025-01-06 22:04:00 +01:00
parent 01abe5921a
commit d45c63555f
4 changed files with 23 additions and 16 deletions

View File

@ -38,4 +38,5 @@ test = [
"sqlalchemy",
"ruff>=0.5.0",
"accelerate",
"soundfile",
]

View File

@ -372,7 +372,9 @@ class MultiStepAgent:
except Exception as e:
return f"Error in generating final LLM output:\n{e}"
def execute_tool_call(self, tool_name: str, arguments: Union[Dict[str, str], str]) -> Any:
def execute_tool_call(
self, tool_name: str, arguments: Union[Dict[str, str], str]
) -> Any:
"""
Execute tool with the provided input and returns the result.
This method replaces arguments with the actual values from the state if they refer to state variables.
@ -515,7 +517,9 @@ You have been provided with these additional arguments, that you can access usin
self.planning_interval is not None
and step_number % self.planning_interval == 0
):
self.planning_step(task, is_first_step=(step_number == 0), step=step_number)
self.planning_step(
task, is_first_step=(step_number == 0), step=step_number
)
console.print(
Rule(f"[bold]Step {step_number}", characters="", style=YELLOW_HEX)
)
@ -562,7 +566,9 @@ You have been provided with these additional arguments, that you can access usin
self.planning_interval is not None
and step_number % self.planning_interval == 0
):
self.planning_step(task, is_first_step=(step_number == 0), step=step_number)
self.planning_step(
task, is_first_step=(step_number == 0), step=step_number
)
console.print(
Rule(f"[bold]Step {step_number}", characters="", style=YELLOW_HEX)
)

View File

@ -62,7 +62,7 @@ class FakeToolCallModelImage:
else: # We're at step 2
return "final_answer", "image.png", "call_1"
def fake_code_model(messages, stop_sequences=None, grammar=None) -> str:
prompt = str(messages)
@ -364,9 +364,7 @@ class AgentTests(unittest.TestCase):
def test_multiagents(self):
class FakeModelMultiagentsManagerAgent:
def __call__(
self, messages, stop_sequences=None, grammar=None
):
def __call__(self, messages, stop_sequences=None, grammar=None):
if len(messages) < 3:
return """
Thought: Let's call our search agent.
@ -397,10 +395,11 @@ final_answer("Final report.")
else:
assert "Report on the current US president" in str(messages)
return (
"final_answer",
"Final report.",
"call_0",
)
"final_answer",
"Final report.",
"call_0",
)
manager_model = FakeModelMultiagentsManagerAgent()
class FakeModelMultiagentsManagedAgent:
@ -412,6 +411,7 @@ final_answer("Final report.")
{"report": "Report on the current US president"},
"call_0",
)
managed_model = FakeModelMultiagentsManagedAgent()
web_agent = ToolCallingAgent(
@ -435,7 +435,7 @@ final_answer("Final report.")
report = manager_code_agent.run("Fake question.")
assert report == "Final report."
manager_toolcalling_agent = ToolCallingAgent(
tools=[],
model=manager_model,

View File

@ -111,10 +111,10 @@ class TestDocs:
code_blocks = self.extractor.extract_python_code(content)
excluded_snippets = [
"ToolCollection",
"image_generation_tool", # We don't want to run this expensive operation
"from_langchain", # Langchain is not a dependency
"while llm_should_continue(memory):", # This is pseudo code
"ollama_chat/llama3.2" # Exclude ollama building in guided tour
"image_generation_tool", # We don't want to run this expensive operation
"from_langchain", # Langchain is not a dependency
"while llm_should_continue(memory):", # This is pseudo code
"ollama_chat/llama3.2", # Exclude ollama building in guided tour
]
code_blocks = [
block