Solve additional args not being passed to task
This commit is contained in:
parent
b38d842c2d
commit
ba87dd98c8
|
@ -528,7 +528,7 @@ class ReactAgent(BaseAgent):
|
|||
self.logs.append(system_prompt_step)
|
||||
|
||||
console.print(Group(Rule("[bold]New task", characters="="), Text(self.task)))
|
||||
self.logs.append(TaskStep(task=task))
|
||||
self.logs.append(TaskStep(task=self.task))
|
||||
|
||||
if oneshot:
|
||||
step_start_time = time.time()
|
||||
|
@ -541,9 +541,9 @@ class ReactAgent(BaseAgent):
|
|||
return result
|
||||
|
||||
if stream:
|
||||
return self.stream_run(task)
|
||||
return self.stream_run(self.task)
|
||||
else:
|
||||
return self.direct_run(task)
|
||||
return self.direct_run(self.task)
|
||||
|
||||
def stream_run(self, task: str):
|
||||
"""
|
||||
|
|
|
@ -230,6 +230,12 @@ Action:
|
|||
tool_arguments="final_answer(7.2904)",
|
||||
)
|
||||
|
||||
def test_additional_args_added_to_task(self):
|
||||
agent = CodeAgent(tools=[], llm_engine=fake_code_llm)
|
||||
output = agent.run("What is 2 multiplied by 3.6452?", additional_instruction="Remember this.")
|
||||
assert "Remember this" in agent.task
|
||||
assert "Remember this" in str(agent.prompt_messages)
|
||||
|
||||
def test_reset_conversations(self):
|
||||
agent = CodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_code_llm)
|
||||
output = agent.run("What is 2 multiplied by 3.6452?", reset=True)
|
||||
|
|
Loading…
Reference in New Issue