Spaces:
Running
Running
from game import AgentPlayer, SQLiteDB, Game | |
import os | |
import json | |
import asyncio | |
import argparse | |
class Proctor: | |
def __init__( | |
self, | |
article_list: list[tuple[str, str]], | |
num_trials: int, | |
num_workers: int, | |
max_steps: int, | |
agent_settings: dict, | |
db_path: str, | |
verbose: bool = True, | |
output_dir: str = "./proctor_tmp", | |
proctor_id: str = "proctor_1", | |
starting_seed: int = 42, | |
): | |
self.article_list = article_list | |
self.num_trials = num_trials | |
self.num_workers = num_workers | |
self.max_steps = max_steps | |
self.agent_settings = agent_settings | |
self.db_path = db_path | |
self.verbose = verbose | |
self.output_dir = output_dir | |
self.proctor_id = proctor_id | |
self.db = SQLiteDB(self.db_path) | |
self.starting_seed = starting_seed | |
os.makedirs(self.output_dir, exist_ok=True) | |
self.runs = [] | |
self.setup_runs() | |
def setup_runs(self): | |
for start in self.article_list: | |
for destination in self.article_list: | |
if start == destination: | |
continue | |
for n in range(self.num_trials): | |
run_id = f"{self.proctor_id}_{start}_{destination}_{n}" | |
self.runs.append( | |
Run( | |
start, | |
destination, | |
self.max_steps, | |
self.agent_settings, | |
self.db, | |
self.output_dir, | |
self.verbose, | |
run_id, | |
self.starting_seed + n, | |
) | |
) | |
print(f"Setup run {run_id}") | |
async def run(self): | |
semaphore = asyncio.Semaphore(self.num_workers) | |
tasks = [] | |
async def run_with_semaphore(run_instance): | |
async with semaphore: | |
if self.verbose: | |
print(f"Starting run {run_instance.id}") | |
await run_instance.run() | |
if self.verbose: | |
print(f"Finished run {run_instance.id}") | |
for run_instance in self.runs: | |
tasks.append(asyncio.create_task(run_with_semaphore(run_instance))) | |
await asyncio.gather(*tasks) | |
self.analyze_runs() | |
def analyze_runs(self): | |
"""We need to analze all the runs into a .json""" | |
final_results = { | |
"article_list": self.article_list, | |
"num_trials": self.num_trials, | |
"num_workers": self.num_workers, | |
"max_steps": self.max_steps, | |
"agent_settings": self.agent_settings, | |
"runs": [], | |
} | |
win_count = 0 | |
lose_count = 0 | |
hops_distribution = [] | |
for run in self.runs: | |
with open(run.output_file, "r") as f: | |
result = json.load(f) | |
final_results["runs"].append(result) | |
if result["result"] == "win": | |
win_count += 1 | |
hops_distribution.append(len(result["steps"]) - 1) | |
else: | |
lose_count += 1 | |
final_results["hops_distribution"] = hops_distribution | |
final_results["average_hops"] = sum(hops_distribution) / len(hops_distribution) | |
final_results["win_rate"] = win_count / len(self.runs) | |
final_results["lose_rate"] = lose_count / len(self.runs) | |
with open(f"{self.output_dir}/{self.proctor_id}-final-results.json", "w") as f: | |
json.dump(final_results, f, indent=4) | |
class Run: | |
def __init__( | |
self, | |
start_article: str, | |
destination_article: str, | |
max_steps: int, | |
agent_settings: dict, | |
db: SQLiteDB, | |
output_dir: str, | |
verbose: bool, | |
id: str, | |
seed: int, | |
): | |
self.start_article = start_article | |
self.destination_article = destination_article | |
self.max_steps = max_steps | |
self.agent_settings = agent_settings | |
self.db = db | |
self.output_dir = output_dir | |
self.verbose = verbose | |
self.id = id | |
self.seed = seed | |
self.output_file = f"{self.output_dir}/run_{self.id}.json" | |
async def run(self): | |
if os.path.exists(self.output_file): | |
return | |
player = AgentPlayer( | |
model=self.agent_settings["model"], | |
api_base=self.agent_settings["api_base"], | |
max_links=self.agent_settings["max_links"], | |
max_tries=self.agent_settings["max_tries"], | |
verbose=False, | |
seed=self.seed, | |
) | |
game = Game( | |
self.start_article, | |
self.destination_article, | |
self.db, | |
self.max_steps, | |
player, | |
verbose=False, | |
) | |
steps = await game.run() | |
output = { | |
"model": self.agent_settings["model"], | |
"api_base": self.agent_settings["api_base"], | |
"max_links": self.agent_settings["max_links"], | |
"max_tries": self.agent_settings["max_tries"], | |
"start_article": self.start_article, | |
"destination_article": self.destination_article, | |
"steps": steps, | |
"seed": self.seed, | |
"result": steps[-1]["type"], | |
} | |
with open(self.output_file, "w") as f: | |
json.dump(output, f, indent=4) | |
print(f"Run {self.id} completed in {len(steps)} steps") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Run parallel Wikispeedia evaluations") | |
parser.add_argument("--model", type=str, default="gpt-4o", help="Model to use for agent") | |
parser.add_argument("--api-base", type=str, default=None, help="API base URL for hosted models") | |
parser.add_argument("--workers", type=int, default=20, help="Number of parallel workers") | |
parser.add_argument("--trials", type=int, default=1, help="Number of trials per start-destination pair") | |
parser.add_argument("--max-steps", type=int, default=20, help="Maximum steps per game") | |
parser.add_argument("--max-links", type=int, default=200, help="Maximum links per page for agent") | |
parser.add_argument("--max-tries", type=int, default=3, help="Maximum retries for agent") | |
parser.add_argument("--db-path", type=str, default="wikihop.db", help="Path to the wikihop database") | |
parser.add_argument("--output-dir", type=str, default="./proctor_tmp", help="Directory for output files") | |
parser.add_argument("--proctor-id", type=str, default="proctor_1", help="Unique identifier for this proctor run") | |
parser.add_argument("--seed", type=int, default=42, help="Starting random seed") | |
parser.add_argument("--verbose", action="store_true", help="Enable verbose output") | |
parser.add_argument("--article-list", type=str, default="supernodes.json", | |
help="Path to JSON file with list of articles to test") | |
args = parser.parse_args() | |
# check if db exists | |
if not os.path.exists(args.db_path): | |
raise FileNotFoundError(f"Database file not found at {args.db_path}") | |
# check if article list exists | |
if not os.path.exists(args.article_list): | |
raise FileNotFoundError(f"Article list file not found at {args.article_list}") | |
# Read article list from file | |
with open(args.article_list, "r") as f: | |
article_list = json.load(f) | |
agent_settings = { | |
"model": args.model, | |
"api_base": args.api_base, | |
"max_links": args.max_links, | |
"max_tries": args.max_tries, | |
} | |
proctor = Proctor( | |
article_list=article_list, | |
num_trials=args.trials, | |
num_workers=args.workers, | |
max_steps=args.max_steps, | |
agent_settings=agent_settings, | |
db_path=args.db_path, | |
verbose=args.verbose, | |
output_dir=args.output_dir, | |
proctor_id=args.proctor_id, | |
starting_seed=args.seed, | |
) | |
asyncio.run(proctor.run()) | |