Spaces:
Running
Running
Commit
·
1e310b7
1
Parent(s):
8b6f89e
docs minus assets
Browse files- .gitignore +4 -1
- Dockerfile +1 -1
- README.md +12 -1
- index.html +1 -1
- parallel_eval/README.md +59 -0
- parallel_eval/game.py +310 -0
- parallel_eval/proctor.py +233 -0
- parallel_eval/requirements.txt +5 -0
- parallel_eval/supernodes.json +19 -0
- src/components/viewer-tab.tsx +69 -5
.gitignore
CHANGED
@@ -28,4 +28,7 @@ tmp
|
|
28 |
|
29 |
qwen3-final-results.json
|
30 |
|
31 |
-
__pycache__
|
|
|
|
|
|
|
|
28 |
|
29 |
qwen3-final-results.json
|
30 |
|
31 |
+
__pycache__
|
32 |
+
.venv
|
33 |
+
proctor_tmp
|
34 |
+
wikihop.db
|
Dockerfile
CHANGED
@@ -53,7 +53,7 @@ RUN --mount=type=secret,id=HF_TOKEN,mode=0444,required=true \
|
|
53 |
curl https://huggingface.co/api/whoami-v2 -H "Authorization: Bearer $(cat /run/secrets/HF_TOKEN)"
|
54 |
|
55 |
RUN --mount=type=secret,id=HF_TOKEN,mode=0444,required=true \
|
56 |
-
curl -L https://huggingface.co/HuggingFaceTB/simplewiki-pruned-text-350k/
|
57 |
|
58 |
ENV WIKISPEEDIA_DB_PATH=/home/user/app/wikihop.db
|
59 |
|
|
|
53 |
curl https://huggingface.co/api/whoami-v2 -H "Authorization: Bearer $(cat /run/secrets/HF_TOKEN)"
|
54 |
|
55 |
RUN --mount=type=secret,id=HF_TOKEN,mode=0444,required=true \
|
56 |
+
curl -L https://huggingface.co/datasets/HuggingFaceTB/simplewiki-pruned-text-350k/blob/main/wikihop.db -H "Authorization: Bearer $(cat /run/secrets/HF_TOKEN)" -o wikihop.db
|
57 |
|
58 |
ENV WIKISPEEDIA_DB_PATH=/home/user/app/wikihop.db
|
59 |
|
README.md
CHANGED
@@ -9,4 +9,15 @@ hf_oauth: true
|
|
9 |
hf_oauth_scopes:
|
10 |
- inference-api
|
11 |
- email
|
12 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
hf_oauth_scopes:
|
10 |
- inference-api
|
11 |
- email
|
12 |
+
---
|
13 |
+
|
14 |
+
# Can you wikirace faster than an LLM? 🏁
|
15 |
+
|
16 |
+
Go head-to-head with Qwen, Gemma, and DeepSeek on the [Huggingface Space](https://huggingface.co/spaces/HuggingFaceTB/Wikispeedia)
|
17 |
+
|
18 |
+
<!-- add gifs -->
|
19 |
+

|
20 |
+
|
21 |
+
Or run 100s of agents on any model in parallel for efficient evaluations [see README](parallel_eval)
|
22 |
+
|
23 |
+

|
index.html
CHANGED
@@ -4,7 +4,7 @@
|
|
4 |
<meta charset="UTF-8" />
|
5 |
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
6 |
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
7 |
-
<title>
|
8 |
</head>
|
9 |
<body>
|
10 |
<div id="root"></div>
|
|
|
4 |
<meta charset="UTF-8" />
|
5 |
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
6 |
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
7 |
+
<title>WikiRacing LLMs</title>
|
8 |
</head>
|
9 |
<body>
|
10 |
<div id="root"></div>
|
parallel_eval/README.md
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Setup env
|
2 |
+
|
3 |
+
```bash
|
4 |
+
uv venv
|
5 |
+
source .venv/bin/activate
|
6 |
+
uv pip install -r requirements.txt
|
7 |
+
|
8 |
+
# pull wikihop db
|
9 |
+
wget https://huggingface.co/datasets/HuggingFaceTB/simplewiki-pruned-text-350k/blob/main/wikihop.db -o wikihop.db
|
10 |
+
```
|
11 |
+
|
12 |
+
## Which models does it support?
|
13 |
+
Under the hood it uses [LiteLLM](https://github.com/BerriAI/litellm) so you can use any major model (dont forget to export appropriate api key), or host any model on huggingface via [vLLM](https://github.com/vllm-project/vllm).
|
14 |
+
|
15 |
+
|
16 |
+
## Play the game
|
17 |
+
```
|
18 |
+
# play the game with cli
|
19 |
+
python game.py --human --start 'Saint Lucia' --end 'Italy' --db wikihop.db
|
20 |
+
|
21 |
+
# have the agent play the game (gpt-4o)
|
22 |
+
export OPENAI_API_KEY=sk_xxxxx
|
23 |
+
python game.py --agent --start 'Saint Lucia' --end 'Italy' --db wikihop.db --model gpt-4o --max-steps 20
|
24 |
+
|
25 |
+
# run an evaluation suite with qwen3 hosted on vLLM, 200 workers
|
26 |
+
python proctor.py --model "hosted_vllm/Qwen/Qwen3-30B-A3B" --api-base "http://localhost:8000/v1" --workers 200
|
27 |
+
|
28 |
+
# this will produce a `proctor_tmp/proctor_1-final-results.json` that can be visualized in the space, as well as the individual reasoning traces for each run. This is resumable if it is stopped and is idempotent.
|
29 |
+
```
|
30 |
+
|
31 |
+
## JQ command to strip out reasoning traces
|
32 |
+
This output file will be very large because it contains all the reasoning traces. You can shrink it down and still be able to visualize it with
|
33 |
+
|
34 |
+
```bash
|
35 |
+
jq '{
|
36 |
+
article_list: .article_list,
|
37 |
+
num_trials: .num_trials,
|
38 |
+
num_workers: .num_workers,
|
39 |
+
max_steps: .max_steps,
|
40 |
+
agent_settings: .agent_settings,
|
41 |
+
runs: [.runs[] | {
|
42 |
+
model: .model,
|
43 |
+
api_base: .api_base,
|
44 |
+
max_links: .max_links,
|
45 |
+
max_tries: .max_tries, result: .result,
|
46 |
+
start_article: .start_article,
|
47 |
+
destination_article: .destination_article,
|
48 |
+
steps: [.steps[] | {
|
49 |
+
type: .type,
|
50 |
+
article: .article,
|
51 |
+
metadata: (if .metadata.conversation then
|
52 |
+
.metadata | del(.conversation)
|
53 |
+
else
|
54 |
+
.metadata
|
55 |
+
end)
|
56 |
+
}]
|
57 |
+
}]
|
58 |
+
}' proctor_tmp/proctor_1-final-results.json > cleaned_data.json
|
59 |
+
```
|
parallel_eval/game.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple, Dict, Optional
|
2 |
+
import sqlite3
|
3 |
+
import json
|
4 |
+
import litellm
|
5 |
+
import re
|
6 |
+
import asyncio
|
7 |
+
import argparse
|
8 |
+
from functools import lru_cache
|
9 |
+
class SQLiteDB:
|
10 |
+
def __init__(self, db_path: str):
|
11 |
+
"""Initialize the database with path to SQLite database"""
|
12 |
+
self.db_path = db_path
|
13 |
+
self.conn = sqlite3.connect(db_path)
|
14 |
+
self.conn.row_factory = sqlite3.Row
|
15 |
+
self.cursor = self.conn.cursor()
|
16 |
+
self._article_count = self._get_article_count()
|
17 |
+
print(f"Connected to SQLite database with {self._article_count} articles")
|
18 |
+
|
19 |
+
def _get_article_count(self):
|
20 |
+
self.cursor.execute("SELECT COUNT(*) FROM core_articles")
|
21 |
+
return self.cursor.fetchone()[0]
|
22 |
+
|
23 |
+
@lru_cache(maxsize=8192)
|
24 |
+
def get_article_with_links(self, article_title: str) -> Tuple[str, List[str]]:
|
25 |
+
self.cursor.execute(
|
26 |
+
"SELECT title, links_json FROM core_articles WHERE title = ?",
|
27 |
+
(article_title,),
|
28 |
+
)
|
29 |
+
article = self.cursor.fetchone()
|
30 |
+
if not article:
|
31 |
+
return None, []
|
32 |
+
|
33 |
+
links = json.loads(article["links_json"])
|
34 |
+
return article["title"], links
|
35 |
+
|
36 |
+
|
37 |
+
class Player:
|
38 |
+
def __init__(self, name: str):
|
39 |
+
self.name = name
|
40 |
+
|
41 |
+
async def get_move(self, game_state: List[Dict]) -> Tuple[str, Dict]:
|
42 |
+
print("Link choices:")
|
43 |
+
for i, link in enumerate(game_state[-1]["links"]):
|
44 |
+
print(f"{i}: {link}")
|
45 |
+
|
46 |
+
idx = int(input(f"Enter the index of the link you want to select: "))
|
47 |
+
return game_state[-1]["links"][idx], {
|
48 |
+
"message": f"{self.name} selected link #{i}"
|
49 |
+
} # select the first link
|
50 |
+
|
51 |
+
|
52 |
+
class AgentPlayer(Player):
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
model: str,
|
56 |
+
api_base: str,
|
57 |
+
verbose: bool = True,
|
58 |
+
max_links=None,
|
59 |
+
max_tries=10,
|
60 |
+
target_article = None,
|
61 |
+
seed = None
|
62 |
+
):
|
63 |
+
super().__init__(model)
|
64 |
+
self.model = model
|
65 |
+
self.api_base = api_base
|
66 |
+
self.verbose = verbose
|
67 |
+
self.max_links = max_links
|
68 |
+
self.max_tries = max_tries
|
69 |
+
self.target_article = target_article
|
70 |
+
self.seed = seed
|
71 |
+
|
72 |
+
async def get_move(self, game_state: List[Dict]) -> Tuple[str, Dict]:
|
73 |
+
prompt = self.construct_prompt(game_state)
|
74 |
+
|
75 |
+
conversation = [
|
76 |
+
{"role": "user", "content": prompt}
|
77 |
+
]
|
78 |
+
|
79 |
+
for try_number in range(self.max_tries):
|
80 |
+
response = await litellm.acompletion(
|
81 |
+
model=self.model,
|
82 |
+
api_base=self.api_base,
|
83 |
+
messages=conversation,
|
84 |
+
seed=self.seed
|
85 |
+
)
|
86 |
+
response = response.choices[0].message.content
|
87 |
+
|
88 |
+
conversation.append({"role": "assistant", "content": response})
|
89 |
+
|
90 |
+
answer, message = self._attempt_to_extract_answer(response, maximum_answer=len(game_state[-1]["links"]))
|
91 |
+
|
92 |
+
# there was a problem with the answer so give the model another chance
|
93 |
+
if answer == -1:
|
94 |
+
conversation.append({"role": "user", "content": message})
|
95 |
+
continue
|
96 |
+
|
97 |
+
assert answer >= 1 and answer <= len(game_state[-1]["links"]), f"Answer {answer} is out of range"
|
98 |
+
|
99 |
+
# we found an answer so we can return it
|
100 |
+
return game_state[-1]["links"][answer-1], {"tries": try_number, "conversation": conversation}
|
101 |
+
|
102 |
+
# we tried the max number of times and still didn't find an answer
|
103 |
+
return -1, {"tries": self.max_tries, "conversation": conversation}
|
104 |
+
|
105 |
+
def construct_prompt(self, game_state: List[Dict]) -> str:
|
106 |
+
current = game_state[-1]["article"]
|
107 |
+
target = self.target_article
|
108 |
+
available_links = game_state[-1]["links"]
|
109 |
+
formatted_links = "\n".join([f"{i+1}. {link}" for i, link in enumerate(available_links)])
|
110 |
+
path_so_far = [step["article"] for step in game_state]
|
111 |
+
|
112 |
+
try:
|
113 |
+
formatted_path = ' -> '.join(path_so_far)
|
114 |
+
except Exception as e:
|
115 |
+
print(f"Error formatting path: {e}")
|
116 |
+
print(game_state)
|
117 |
+
print("Path so far: ", path_so_far)
|
118 |
+
raise e
|
119 |
+
|
120 |
+
return f"""You are playing WikiRun, trying to navigate from one Wikipedia article to another using only links.
|
121 |
+
|
122 |
+
IMPORTANT: You MUST put your final answer in <answer>NUMBER</answer> tags, where NUMBER is the link number.
|
123 |
+
For example, if you want to choose link 3, output <answer>3</answer>.
|
124 |
+
|
125 |
+
Current article: {current}
|
126 |
+
Target article: {target}
|
127 |
+
Available links (numbered):
|
128 |
+
{formatted_links}
|
129 |
+
|
130 |
+
Your path so far: {formatted_path}
|
131 |
+
|
132 |
+
Think about which link is most likely to lead you toward the target article.
|
133 |
+
First, analyze each link briefly and how it connects to your goal, then select the most promising one.
|
134 |
+
|
135 |
+
Remember to format your final answer by explicitly writing out the xml number tags like this: <answer>NUMBER</answer>
|
136 |
+
"""
|
137 |
+
|
138 |
+
def _attempt_to_extract_answer(self, response: str, maximum_answer: Optional[int] = None) -> Tuple[int, str]:
|
139 |
+
'returns -1 and a message if no answer is found'
|
140 |
+
|
141 |
+
# Extract choice using format <answer>N</answer>
|
142 |
+
choice_match = re.search(r"<answer>(\d+)</answer>", response)
|
143 |
+
|
144 |
+
if choice_match is None:
|
145 |
+
return -1, f"No answer found in response. Please respond with a number between 1 and {maximum_answer} in <answer>NUMBER</answer> tags."
|
146 |
+
|
147 |
+
# check if there are multiple answers
|
148 |
+
multiple_answers = re.findall(r"<answer>(\d+)</answer>", response)
|
149 |
+
if len(multiple_answers) > 1:
|
150 |
+
return -1, "Multiple answers found in response. Please respond with just one."
|
151 |
+
|
152 |
+
answer = choice_match.group(1)
|
153 |
+
|
154 |
+
# try to convert to int
|
155 |
+
try:
|
156 |
+
answer = int(answer)
|
157 |
+
except ValueError:
|
158 |
+
return -1, f"You answered with {answer} but it could not be converted to an integer. Please respond with a number between 1 and {maximum_answer}."
|
159 |
+
|
160 |
+
# check if the answer is too high or too low
|
161 |
+
if answer > maximum_answer or answer < 1:
|
162 |
+
return -1, f"You answered with {answer} but you have to select a number between 1 and {maximum_answer}."
|
163 |
+
|
164 |
+
return answer, "" # we found an answer so we don't need to return a message
|
165 |
+
|
166 |
+
class Game:
|
167 |
+
def __init__(
|
168 |
+
self,
|
169 |
+
start_article: str,
|
170 |
+
target_article: str,
|
171 |
+
db: SQLiteDB,
|
172 |
+
max_allowed_steps: int,
|
173 |
+
player: Player,
|
174 |
+
verbose: bool = True,
|
175 |
+
):
|
176 |
+
self.start_article = start_article
|
177 |
+
self.target_article = target_article
|
178 |
+
self.db = db
|
179 |
+
self.max_allowed_steps = max_allowed_steps
|
180 |
+
self.steps = []
|
181 |
+
self.steps_taken = 0
|
182 |
+
self.player = player
|
183 |
+
self.verbose = verbose
|
184 |
+
# Ensure the player knows the target article
|
185 |
+
if isinstance(self.player, AgentPlayer):
|
186 |
+
self.player.target_article = self.target_article
|
187 |
+
|
188 |
+
async def run(self):
|
189 |
+
|
190 |
+
if self.verbose:
|
191 |
+
print(f"Starting game from {self.start_article} to {self.target_article}")
|
192 |
+
|
193 |
+
# get the start article
|
194 |
+
_, links = self.db.get_article_with_links(self.start_article)
|
195 |
+
|
196 |
+
self.steps.append(
|
197 |
+
{
|
198 |
+
"type": "start",
|
199 |
+
"article": self.start_article,
|
200 |
+
"links": links,
|
201 |
+
"metadata": {"message": "Game started"},
|
202 |
+
}
|
203 |
+
)
|
204 |
+
|
205 |
+
# while the current article is not the target article and the number of steps taken is less than the max allowed steps
|
206 |
+
while self.steps_taken < self.max_allowed_steps:
|
207 |
+
self.steps_taken += 1
|
208 |
+
|
209 |
+
# Await the async player move
|
210 |
+
player_move, metadata = await self.player.get_move(self.steps)
|
211 |
+
|
212 |
+
# player couldn't select a valid link
|
213 |
+
if player_move == -1:
|
214 |
+
self.steps.append(
|
215 |
+
{"type": "lose", "article": player_move, "metadata": metadata}
|
216 |
+
)
|
217 |
+
break
|
218 |
+
|
219 |
+
if self.verbose:
|
220 |
+
print(f" -> Step {self.steps_taken}: {player_move}")
|
221 |
+
# input("Press Enter to continue...")
|
222 |
+
|
223 |
+
# if we found it its over
|
224 |
+
if player_move == self.target_article:
|
225 |
+
self.steps.append(
|
226 |
+
{"type": "win", "article": player_move, "metadata": metadata}
|
227 |
+
)
|
228 |
+
break
|
229 |
+
|
230 |
+
# if not lets get the next article
|
231 |
+
_, links = self.db.get_article_with_links(player_move)
|
232 |
+
|
233 |
+
if len(links) == 0:
|
234 |
+
self.steps.append(
|
235 |
+
{"type": "lose", "article": player_move, "metadata": metadata}
|
236 |
+
)
|
237 |
+
break
|
238 |
+
|
239 |
+
self.steps.append(
|
240 |
+
{
|
241 |
+
"type": "move",
|
242 |
+
"article": player_move,
|
243 |
+
"links": links,
|
244 |
+
"metadata": metadata,
|
245 |
+
}
|
246 |
+
)
|
247 |
+
|
248 |
+
return self.steps
|
249 |
+
|
250 |
+
|
251 |
+
if __name__ == "__main__":
|
252 |
+
parser = argparse.ArgumentParser(description="Play the WikiRun game")
|
253 |
+
|
254 |
+
# Add mutual exclusion group for player type
|
255 |
+
player_group = parser.add_mutually_exclusive_group(required=True)
|
256 |
+
player_group.add_argument("--human", action="store_true", help="Play as a human")
|
257 |
+
player_group.add_argument("--agent", action="store_true", help="Use an AI agent to play")
|
258 |
+
|
259 |
+
# Game parameters
|
260 |
+
parser.add_argument("--start", type=str, default="British Library", help="Starting article title")
|
261 |
+
parser.add_argument("--end", type=str, default="Saint Lucia", help="Target article title")
|
262 |
+
parser.add_argument("--db", type=str, required=True, help="Path to SQLite database")
|
263 |
+
parser.add_argument("--max-steps", type=int, default=10, help="Maximum number of steps allowed (default: 10)")
|
264 |
+
|
265 |
+
# Agent parameters (only used with --agent)
|
266 |
+
parser.add_argument("--model", type=str, default="gpt-4o", help="Model to use for the agent (default: gpt-4o)")
|
267 |
+
parser.add_argument("--api-base", type=str, default="https://api.openai.com/v1",
|
268 |
+
help="API base URL (default: https://api.openai.com/v1)")
|
269 |
+
parser.add_argument("--max-links", type=int, default=200, help="Maximum number of links to consider (default: 200)")
|
270 |
+
parser.add_argument("--max-tries", type=int, default=3, help="Maximum number of tries for the agent (default: 3)")
|
271 |
+
parser.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility")
|
272 |
+
|
273 |
+
args = parser.parse_args()
|
274 |
+
|
275 |
+
# Initialize the database
|
276 |
+
db = SQLiteDB(args.db)
|
277 |
+
|
278 |
+
# Initialize the player based on the argument
|
279 |
+
if args.human:
|
280 |
+
player = Player("Human")
|
281 |
+
else: # args.agent is True
|
282 |
+
player = AgentPlayer(
|
283 |
+
model=args.model,
|
284 |
+
api_base=args.api_base,
|
285 |
+
verbose=True,
|
286 |
+
max_links=args.max_links,
|
287 |
+
max_tries=args.max_tries,
|
288 |
+
target_article=args.end,
|
289 |
+
seed=args.seed
|
290 |
+
)
|
291 |
+
|
292 |
+
# Create and run the game
|
293 |
+
game = Game(
|
294 |
+
start_article=args.start,
|
295 |
+
target_article=args.end,
|
296 |
+
db=db,
|
297 |
+
max_allowed_steps=args.max_steps,
|
298 |
+
player=player,
|
299 |
+
verbose=True
|
300 |
+
)
|
301 |
+
|
302 |
+
steps = asyncio.run(game.run())
|
303 |
+
|
304 |
+
print(f"Game over in {len(steps)} steps")
|
305 |
+
for i, step in enumerate(steps):
|
306 |
+
print(f"Step {i}: {step['type']}")
|
307 |
+
print(f" Article: {step['article']}")
|
308 |
+
print(f" Links: {step.get('links', [])}")
|
309 |
+
print(f" Metadata: {step.get('metadata', {})}")
|
310 |
+
print("\n\n")
|
parallel_eval/proctor.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from game import AgentPlayer, SQLiteDB, Game
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import asyncio
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
|
8 |
+
class Proctor:
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
article_list: list[tuple[str, str]],
|
12 |
+
num_trials: int,
|
13 |
+
num_workers: int,
|
14 |
+
max_steps: int,
|
15 |
+
agent_settings: dict,
|
16 |
+
db_path: str,
|
17 |
+
verbose: bool = True,
|
18 |
+
output_dir: str = "./proctor_tmp",
|
19 |
+
proctor_id: str = "proctor_1",
|
20 |
+
starting_seed: int = 42,
|
21 |
+
):
|
22 |
+
self.article_list = article_list
|
23 |
+
self.num_trials = num_trials
|
24 |
+
self.num_workers = num_workers
|
25 |
+
self.max_steps = max_steps
|
26 |
+
self.agent_settings = agent_settings
|
27 |
+
self.db_path = db_path
|
28 |
+
self.verbose = verbose
|
29 |
+
self.output_dir = output_dir
|
30 |
+
self.proctor_id = proctor_id
|
31 |
+
self.db = SQLiteDB(self.db_path)
|
32 |
+
self.starting_seed = starting_seed
|
33 |
+
|
34 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
35 |
+
|
36 |
+
self.runs = []
|
37 |
+
|
38 |
+
self.setup_runs()
|
39 |
+
|
40 |
+
def setup_runs(self):
|
41 |
+
for start in self.article_list:
|
42 |
+
for destination in self.article_list:
|
43 |
+
if start == destination:
|
44 |
+
continue
|
45 |
+
for n in range(self.num_trials):
|
46 |
+
run_id = f"{self.proctor_id}_{start}_{destination}_{n}"
|
47 |
+
self.runs.append(
|
48 |
+
Run(
|
49 |
+
start,
|
50 |
+
destination,
|
51 |
+
self.max_steps,
|
52 |
+
self.agent_settings,
|
53 |
+
self.db,
|
54 |
+
self.output_dir,
|
55 |
+
self.verbose,
|
56 |
+
run_id,
|
57 |
+
self.starting_seed + n,
|
58 |
+
)
|
59 |
+
)
|
60 |
+
print(f"Setup run {run_id}")
|
61 |
+
|
62 |
+
async def run(self):
|
63 |
+
semaphore = asyncio.Semaphore(self.num_workers)
|
64 |
+
tasks = []
|
65 |
+
|
66 |
+
async def run_with_semaphore(run_instance):
|
67 |
+
async with semaphore:
|
68 |
+
if self.verbose:
|
69 |
+
print(f"Starting run {run_instance.id}")
|
70 |
+
await run_instance.run()
|
71 |
+
if self.verbose:
|
72 |
+
print(f"Finished run {run_instance.id}")
|
73 |
+
|
74 |
+
for run_instance in self.runs:
|
75 |
+
tasks.append(asyncio.create_task(run_with_semaphore(run_instance)))
|
76 |
+
|
77 |
+
await asyncio.gather(*tasks)
|
78 |
+
|
79 |
+
self.analyze_runs()
|
80 |
+
|
81 |
+
def analyze_runs(self):
|
82 |
+
"""We need to analze all the runs into a .json"""
|
83 |
+
final_results = {
|
84 |
+
"article_list": self.article_list,
|
85 |
+
"num_trials": self.num_trials,
|
86 |
+
"num_workers": self.num_workers,
|
87 |
+
"max_steps": self.max_steps,
|
88 |
+
"agent_settings": self.agent_settings,
|
89 |
+
"runs": [],
|
90 |
+
}
|
91 |
+
|
92 |
+
win_count = 0
|
93 |
+
lose_count = 0
|
94 |
+
hops_distribution = []
|
95 |
+
|
96 |
+
for run in self.runs:
|
97 |
+
with open(run.output_file, "r") as f:
|
98 |
+
result = json.load(f)
|
99 |
+
final_results["runs"].append(result)
|
100 |
+
if result["result"] == "win":
|
101 |
+
win_count += 1
|
102 |
+
hops_distribution.append(len(result["steps"]) - 1)
|
103 |
+
else:
|
104 |
+
lose_count += 1
|
105 |
+
|
106 |
+
final_results["hops_distribution"] = hops_distribution
|
107 |
+
final_results["average_hops"] = sum(hops_distribution) / len(hops_distribution)
|
108 |
+
final_results["win_rate"] = win_count / len(self.runs)
|
109 |
+
final_results["lose_rate"] = lose_count / len(self.runs)
|
110 |
+
|
111 |
+
with open(f"{self.output_dir}/{self.proctor_id}-final-results.json", "w") as f:
|
112 |
+
json.dump(final_results, f, indent=4)
|
113 |
+
|
114 |
+
|
115 |
+
class Run:
|
116 |
+
def __init__(
|
117 |
+
self,
|
118 |
+
start_article: str,
|
119 |
+
destination_article: str,
|
120 |
+
max_steps: int,
|
121 |
+
agent_settings: dict,
|
122 |
+
db: SQLiteDB,
|
123 |
+
output_dir: str,
|
124 |
+
verbose: bool,
|
125 |
+
id: str,
|
126 |
+
seed: int,
|
127 |
+
):
|
128 |
+
self.start_article = start_article
|
129 |
+
self.destination_article = destination_article
|
130 |
+
self.max_steps = max_steps
|
131 |
+
self.agent_settings = agent_settings
|
132 |
+
self.db = db
|
133 |
+
self.output_dir = output_dir
|
134 |
+
self.verbose = verbose
|
135 |
+
self.id = id
|
136 |
+
self.seed = seed
|
137 |
+
|
138 |
+
self.output_file = f"{self.output_dir}/run_{self.id}.json"
|
139 |
+
|
140 |
+
async def run(self):
|
141 |
+
if os.path.exists(self.output_file):
|
142 |
+
return
|
143 |
+
|
144 |
+
player = AgentPlayer(
|
145 |
+
model=self.agent_settings["model"],
|
146 |
+
api_base=self.agent_settings["api_base"],
|
147 |
+
max_links=self.agent_settings["max_links"],
|
148 |
+
max_tries=self.agent_settings["max_tries"],
|
149 |
+
verbose=False,
|
150 |
+
seed=self.seed,
|
151 |
+
)
|
152 |
+
|
153 |
+
game = Game(
|
154 |
+
self.start_article,
|
155 |
+
self.destination_article,
|
156 |
+
self.db,
|
157 |
+
self.max_steps,
|
158 |
+
player,
|
159 |
+
verbose=False,
|
160 |
+
)
|
161 |
+
|
162 |
+
steps = await game.run()
|
163 |
+
|
164 |
+
output = {
|
165 |
+
"model": self.agent_settings["model"],
|
166 |
+
"api_base": self.agent_settings["api_base"],
|
167 |
+
"max_links": self.agent_settings["max_links"],
|
168 |
+
"max_tries": self.agent_settings["max_tries"],
|
169 |
+
"start_article": self.start_article,
|
170 |
+
"destination_article": self.destination_article,
|
171 |
+
"steps": steps,
|
172 |
+
"seed": self.seed,
|
173 |
+
"result": steps[-1]["type"],
|
174 |
+
}
|
175 |
+
|
176 |
+
with open(self.output_file, "w") as f:
|
177 |
+
json.dump(output, f, indent=4)
|
178 |
+
|
179 |
+
print(f"Run {self.id} completed in {len(steps)} steps")
|
180 |
+
|
181 |
+
|
182 |
+
if __name__ == "__main__":
|
183 |
+
parser = argparse.ArgumentParser(description="Run parallel Wikispeedia evaluations")
|
184 |
+
parser.add_argument("--model", type=str, default="gpt-4o", help="Model to use for agent")
|
185 |
+
parser.add_argument("--api-base", type=str, default=None, help="API base URL for hosted models")
|
186 |
+
parser.add_argument("--workers", type=int, default=20, help="Number of parallel workers")
|
187 |
+
parser.add_argument("--trials", type=int, default=1, help="Number of trials per start-destination pair")
|
188 |
+
parser.add_argument("--max-steps", type=int, default=20, help="Maximum steps per game")
|
189 |
+
parser.add_argument("--max-links", type=int, default=200, help="Maximum links per page for agent")
|
190 |
+
parser.add_argument("--max-tries", type=int, default=3, help="Maximum retries for agent")
|
191 |
+
parser.add_argument("--db-path", type=str, default="wikihop.db", help="Path to the wikihop database")
|
192 |
+
parser.add_argument("--output-dir", type=str, default="./proctor_tmp", help="Directory for output files")
|
193 |
+
parser.add_argument("--proctor-id", type=str, default="proctor_1", help="Unique identifier for this proctor run")
|
194 |
+
parser.add_argument("--seed", type=int, default=42, help="Starting random seed")
|
195 |
+
parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
|
196 |
+
parser.add_argument("--article-list", type=str, default="supernodes.json",
|
197 |
+
help="Path to JSON file with list of articles to test")
|
198 |
+
|
199 |
+
args = parser.parse_args()
|
200 |
+
|
201 |
+
# check if db exists
|
202 |
+
if not os.path.exists(args.db_path):
|
203 |
+
raise FileNotFoundError(f"Database file not found at {args.db_path}")
|
204 |
+
|
205 |
+
# check if article list exists
|
206 |
+
if not os.path.exists(args.article_list):
|
207 |
+
raise FileNotFoundError(f"Article list file not found at {args.article_list}")
|
208 |
+
|
209 |
+
# Read article list from file
|
210 |
+
with open(args.article_list, "r") as f:
|
211 |
+
article_list = json.load(f)
|
212 |
+
|
213 |
+
agent_settings = {
|
214 |
+
"model": args.model,
|
215 |
+
"api_base": args.api_base,
|
216 |
+
"max_links": args.max_links,
|
217 |
+
"max_tries": args.max_tries,
|
218 |
+
}
|
219 |
+
|
220 |
+
proctor = Proctor(
|
221 |
+
article_list=article_list,
|
222 |
+
num_trials=args.trials,
|
223 |
+
num_workers=args.workers,
|
224 |
+
max_steps=args.max_steps,
|
225 |
+
agent_settings=agent_settings,
|
226 |
+
db_path=args.db_path,
|
227 |
+
verbose=args.verbose,
|
228 |
+
output_dir=args.output_dir,
|
229 |
+
proctor_id=args.proctor_id,
|
230 |
+
starting_seed=args.seed,
|
231 |
+
)
|
232 |
+
|
233 |
+
asyncio.run(proctor.run())
|
parallel_eval/requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
litellm>=1.10.0
|
2 |
+
asyncio
|
3 |
+
tqdm
|
4 |
+
sqlite3-wrapper
|
5 |
+
aiohttp
|
parallel_eval/supernodes.json
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
"Soviet Union",
|
3 |
+
"Frank Lloyd Wright",
|
4 |
+
"Major League Baseball",
|
5 |
+
"R (programming language)",
|
6 |
+
"Hinduism",
|
7 |
+
"Singapore General Hospital",
|
8 |
+
"Nepenthes",
|
9 |
+
"Google AI",
|
10 |
+
"Freedom, Pennsylvania",
|
11 |
+
"Iron Man 3",
|
12 |
+
"Central Bank of Nigeria",
|
13 |
+
"Pok\u00e9mon",
|
14 |
+
"Nintendo",
|
15 |
+
"Bachelor of Arts",
|
16 |
+
"Polynesian languages",
|
17 |
+
"France",
|
18 |
+
"Jennifer Aniston"
|
19 |
+
]
|
src/components/viewer-tab.tsx
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
import q3Results from "../../results/qwen3.json"
|
4 |
import q3_30B_A3B_Results from "../../results/qwen3-30B-A3-results.json"
|
5 |
// import mockResults from "../../qwen3-final-results.json"
|
6 |
-
import { useMemo, useState, useEffect } from "react";
|
7 |
import { Card } from "@/components/ui/card";
|
8 |
import ForceDirectedGraph from "@/components/force-directed-graph";
|
9 |
import RunsList from "@/components/runs-list";
|
@@ -16,8 +16,10 @@ import {
|
|
16 |
} from "@/components/ui/select";
|
17 |
import { Run as ForceGraphRun } from "@/components/reasoning-trace";
|
18 |
import { Badge } from "@/components/ui/badge";
|
|
|
|
|
19 |
|
20 |
-
const
|
21 |
"Qwen3-14B": q3Results,
|
22 |
"Qwen3-30B-A3B": q3_30B_A3B_Results,
|
23 |
}
|
@@ -51,10 +53,12 @@ export default function ViewerTab({
|
|
51 |
const [runs, setRuns] = useState<Run[]>([]);
|
52 |
const [selectedModel, setSelectedModel] = useState<string>("Qwen3-14B");
|
53 |
const [modelStats, setModelStats] = useState<ModelStats | null>(null);
|
|
|
|
|
54 |
|
55 |
useEffect(() => {
|
56 |
// Convert the model data to the format expected by RunsList
|
57 |
-
const convertedRuns = models[selectedModel]
|
58 |
start_article: string;
|
59 |
destination_article: string;
|
60 |
steps: { type: string; article: string }[];
|
@@ -64,7 +68,7 @@ export default function ViewerTab({
|
|
64 |
destination_article: run.destination_article,
|
65 |
steps: run.steps.map((step: { article: string }) => step.article),
|
66 |
result: run.result
|
67 |
-
}));
|
68 |
setRuns(convertedRuns);
|
69 |
|
70 |
// Calculate model statistics
|
@@ -105,7 +109,7 @@ export default function ViewerTab({
|
|
105 |
minSteps,
|
106 |
maxSteps
|
107 |
});
|
108 |
-
}, [selectedModel]);
|
109 |
|
110 |
const handleRunSelect = (runId: number) => {
|
111 |
setSelectedRun(runId);
|
@@ -124,6 +128,49 @@ export default function ViewerTab({
|
|
124 |
}));
|
125 |
}, [filterRuns]);
|
126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
return (
|
128 |
<div className="grid grid-cols-1 md:grid-cols-12 gap-4 h-[calc(100vh-200px)] max-h-[calc(100vh-200px)] overflow-hidden p-2">
|
129 |
<Card className="p-3 col-span-12 row-start-1">
|
@@ -143,6 +190,23 @@ export default function ViewerTab({
|
|
143 |
</Select>
|
144 |
</div>
|
145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
{modelStats && (
|
147 |
<div className="flex flex-wrap gap-1.5 items-center">
|
148 |
<Badge variant="outline" className="px-2 py-0.5 flex gap-1 items-center">
|
|
|
3 |
import q3Results from "../../results/qwen3.json"
|
4 |
import q3_30B_A3B_Results from "../../results/qwen3-30B-A3-results.json"
|
5 |
// import mockResults from "../../qwen3-final-results.json"
|
6 |
+
import { useMemo, useState, useEffect, useRef } from "react";
|
7 |
import { Card } from "@/components/ui/card";
|
8 |
import ForceDirectedGraph from "@/components/force-directed-graph";
|
9 |
import RunsList from "@/components/runs-list";
|
|
|
16 |
} from "@/components/ui/select";
|
17 |
import { Run as ForceGraphRun } from "@/components/reasoning-trace";
|
18 |
import { Badge } from "@/components/ui/badge";
|
19 |
+
import { Button } from "@/components/ui/button";
|
20 |
+
import { UploadIcon } from "lucide-react";
|
21 |
|
22 |
+
const defaultModels = {
|
23 |
"Qwen3-14B": q3Results,
|
24 |
"Qwen3-30B-A3B": q3_30B_A3B_Results,
|
25 |
}
|
|
|
53 |
const [runs, setRuns] = useState<Run[]>([]);
|
54 |
const [selectedModel, setSelectedModel] = useState<string>("Qwen3-14B");
|
55 |
const [modelStats, setModelStats] = useState<ModelStats | null>(null);
|
56 |
+
const [models, setModels] = useState(defaultModels);
|
57 |
+
const fileInputRef = useRef<HTMLInputElement>(null);
|
58 |
|
59 |
useEffect(() => {
|
60 |
// Convert the model data to the format expected by RunsList
|
61 |
+
const convertedRuns = models[selectedModel]?.runs?.map((run: {
|
62 |
start_article: string;
|
63 |
destination_article: string;
|
64 |
steps: { type: string; article: string }[];
|
|
|
68 |
destination_article: run.destination_article,
|
69 |
steps: run.steps.map((step: { article: string }) => step.article),
|
70 |
result: run.result
|
71 |
+
})) || [];
|
72 |
setRuns(convertedRuns);
|
73 |
|
74 |
// Calculate model statistics
|
|
|
109 |
minSteps,
|
110 |
maxSteps
|
111 |
});
|
112 |
+
}, [selectedModel, models]);
|
113 |
|
114 |
const handleRunSelect = (runId: number) => {
|
115 |
setSelectedRun(runId);
|
|
|
128 |
}));
|
129 |
}, [filterRuns]);
|
130 |
|
131 |
+
const handleFileUpload = (event: React.ChangeEvent<HTMLInputElement>) => {
|
132 |
+
const file = event.target.files?.[0];
|
133 |
+
if (!file) return;
|
134 |
+
|
135 |
+
const reader = new FileReader();
|
136 |
+
reader.onload = (e) => {
|
137 |
+
try {
|
138 |
+
const jsonData = JSON.parse(e.target?.result as string);
|
139 |
+
|
140 |
+
// Validate the JSON structure has the required fields
|
141 |
+
if (!jsonData.runs || !Array.isArray(jsonData.runs)) {
|
142 |
+
alert("Invalid JSON format. File must contain a 'runs' array.");
|
143 |
+
return;
|
144 |
+
}
|
145 |
+
|
146 |
+
// Create a filename-based model name, removing extension and path
|
147 |
+
const fileName = file.name.replace(/\.[^/.]+$/, "");
|
148 |
+
const modelName = `Custom: ${fileName}`;
|
149 |
+
|
150 |
+
// Add the new model to the models object
|
151 |
+
setModels(prev => ({
|
152 |
+
...prev,
|
153 |
+
[modelName]: jsonData
|
154 |
+
}));
|
155 |
+
|
156 |
+
// Select the newly added model
|
157 |
+
setSelectedModel(modelName);
|
158 |
+
} catch (error) {
|
159 |
+
alert(`Error parsing JSON file: ${error.message}`);
|
160 |
+
}
|
161 |
+
};
|
162 |
+
reader.readAsText(file);
|
163 |
+
|
164 |
+
// Reset the file input
|
165 |
+
if (fileInputRef.current) {
|
166 |
+
fileInputRef.current.value = '';
|
167 |
+
}
|
168 |
+
};
|
169 |
+
|
170 |
+
const handleUploadClick = () => {
|
171 |
+
fileInputRef.current?.click();
|
172 |
+
};
|
173 |
+
|
174 |
return (
|
175 |
<div className="grid grid-cols-1 md:grid-cols-12 gap-4 h-[calc(100vh-200px)] max-h-[calc(100vh-200px)] overflow-hidden p-2">
|
176 |
<Card className="p-3 col-span-12 row-start-1">
|
|
|
190 |
</Select>
|
191 |
</div>
|
192 |
|
193 |
+
<Button
|
194 |
+
variant="outline"
|
195 |
+
size="sm"
|
196 |
+
className="flex items-center gap-1"
|
197 |
+
onClick={handleUploadClick}
|
198 |
+
>
|
199 |
+
<UploadIcon size={14} />
|
200 |
+
<span>Upload JSON</span>
|
201 |
+
<input
|
202 |
+
type="file"
|
203 |
+
ref={fileInputRef}
|
204 |
+
accept=".json"
|
205 |
+
className="hidden"
|
206 |
+
onChange={handleFileUpload}
|
207 |
+
/>
|
208 |
+
</Button>
|
209 |
+
|
210 |
{modelStats && (
|
211 |
<div className="flex flex-wrap gap-1.5 items-center">
|
212 |
<Badge variant="outline" className="px-2 py-0.5 flex gap-1 items-center">
|