import re import tempfile import logging import dataclasses from browsergym.core.action.highlevel import HighLevelActionSet from browsergym.experiments.agent import Agent from browsergym.experiments.loop import AbstractAgentArgs, EnvArgs, ExpArgs, get_exp_result from browsergym.utils.obs import flatten_axtree_to_str class MiniwobTestAgent(Agent): action_set = HighLevelActionSet(subsets="bid") def obs_preprocessor(self, obs: dict): return {"axtree_txt": flatten_axtree_to_str(obs["axtree_object"])} def get_action(self, obs: dict) -> tuple[str, dict]: match = re.search(r"^\s*\[(\d+)\].*button", obs["axtree_txt"], re.MULTILINE | re.IGNORECASE) if match: bid = match.group(1) action = f'click("{bid}")' else: raise Exception("Can't find the button's bid") return action, dict(think="I'm clicking the button as requested.") @dataclasses.dataclass class MiniwobTestAgentArgs(AbstractAgentArgs): def make_agent(self): return MiniwobTestAgent() def test_run_exp(): exp_args = ExpArgs( agent_args=MiniwobTestAgentArgs(), env_args=EnvArgs(task_name="miniwob.click-test", task_seed=42), ) with tempfile.TemporaryDirectory() as tmp_dir: exp_args.prepare(tmp_dir) exp_args.run() exp_result = get_exp_result(exp_args.exp_dir) exp_record = exp_result.get_exp_record() target = { "env_args.task_name": "miniwob.click-test", "env_args.task_seed": 42, "env_args.headless": True, "env_args.record_video": False, "n_steps": 1, "cum_reward": 1.0, "terminated": True, "truncated": False, } assert len(exp_result.steps_info) == 2 for key, target_val in target.items(): assert key in exp_record assert exp_record[key] == target_val # TODO investigate why it's taking almost 5 seconds to solve assert exp_record["stats.cum_step_elapsed"] < 5 if exp_record["stats.cum_step_elapsed"] > 3: t = exp_record["stats.cum_step_elapsed"] logging.warning( f"miniwob.click-test is taking {t:.2f}s (> 3s) to solve with an oracle." )