Fix `vlm_web_browser.py` example (#410)

* Update `vlm_web_browser.py` example
Fixed a typo which meant that the images were never being remove for lean processing

---------

Co-authored-by: Aymeric <aymeric.roucher@gmail.com>
This commit is contained in:
Abubakar Abid 2025-01-29 17:17:00 -08:00 committed by GitHub
parent dcbbe448af
commit c0abd2134e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 41 additions and 34 deletions

View File

@ -72,22 +72,24 @@ else:
# Prepare callback
def save_screenshot(step_log: ActionStep, agent: CodeAgent) -> None:
def save_screenshot(memory_step: ActionStep, agent: CodeAgent) -> None:
sleep(1.0) # Let JavaScript animations happen before taking the screenshot
driver = helium.get_driver()
current_step = step_log.step_number
current_step = memory_step.step_number
if driver is not None:
for step_logs in agent.logs: # Remove previous screenshots from logs for lean processing
if isinstance(step_log, ActionStep) and step_log.step_number <= current_step - 2:
step_logs.observations_images = None
for previous_memory_step in agent.memory.steps: # Remove previous screenshots from logs for lean processing
if isinstance(previous_memory_step, ActionStep) and previous_memory_step.step_number <= current_step - 2:
previous_memory_step.observations_images = None
png_bytes = driver.get_screenshot_as_png()
image = Image.open(BytesIO(png_bytes))
print(f"Captured a browser screenshot: {image.size} pixels")
step_log.observations_images = [image.copy()] # Create a copy to ensure it persists, important!
memory_step.observations_images = [image.copy()] # Create a copy to ensure it persists, important!
# Update observations with current URL
url_info = f"Current url: {driver.current_url}"
step_log.observations = url_info if step_logs.observations is None else step_log.observations + "\n" + url_info
memory_step.observations = (
url_info if memory_step.observations is None else memory_step.observations + "\n" + url_info
)
return

View File

@ -225,8 +225,8 @@ class MultiStepAgent:
the LLM.
"""
messages = self.memory.system_prompt.to_messages(summary_mode=summary_mode)
for step_log in self.memory.steps:
messages.extend(step_log.to_messages(summary_mode=summary_mode))
for memory_step in self.memory.steps:
messages.extend(memory_step.to_messages(summary_mode=summary_mode))
return messages
def extract_action(self, model_output: str, split_token: str) -> Tuple[str, str]:
@ -413,12 +413,12 @@ You have been provided with these additional arguments, that you can access usin
self.memory.steps.append(TaskStep(task=self.task, task_images=images))
if single_step:
step_start_time = time.time()
step_log = ActionStep(start_time=step_start_time, observations_images=images)
step_log.end_time = time.time()
step_log.duration = step_log.end_time - step_start_time
memory_step = ActionStep(start_time=step_start_time, observations_images=images)
memory_step.end_time = time.time()
memory_step.duration = memory_step.end_time - step_start_time
# Run the agent's step
result = self.step(step_log)
result = self.step(memory_step)
return result
if stream:
@ -439,7 +439,7 @@ You have been provided with these additional arguments, that you can access usin
self.step_number = 0
while final_answer is None and self.step_number < self.max_steps:
step_start_time = time.time()
step_log = ActionStep(
memory_step = ActionStep(
step_number=self.step_number,
start_time=step_start_time,
observations_images=images,
@ -461,41 +461,40 @@ You have been provided with these additional arguments, that you can access usin
)
# Run one step!
final_answer = self.step(step_log)
final_answer = self.step(memory_step)
except AgentError as e:
step_log.error = e
memory_step.error = e
finally:
step_log.end_time = time.time()
step_log.duration = step_log.end_time - step_start_time
self.memory.steps.append(step_log)
memory_step.end_time = time.time()
memory_step.duration = memory_step.end_time - step_start_time
self.memory.steps.append(memory_step)
for callback in self.step_callbacks:
# For compatibility with old callbacks that don't take the agent as an argument
if len(inspect.signature(callback).parameters) == 1:
callback(step_log)
callback(memory_step)
else:
callback(step_log=step_log, agent=self)
callback(memory_step, agent=self)
self.step_number += 1
yield step_log
yield memory_step
if final_answer is None and self.step_number == self.max_steps:
error_message = "Reached max steps."
final_answer = self.provide_final_answer(task, images)
final_step_log = ActionStep(
final_memory_step = ActionStep(
step_number=self.step_number, error=AgentMaxStepsError(error_message, self.logger)
)
self.logger.log(final_step_log)
final_step_log = ActionStep(error=AgentMaxStepsError(error_message, self.logger))
final_step_log.action_output = final_answer
final_step_log.end_time = time.time()
final_step_log.duration = step_log.end_time - step_start_time
self.memory.steps.append(final_step_log)
final_memory_step = ActionStep(error=AgentMaxStepsError(error_message, self.logger))
final_memory_step.action_output = final_answer
final_memory_step.end_time = time.time()
final_memory_step.duration = memory_step.end_time - step_start_time
self.memory.steps.append(final_memory_step)
for callback in self.step_callbacks:
# For compatibility with old callbacks that don't take the agent as an argument
if len(inspect.signature(callback).parameters) == 1:
callback(final_step_log)
callback(final_memory_step)
else:
callback(step_log=final_step_log, agent=self)
yield final_step_log
callback(final_memory_step, agent=self)
yield final_memory_step
yield handle_agent_output_types(final_answer)

View File

@ -105,7 +105,8 @@ class ActionStep(MemoryStep):
)
else:
tool_response_message = Message(
role=MessageRole.TOOL_RESPONSE, content=f"Call id: {self.tool_calls[0].id}\n{message_content}"
role=MessageRole.TOOL_RESPONSE,
content=[{"type": "text", "text": f"Call id: {self.tool_calls[0].id}\n{message_content}"}],
)
messages.append(tool_response_message)
@ -114,7 +115,12 @@ class ActionStep(MemoryStep):
messages.append(
Message(
role=MessageRole.TOOL_RESPONSE,
content=f"Call id: {self.tool_calls[0].id}\nObservation:\n{self.observations}",
content=[
{
"type": "text",
"text": f"Call id: {self.tool_calls[0].id}\nObservation:\n{self.observations}",
}
],
)
)
if self.observations_images: