Spaces:
Running
Running
File size: 3,321 Bytes
dd39c08 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
from typing import Tuple
import playwright
import pytest
from browsergym.core.env import BrowserEnv
from browsergym.core.task import AbstractBrowserTask
class MockImageGoalTask(AbstractBrowserTask):
@classmethod
def get_task_id(cls):
return "mockimagegoal"
def __init__(self, seed: int = 0, start_url: str = "https://www.google.com") -> None:
"""
Args:
seed: random seed.
start_url: str, the url for the starting page.
goal: str, the initial goal.
"""
super().__init__(seed)
self.start_url = start_url
self.goal = [
{"type": "text", "text": "This is a mock task with an image goal."},
{
"type": "image_url",
"image_url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBapySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnxBwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXrCDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQDry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPsgxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96CutRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOMOVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWquaZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYSUb3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6EhOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oWVeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmHrwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66PfyuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UNz8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII=",
},
]
def setup(self, page: playwright.sync_api.Page) -> tuple[str, dict]:
page.goto(self.start_url, timeout=10000)
return self.goal, {}
def teardown(self) -> None:
pass
def validate(
self, page: playwright.sync_api.Page, chat_messages: list[str]
) -> Tuple[float, bool, str, dict]:
reward, done, msg, info = 0, False, "", {}
for message in chat_messages:
if message["role"] == "user" and message["message"] == "exit":
done = True
break
return reward, done, msg, info
def test_mock_image_goal_task():
env = BrowserEnv(MockImageGoalTask)
obs, _ = env.reset()
assert "goal_object" in obs
assert len(obs["goal_object"]) == 2
assert obs["goal_object"][0]["type"] == "text"
assert obs["goal_object"][0]["text"] == "This is a mock task with an image goal."
assert obs["goal_object"][1]["type"] == "image_url"
env.chat.add_message("user", "exit")
obs, reward, terminated, _, _ = env.step("send_msg_to_user('bye')")
assert reward == 0
assert terminated is True
env.close()
if __name__ == "__main__":
test_mock_image_goal_task()
|