Fuse stream and direct run calls (#296)
* fuse stream and non stream calls
This commit is contained in:
parent
428aedde93
commit
2c43546d3c
|
@ -15,9 +15,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
||||
|
||||
from rich import box
|
||||
from rich.console import Console, Group
|
||||
|
@ -498,13 +499,17 @@ You have been provided with these additional arguments, that you can access usin
|
|||
return result
|
||||
|
||||
if stream:
|
||||
return self.stream_run(self.task)
|
||||
else:
|
||||
return self.direct_run(self.task)
|
||||
# The steps are returned as they are executed through a generator to iterate on.
|
||||
return self._run(task=self.task)
|
||||
# Outputs are returned only at the end as a string. We only look at the last step
|
||||
return deque(self._run(task=self.task), maxlen=1)[0]
|
||||
|
||||
def stream_run(self, task: str):
|
||||
def _run(self, task: str) -> Generator[str, None, None]:
|
||||
"""
|
||||
Runs the agent in streaming mode, yielding steps as they are executed: should be launched only in the `run` method.
|
||||
Runs the agent in streaming mode and returns a generator of all the steps.
|
||||
|
||||
Args:
|
||||
task (`str`): The task to perform.
|
||||
"""
|
||||
final_answer = None
|
||||
self.step_number = 0
|
||||
|
@ -555,59 +560,7 @@ You have been provided with these additional arguments, that you can access usin
|
|||
|
||||
yield handle_agent_output_types(final_answer)
|
||||
|
||||
def direct_run(self, task: str):
|
||||
"""
|
||||
Runs the agent in direct mode, returning outputs only at the end: should be launched only in the `run` method.
|
||||
"""
|
||||
final_answer = None
|
||||
self.step_number = 0
|
||||
while final_answer is None and self.step_number < self.max_steps:
|
||||
step_start_time = time.time()
|
||||
step_log = ActionStep(step=self.step_number, start_time=step_start_time)
|
||||
try:
|
||||
if self.planning_interval is not None and self.step_number % self.planning_interval == 0:
|
||||
self.planning_step(
|
||||
task,
|
||||
is_first_step=(self.step_number == 0),
|
||||
step=self.step_number,
|
||||
)
|
||||
self.logger.log(
|
||||
Rule(
|
||||
f"[bold]Step {self.step_number}",
|
||||
characters="━",
|
||||
style=YELLOW_HEX,
|
||||
),
|
||||
level=LogLevel.INFO,
|
||||
)
|
||||
|
||||
# Run one step!
|
||||
final_answer = self.step(step_log)
|
||||
|
||||
except AgentError as e:
|
||||
step_log.error = e
|
||||
finally:
|
||||
step_end_time = time.time()
|
||||
step_log.end_time = step_end_time
|
||||
step_log.duration = step_end_time - step_start_time
|
||||
self.logs.append(step_log)
|
||||
for callback in self.step_callbacks:
|
||||
callback(step_log)
|
||||
self.step_number += 1
|
||||
|
||||
if final_answer is None and self.step_number == self.max_steps:
|
||||
error_message = "Reached max steps."
|
||||
final_step_log = ActionStep(error=AgentMaxStepsError(error_message))
|
||||
self.logs.append(final_step_log)
|
||||
final_answer = self.provide_final_answer(task)
|
||||
self.logger.log(Text(f"Final answer: {final_answer}"), level=LogLevel.INFO)
|
||||
final_step_log.action_output = final_answer
|
||||
final_step_log.duration = 0
|
||||
for callback in self.step_callbacks:
|
||||
callback(final_step_log)
|
||||
|
||||
return handle_agent_output_types(final_answer)
|
||||
|
||||
def planning_step(self, task, is_first_step: bool, step: int):
|
||||
def planning_step(self, task, is_first_step: bool, step: int) -> None:
|
||||
"""
|
||||
Used periodically by the agent to plan the next steps to reach the objective.
|
||||
|
||||
|
|
Loading…
Reference in New Issue