Spaces:
Sleeping
Sleeping
Add server
Browse files
server.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
import asyncio
|
7 |
+
import json
|
8 |
+
import traceback
|
9 |
+
import bittensor as bt
|
10 |
+
|
11 |
+
from collections import Counter
|
12 |
+
|
13 |
+
from neurons.validator import Validator
|
14 |
+
from prompting.dendrite import DendriteResponseEvent
|
15 |
+
from prompting.protocol import PromptingSynapse
|
16 |
+
from prompting.utils.uids import get_random_uids
|
17 |
+
from prompting.rewards import DateRewardModel, FloatDiffModel
|
18 |
+
from aiohttp import web
|
19 |
+
from aiohttp.web_response import Response
|
20 |
+
|
21 |
+
"""
|
22 |
+
# test
|
23 |
+
```
|
24 |
+
curl -X POST http://0.0.0.0:10000/chat/ -H "api_key: hello" -d '{"k": 5, "timeout": 3, "roles": ["user"], "messages": ["hello world"]}'
|
25 |
+
|
26 |
+
curl -X POST http://0.0.0.0:10000/chat/ -H "api_key: hey-michal" -d '{"k": 5, "timeout": 3, "roles": ["user"], "messages": ["on what exact date did the 21st century begin?"]}'
|
27 |
+
```
|
28 |
+
|
29 |
+
TROUBLESHOOT
|
30 |
+
check if port is open
|
31 |
+
```
|
32 |
+
sudo ufw allow 10000/tcp
|
33 |
+
sudo ufw allow 10000/tcp
|
34 |
+
```
|
35 |
+
# run
|
36 |
+
```
|
37 |
+
EXPECTED_ACCESS_KEY="hey-michal" pm2 start app.py --interpreter python3 --name app -- --neuron.model_id mock --wallet.name sn1 --wallet.hotkey v1 --netuid 1 --neuron.tasks math --neuron.task_p 1 --neuron.device cpu
|
38 |
+
```
|
39 |
+
"""
|
40 |
+
|
41 |
+
EXPECTED_ACCESS_KEY = os.environ.get('EXPECTED_ACCESS_KEY')
|
42 |
+
|
43 |
+
validator = None
|
44 |
+
reward_models = {
|
45 |
+
'date_qa': DateRewardModel(),
|
46 |
+
'math': FloatDiffModel(),
|
47 |
+
}
|
48 |
+
|
49 |
+
def completion_is_valid(completion: str):
|
50 |
+
"""
|
51 |
+
Get the completion statuses from the completions.
|
52 |
+
"""
|
53 |
+
patt = re.compile(r'I\'m sorry|unable to|I cannot|I can\'t|I am unable|I am sorry|I can not|don\'t know|not sure|don\'t understand')
|
54 |
+
if not len(re.findall(r'\w+',completion)) or patt.search(completion):
|
55 |
+
return False
|
56 |
+
return True
|
57 |
+
|
58 |
+
|
59 |
+
def ensemble_result(completions: list, task_name: str, prefer: str = 'longest'):
|
60 |
+
"""
|
61 |
+
Ensemble completions from multiple models.
|
62 |
+
# TODO: Measure agreement
|
63 |
+
# TODO: Figure out how to mitigate the cabal effect (large groups will appear to be more credible)
|
64 |
+
# TODO: Reward pipeline
|
65 |
+
"""
|
66 |
+
if not completions:
|
67 |
+
return None
|
68 |
+
|
69 |
+
|
70 |
+
answer = None
|
71 |
+
if task_name in ('qa', 'summarization'):
|
72 |
+
# No special handling for QA or summarization
|
73 |
+
supporting_completions = completions
|
74 |
+
|
75 |
+
elif task_name == 'date_qa':
|
76 |
+
# filter the completions to be the ones that contain valid dates and if there are multiple dates, select the most common one (with support > 1)
|
77 |
+
dates = list(map(reward_models[task_name].parse_dates_from_text, completions))
|
78 |
+
bt.logging.info(f"Unprocessed dates: {dates}")
|
79 |
+
valid_date_indices = [i for i, d in enumerate(dates) if d]
|
80 |
+
valid_completions = [completions[i] for i in valid_date_indices]
|
81 |
+
valid_dates = [dates[i] for i in valid_date_indices]
|
82 |
+
dates = [f"{d[0].strftime('%-d %B')} {d[1]}" for d in valid_dates]
|
83 |
+
if not dates:
|
84 |
+
return None
|
85 |
+
|
86 |
+
counter = Counter(dates)
|
87 |
+
most_common, count = counter.most_common()[0]
|
88 |
+
answer = most_common
|
89 |
+
if count == 1:
|
90 |
+
supporting_completions = valid_completions
|
91 |
+
else:
|
92 |
+
supporting_completions = [c for i, c in enumerate(valid_completions) if dates[i]==most_common]
|
93 |
+
|
94 |
+
elif task_name == 'math':
|
95 |
+
# filter the completions to be the ones that contain valid numbers and if there are multiple values, select the most common one (with support > 1)
|
96 |
+
# TODO: use the median instead of the most common value
|
97 |
+
vals = list(map(reward_models[task_name].extract_number, completions))
|
98 |
+
vals = [val for val in vals if val]
|
99 |
+
if not vals:
|
100 |
+
return None
|
101 |
+
|
102 |
+
most_common, count = Counter(dates).most_common()[0]
|
103 |
+
bt.logging.info(f"Most common value: {most_common}, count: {count}")
|
104 |
+
answer = most_common
|
105 |
+
if count == 1:
|
106 |
+
supporting_completions = completions
|
107 |
+
else:
|
108 |
+
supporting_completions = [c for i, c in enumerate(completions) if vals[i]==most_common]
|
109 |
+
|
110 |
+
|
111 |
+
bt.logging.info(f"Supporting completions: {supporting_completions}")
|
112 |
+
if prefer == 'longest':
|
113 |
+
preferred_completion = sorted(supporting_completions, key=len)[-1]
|
114 |
+
elif prefer == 'shortest':
|
115 |
+
preferred_completion = sorted(supporting_completions, key=len)[0]
|
116 |
+
elif prefer == 'most_common':
|
117 |
+
preferred_completion = max(set(supporting_completions), key=supporting_completions.count)
|
118 |
+
else:
|
119 |
+
raise ValueError(f"Unknown ensemble preference: {prefer}")
|
120 |
+
|
121 |
+
return {
|
122 |
+
'completion': preferred_completion,
|
123 |
+
'accepted_answer': answer,
|
124 |
+
'support': len(supporting_completions),
|
125 |
+
'support_indices': [completions.index(c) for c in supporting_completions],
|
126 |
+
'method': f'Selected the {prefer.replace("_", " ")} completion'
|
127 |
+
}
|
128 |
+
|
129 |
+
def guess_task_name(challenge: str):
|
130 |
+
categories = {
|
131 |
+
'summarization': re.compile('summar|quick rundown|overview'),
|
132 |
+
'date_qa': re.compile('exact date|tell me when|on what date|on what day|was born?|died?'),
|
133 |
+
'math': re.compile('math|solve|solution| sum |problem|geometric|vector|calculate|degrees|decimal|factorial'),
|
134 |
+
}
|
135 |
+
for task_name, patt in categories.items():
|
136 |
+
if patt.search(challenge):
|
137 |
+
return task_name
|
138 |
+
|
139 |
+
return 'qa'
|
140 |
+
|
141 |
+
async def chat(request: web.Request) -> Response:
|
142 |
+
"""
|
143 |
+
Chat endpoint for the validator.
|
144 |
+
|
145 |
+
Required headers:
|
146 |
+
- api_key: The access key for the validator.
|
147 |
+
|
148 |
+
Required body:
|
149 |
+
- roles: The list of roles to query.
|
150 |
+
- messages: The list of messages to query.
|
151 |
+
Optional body:
|
152 |
+
- k: The number of nodes to query.
|
153 |
+
- exclude: The list of nodes to exclude from the query.
|
154 |
+
- timeout: The timeout for the query.
|
155 |
+
"""
|
156 |
+
|
157 |
+
bt.logging.info(f'chat()')
|
158 |
+
# Check access key
|
159 |
+
access_key = request.headers.get("api_key")
|
160 |
+
if EXPECTED_ACCESS_KEY is not None and access_key != EXPECTED_ACCESS_KEY:
|
161 |
+
bt.logging.error(f'Invalid access key: {access_key}')
|
162 |
+
return Response(status=401, reason="Invalid access key")
|
163 |
+
|
164 |
+
try:
|
165 |
+
request_data = await request.json()
|
166 |
+
except ValueError:
|
167 |
+
bt.logging.error(f'Invalid request data: {request_data}')
|
168 |
+
return Response(status=400)
|
169 |
+
|
170 |
+
bt.logging.info(f'Request data: {request_data}')
|
171 |
+
k = request_data.get('k', 10)
|
172 |
+
exclude = request_data.get('exclude', [])
|
173 |
+
timeout = request_data.get('timeout', 10)
|
174 |
+
prefer = request_data.get('prefer', 'longest')
|
175 |
+
try:
|
176 |
+
# Guess the task name of current request
|
177 |
+
task_name = guess_task_name(request_data['messages'][-1])
|
178 |
+
|
179 |
+
# Get the list of uids to query for this step.
|
180 |
+
uids = get_random_uids(validator, k=k, exclude=exclude or []).to(validator.device)
|
181 |
+
axons = [validator.metagraph.axons[uid] for uid in uids]
|
182 |
+
|
183 |
+
# Make calls to the network with the prompt.
|
184 |
+
bt.logging.info(f'Calling dendrite')
|
185 |
+
responses = await validator.dendrite(
|
186 |
+
axons=axons,
|
187 |
+
synapse=PromptingSynapse(roles=request_data['roles'], messages=request_data['messages']),
|
188 |
+
timeout=timeout,
|
189 |
+
)
|
190 |
+
|
191 |
+
bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
|
192 |
+
# Encapsulate the responses in a response event (dataclass)
|
193 |
+
response_event = DendriteResponseEvent(responses, uids)
|
194 |
+
|
195 |
+
# convert dict to json
|
196 |
+
response = response_event.__state_dict__()
|
197 |
+
|
198 |
+
response['completion_is_valid'] = valid = list(map(completion_is_valid, response['completions']))
|
199 |
+
valid_completions = [response['completions'][i] for i, v in enumerate(valid) if v]
|
200 |
+
|
201 |
+
response['task_name'] = task_name
|
202 |
+
response['ensemble_result'] = ensemble_result(valid_completions, task_name=task_name, prefer=prefer)
|
203 |
+
|
204 |
+
bt.logging.info(f"Response:\n {response}")
|
205 |
+
return Response(status=200, reason="I can't believe it's not butter!", text=json.dumps(response))
|
206 |
+
|
207 |
+
except Exception:
|
208 |
+
bt.logging.error(f'Encountered in {chat.__name__}:\n{traceback.format_exc()}')
|
209 |
+
return Response(status=500, reason="Internal error")
|
210 |
+
|
211 |
+
|
212 |
+
|
213 |
+
|
214 |
+
class ValidatorApplication(web.Application):
|
215 |
+
def __init__(self, *a, **kw):
|
216 |
+
super().__init__(*a, **kw)
|
217 |
+
# TODO: Enable rewarding and other features
|
218 |
+
|
219 |
+
|
220 |
+
validator_app = ValidatorApplication()
|
221 |
+
validator_app.add_routes([web.post('/chat/', chat)])
|
222 |
+
|
223 |
+
bt.logging.info("Starting validator application.")
|
224 |
+
bt.logging.info(validator_app)
|
225 |
+
|
226 |
+
|
227 |
+
def main(run_aio_app=True, test=False) -> None:
|
228 |
+
|
229 |
+
loop = asyncio.get_event_loop()
|
230 |
+
|
231 |
+
# port = validator.metagraph.axons[validator.uid].port
|
232 |
+
port = 10000
|
233 |
+
if run_aio_app:
|
234 |
+
try:
|
235 |
+
web.run_app(validator_app, port=port, loop=loop)
|
236 |
+
except KeyboardInterrupt:
|
237 |
+
bt.logging.info("Keyboard interrupt detected. Exiting validator.")
|
238 |
+
finally:
|
239 |
+
pass
|
240 |
+
|
241 |
+
if __name__ == "__main__":
|
242 |
+
validator = Validator()
|
243 |
+
main()
|