stillerman's picture
stillerman HF Staff
docs minus assets
1e310b7
from game import AgentPlayer, SQLiteDB, Game
import os
import json
import asyncio
import argparse
class Proctor:
def __init__(
self,
article_list: list[tuple[str, str]],
num_trials: int,
num_workers: int,
max_steps: int,
agent_settings: dict,
db_path: str,
verbose: bool = True,
output_dir: str = "./proctor_tmp",
proctor_id: str = "proctor_1",
starting_seed: int = 42,
):
self.article_list = article_list
self.num_trials = num_trials
self.num_workers = num_workers
self.max_steps = max_steps
self.agent_settings = agent_settings
self.db_path = db_path
self.verbose = verbose
self.output_dir = output_dir
self.proctor_id = proctor_id
self.db = SQLiteDB(self.db_path)
self.starting_seed = starting_seed
os.makedirs(self.output_dir, exist_ok=True)
self.runs = []
self.setup_runs()
def setup_runs(self):
for start in self.article_list:
for destination in self.article_list:
if start == destination:
continue
for n in range(self.num_trials):
run_id = f"{self.proctor_id}_{start}_{destination}_{n}"
self.runs.append(
Run(
start,
destination,
self.max_steps,
self.agent_settings,
self.db,
self.output_dir,
self.verbose,
run_id,
self.starting_seed + n,
)
)
print(f"Setup run {run_id}")
async def run(self):
semaphore = asyncio.Semaphore(self.num_workers)
tasks = []
async def run_with_semaphore(run_instance):
async with semaphore:
if self.verbose:
print(f"Starting run {run_instance.id}")
await run_instance.run()
if self.verbose:
print(f"Finished run {run_instance.id}")
for run_instance in self.runs:
tasks.append(asyncio.create_task(run_with_semaphore(run_instance)))
await asyncio.gather(*tasks)
self.analyze_runs()
def analyze_runs(self):
"""We need to analze all the runs into a .json"""
final_results = {
"article_list": self.article_list,
"num_trials": self.num_trials,
"num_workers": self.num_workers,
"max_steps": self.max_steps,
"agent_settings": self.agent_settings,
"runs": [],
}
win_count = 0
lose_count = 0
hops_distribution = []
for run in self.runs:
with open(run.output_file, "r") as f:
result = json.load(f)
final_results["runs"].append(result)
if result["result"] == "win":
win_count += 1
hops_distribution.append(len(result["steps"]) - 1)
else:
lose_count += 1
final_results["hops_distribution"] = hops_distribution
final_results["average_hops"] = sum(hops_distribution) / len(hops_distribution)
final_results["win_rate"] = win_count / len(self.runs)
final_results["lose_rate"] = lose_count / len(self.runs)
with open(f"{self.output_dir}/{self.proctor_id}-final-results.json", "w") as f:
json.dump(final_results, f, indent=4)
class Run:
def __init__(
self,
start_article: str,
destination_article: str,
max_steps: int,
agent_settings: dict,
db: SQLiteDB,
output_dir: str,
verbose: bool,
id: str,
seed: int,
):
self.start_article = start_article
self.destination_article = destination_article
self.max_steps = max_steps
self.agent_settings = agent_settings
self.db = db
self.output_dir = output_dir
self.verbose = verbose
self.id = id
self.seed = seed
self.output_file = f"{self.output_dir}/run_{self.id}.json"
async def run(self):
if os.path.exists(self.output_file):
return
player = AgentPlayer(
model=self.agent_settings["model"],
api_base=self.agent_settings["api_base"],
max_links=self.agent_settings["max_links"],
max_tries=self.agent_settings["max_tries"],
verbose=False,
seed=self.seed,
)
game = Game(
self.start_article,
self.destination_article,
self.db,
self.max_steps,
player,
verbose=False,
)
steps = await game.run()
output = {
"model": self.agent_settings["model"],
"api_base": self.agent_settings["api_base"],
"max_links": self.agent_settings["max_links"],
"max_tries": self.agent_settings["max_tries"],
"start_article": self.start_article,
"destination_article": self.destination_article,
"steps": steps,
"seed": self.seed,
"result": steps[-1]["type"],
}
with open(self.output_file, "w") as f:
json.dump(output, f, indent=4)
print(f"Run {self.id} completed in {len(steps)} steps")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run parallel Wikispeedia evaluations")
parser.add_argument("--model", type=str, default="gpt-4o", help="Model to use for agent")
parser.add_argument("--api-base", type=str, default=None, help="API base URL for hosted models")
parser.add_argument("--workers", type=int, default=20, help="Number of parallel workers")
parser.add_argument("--trials", type=int, default=1, help="Number of trials per start-destination pair")
parser.add_argument("--max-steps", type=int, default=20, help="Maximum steps per game")
parser.add_argument("--max-links", type=int, default=200, help="Maximum links per page for agent")
parser.add_argument("--max-tries", type=int, default=3, help="Maximum retries for agent")
parser.add_argument("--db-path", type=str, default="wikihop.db", help="Path to the wikihop database")
parser.add_argument("--output-dir", type=str, default="./proctor_tmp", help="Directory for output files")
parser.add_argument("--proctor-id", type=str, default="proctor_1", help="Unique identifier for this proctor run")
parser.add_argument("--seed", type=int, default=42, help="Starting random seed")
parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
parser.add_argument("--article-list", type=str, default="supernodes.json",
help="Path to JSON file with list of articles to test")
args = parser.parse_args()
# check if db exists
if not os.path.exists(args.db_path):
raise FileNotFoundError(f"Database file not found at {args.db_path}")
# check if article list exists
if not os.path.exists(args.article_list):
raise FileNotFoundError(f"Article list file not found at {args.article_list}")
# Read article list from file
with open(args.article_list, "r") as f:
article_list = json.load(f)
agent_settings = {
"model": args.model,
"api_base": args.api_base,
"max_links": args.max_links,
"max_tries": args.max_tries,
}
proctor = Proctor(
article_list=article_list,
num_trials=args.trials,
num_workers=args.workers,
max_steps=args.max_steps,
agent_settings=agent_settings,
db_path=args.db_path,
verbose=args.verbose,
output_dir=args.output_dir,
proctor_id=args.proctor_id,
starting_seed=args.seed,
)
asyncio.run(proctor.run())