Pass more tests

This commit is contained in:
Aymeric 2025-01-06 22:04:00 +01:00
parent 01abe5921a
commit d45c63555f
4 changed files with 23 additions and 16 deletions

View File

@ -38,4 +38,5 @@ test = [
"sqlalchemy", "sqlalchemy",
"ruff>=0.5.0", "ruff>=0.5.0",
"accelerate", "accelerate",
"soundfile",
] ]

View File

@ -372,7 +372,9 @@ class MultiStepAgent:
except Exception as e: except Exception as e:
return f"Error in generating final LLM output:\n{e}" return f"Error in generating final LLM output:\n{e}"
def execute_tool_call(self, tool_name: str, arguments: Union[Dict[str, str], str]) -> Any: def execute_tool_call(
self, tool_name: str, arguments: Union[Dict[str, str], str]
) -> Any:
""" """
Execute tool with the provided input and returns the result. Execute tool with the provided input and returns the result.
This method replaces arguments with the actual values from the state if they refer to state variables. This method replaces arguments with the actual values from the state if they refer to state variables.
@ -515,7 +517,9 @@ You have been provided with these additional arguments, that you can access usin
self.planning_interval is not None self.planning_interval is not None
and step_number % self.planning_interval == 0 and step_number % self.planning_interval == 0
): ):
self.planning_step(task, is_first_step=(step_number == 0), step=step_number) self.planning_step(
task, is_first_step=(step_number == 0), step=step_number
)
console.print( console.print(
Rule(f"[bold]Step {step_number}", characters="", style=YELLOW_HEX) Rule(f"[bold]Step {step_number}", characters="", style=YELLOW_HEX)
) )
@ -562,7 +566,9 @@ You have been provided with these additional arguments, that you can access usin
self.planning_interval is not None self.planning_interval is not None
and step_number % self.planning_interval == 0 and step_number % self.planning_interval == 0
): ):
self.planning_step(task, is_first_step=(step_number == 0), step=step_number) self.planning_step(
task, is_first_step=(step_number == 0), step=step_number
)
console.print( console.print(
Rule(f"[bold]Step {step_number}", characters="", style=YELLOW_HEX) Rule(f"[bold]Step {step_number}", characters="", style=YELLOW_HEX)
) )

View File

@ -364,9 +364,7 @@ class AgentTests(unittest.TestCase):
def test_multiagents(self): def test_multiagents(self):
class FakeModelMultiagentsManagerAgent: class FakeModelMultiagentsManagerAgent:
def __call__( def __call__(self, messages, stop_sequences=None, grammar=None):
self, messages, stop_sequences=None, grammar=None
):
if len(messages) < 3: if len(messages) < 3:
return """ return """
Thought: Let's call our search agent. Thought: Let's call our search agent.
@ -397,10 +395,11 @@ final_answer("Final report.")
else: else:
assert "Report on the current US president" in str(messages) assert "Report on the current US president" in str(messages)
return ( return (
"final_answer", "final_answer",
"Final report.", "Final report.",
"call_0", "call_0",
) )
manager_model = FakeModelMultiagentsManagerAgent() manager_model = FakeModelMultiagentsManagerAgent()
class FakeModelMultiagentsManagedAgent: class FakeModelMultiagentsManagedAgent:
@ -412,6 +411,7 @@ final_answer("Final report.")
{"report": "Report on the current US president"}, {"report": "Report on the current US president"},
"call_0", "call_0",
) )
managed_model = FakeModelMultiagentsManagedAgent() managed_model = FakeModelMultiagentsManagedAgent()
web_agent = ToolCallingAgent( web_agent = ToolCallingAgent(

View File

@ -111,10 +111,10 @@ class TestDocs:
code_blocks = self.extractor.extract_python_code(content) code_blocks = self.extractor.extract_python_code(content)
excluded_snippets = [ excluded_snippets = [
"ToolCollection", "ToolCollection",
"image_generation_tool", # We don't want to run this expensive operation "image_generation_tool", # We don't want to run this expensive operation
"from_langchain", # Langchain is not a dependency "from_langchain", # Langchain is not a dependency
"while llm_should_continue(memory):", # This is pseudo code "while llm_should_continue(memory):", # This is pseudo code
"ollama_chat/llama3.2" # Exclude ollama building in guided tour "ollama_chat/llama3.2", # Exclude ollama building in guided tour
] ]
code_blocks = [ code_blocks = [
block block