Pass more tests

This commit is contained in:
Aymeric 2025-01-06 22:04:00 +01:00
parent 01abe5921a
commit d45c63555f
4 changed files with 23 additions and 16 deletions

View File

@ -38,4 +38,5 @@ test = [
"sqlalchemy",
"ruff>=0.5.0",
"accelerate",
"soundfile",
]

View File

@ -372,7 +372,9 @@ class MultiStepAgent:
except Exception as e:
return f"Error in generating final LLM output:\n{e}"
def execute_tool_call(self, tool_name: str, arguments: Union[Dict[str, str], str]) -> Any:
def execute_tool_call(
self, tool_name: str, arguments: Union[Dict[str, str], str]
) -> Any:
"""
Execute tool with the provided input and returns the result.
This method replaces arguments with the actual values from the state if they refer to state variables.
@ -515,7 +517,9 @@ You have been provided with these additional arguments, that you can access usin
self.planning_interval is not None
and step_number % self.planning_interval == 0
):
self.planning_step(task, is_first_step=(step_number == 0), step=step_number)
self.planning_step(
task, is_first_step=(step_number == 0), step=step_number
)
console.print(
Rule(f"[bold]Step {step_number}", characters="", style=YELLOW_HEX)
)
@ -562,7 +566,9 @@ You have been provided with these additional arguments, that you can access usin
self.planning_interval is not None
and step_number % self.planning_interval == 0
):
self.planning_step(task, is_first_step=(step_number == 0), step=step_number)
self.planning_step(
task, is_first_step=(step_number == 0), step=step_number
)
console.print(
Rule(f"[bold]Step {step_number}", characters="", style=YELLOW_HEX)
)

View File

@ -364,9 +364,7 @@ class AgentTests(unittest.TestCase):
def test_multiagents(self):
class FakeModelMultiagentsManagerAgent:
def __call__(
self, messages, stop_sequences=None, grammar=None
):
def __call__(self, messages, stop_sequences=None, grammar=None):
if len(messages) < 3:
return """
Thought: Let's call our search agent.
@ -401,6 +399,7 @@ final_answer("Final report.")
"Final report.",
"call_0",
)
manager_model = FakeModelMultiagentsManagerAgent()
class FakeModelMultiagentsManagedAgent:
@ -412,6 +411,7 @@ final_answer("Final report.")
{"report": "Report on the current US president"},
"call_0",
)
managed_model = FakeModelMultiagentsManagedAgent()
web_agent = ToolCallingAgent(

View File

@ -114,7 +114,7 @@ class TestDocs:
"image_generation_tool", # We don't want to run this expensive operation
"from_langchain", # Langchain is not a dependency
"while llm_should_continue(memory):", # This is pseudo code
"ollama_chat/llama3.2" # Exclude ollama building in guided tour
"ollama_chat/llama3.2", # Exclude ollama building in guided tour
]
code_blocks = [
block