diff --git a/src/agents/agents.py b/src/agents/agents.py index 4003e40..ea11390 100644 --- a/src/agents/agents.py +++ b/src/agents/agents.py @@ -528,7 +528,7 @@ class ReactAgent(BaseAgent): self.logs.append(system_prompt_step) console.print(Group(Rule("[bold]New task", characters="="), Text(self.task))) - self.logs.append(TaskStep(task=task)) + self.logs.append(TaskStep(task=self.task)) if oneshot: step_start_time = time.time() @@ -541,9 +541,9 @@ class ReactAgent(BaseAgent): return result if stream: - return self.stream_run(task) + return self.stream_run(self.task) else: - return self.direct_run(task) + return self.direct_run(self.task) def stream_run(self, task: str): """ diff --git a/tests/test_agents.py b/tests/test_agents.py index 16b281f..539f4cf 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -230,6 +230,12 @@ Action: tool_arguments="final_answer(7.2904)", ) + def test_additional_args_added_to_task(self): + agent = CodeAgent(tools=[], llm_engine=fake_code_llm) + output = agent.run("What is 2 multiplied by 3.6452?", additional_instruction="Remember this.") + assert "Remember this" in agent.task + assert "Remember this" in str(agent.prompt_messages) + def test_reset_conversations(self): agent = CodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_code_llm) output = agent.run("What is 2 multiplied by 3.6452?", reset=True)