steffenc commited on
Commit
82b3169
·
1 Parent(s): 2e2d80e

Add server

Browse files
Files changed (1) hide show
  1. server.py +243 -0
server.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ import os
5
+ import re
6
+ import asyncio
7
+ import json
8
+ import traceback
9
+ import bittensor as bt
10
+
11
+ from collections import Counter
12
+
13
+ from neurons.validator import Validator
14
+ from prompting.dendrite import DendriteResponseEvent
15
+ from prompting.protocol import PromptingSynapse
16
+ from prompting.utils.uids import get_random_uids
17
+ from prompting.rewards import DateRewardModel, FloatDiffModel
18
+ from aiohttp import web
19
+ from aiohttp.web_response import Response
20
+
21
+ """
22
+ # test
23
+ ```
24
+ curl -X POST http://0.0.0.0:10000/chat/ -H "api_key: hello" -d '{"k": 5, "timeout": 3, "roles": ["user"], "messages": ["hello world"]}'
25
+
26
+ curl -X POST http://0.0.0.0:10000/chat/ -H "api_key: hey-michal" -d '{"k": 5, "timeout": 3, "roles": ["user"], "messages": ["on what exact date did the 21st century begin?"]}'
27
+ ```
28
+
29
+ TROUBLESHOOT
30
+ check if port is open
31
+ ```
32
+ sudo ufw allow 10000/tcp
33
+ sudo ufw allow 10000/tcp
34
+ ```
35
+ # run
36
+ ```
37
+ EXPECTED_ACCESS_KEY="hey-michal" pm2 start app.py --interpreter python3 --name app -- --neuron.model_id mock --wallet.name sn1 --wallet.hotkey v1 --netuid 1 --neuron.tasks math --neuron.task_p 1 --neuron.device cpu
38
+ ```
39
+ """
40
+
41
+ EXPECTED_ACCESS_KEY = os.environ.get('EXPECTED_ACCESS_KEY')
42
+
43
+ validator = None
44
+ reward_models = {
45
+ 'date_qa': DateRewardModel(),
46
+ 'math': FloatDiffModel(),
47
+ }
48
+
49
+ def completion_is_valid(completion: str):
50
+ """
51
+ Get the completion statuses from the completions.
52
+ """
53
+ patt = re.compile(r'I\'m sorry|unable to|I cannot|I can\'t|I am unable|I am sorry|I can not|don\'t know|not sure|don\'t understand')
54
+ if not len(re.findall(r'\w+',completion)) or patt.search(completion):
55
+ return False
56
+ return True
57
+
58
+
59
+ def ensemble_result(completions: list, task_name: str, prefer: str = 'longest'):
60
+ """
61
+ Ensemble completions from multiple models.
62
+ # TODO: Measure agreement
63
+ # TODO: Figure out how to mitigate the cabal effect (large groups will appear to be more credible)
64
+ # TODO: Reward pipeline
65
+ """
66
+ if not completions:
67
+ return None
68
+
69
+
70
+ answer = None
71
+ if task_name in ('qa', 'summarization'):
72
+ # No special handling for QA or summarization
73
+ supporting_completions = completions
74
+
75
+ elif task_name == 'date_qa':
76
+ # filter the completions to be the ones that contain valid dates and if there are multiple dates, select the most common one (with support > 1)
77
+ dates = list(map(reward_models[task_name].parse_dates_from_text, completions))
78
+ bt.logging.info(f"Unprocessed dates: {dates}")
79
+ valid_date_indices = [i for i, d in enumerate(dates) if d]
80
+ valid_completions = [completions[i] for i in valid_date_indices]
81
+ valid_dates = [dates[i] for i in valid_date_indices]
82
+ dates = [f"{d[0].strftime('%-d %B')} {d[1]}" for d in valid_dates]
83
+ if not dates:
84
+ return None
85
+
86
+ counter = Counter(dates)
87
+ most_common, count = counter.most_common()[0]
88
+ answer = most_common
89
+ if count == 1:
90
+ supporting_completions = valid_completions
91
+ else:
92
+ supporting_completions = [c for i, c in enumerate(valid_completions) if dates[i]==most_common]
93
+
94
+ elif task_name == 'math':
95
+ # filter the completions to be the ones that contain valid numbers and if there are multiple values, select the most common one (with support > 1)
96
+ # TODO: use the median instead of the most common value
97
+ vals = list(map(reward_models[task_name].extract_number, completions))
98
+ vals = [val for val in vals if val]
99
+ if not vals:
100
+ return None
101
+
102
+ most_common, count = Counter(dates).most_common()[0]
103
+ bt.logging.info(f"Most common value: {most_common}, count: {count}")
104
+ answer = most_common
105
+ if count == 1:
106
+ supporting_completions = completions
107
+ else:
108
+ supporting_completions = [c for i, c in enumerate(completions) if vals[i]==most_common]
109
+
110
+
111
+ bt.logging.info(f"Supporting completions: {supporting_completions}")
112
+ if prefer == 'longest':
113
+ preferred_completion = sorted(supporting_completions, key=len)[-1]
114
+ elif prefer == 'shortest':
115
+ preferred_completion = sorted(supporting_completions, key=len)[0]
116
+ elif prefer == 'most_common':
117
+ preferred_completion = max(set(supporting_completions), key=supporting_completions.count)
118
+ else:
119
+ raise ValueError(f"Unknown ensemble preference: {prefer}")
120
+
121
+ return {
122
+ 'completion': preferred_completion,
123
+ 'accepted_answer': answer,
124
+ 'support': len(supporting_completions),
125
+ 'support_indices': [completions.index(c) for c in supporting_completions],
126
+ 'method': f'Selected the {prefer.replace("_", " ")} completion'
127
+ }
128
+
129
+ def guess_task_name(challenge: str):
130
+ categories = {
131
+ 'summarization': re.compile('summar|quick rundown|overview'),
132
+ 'date_qa': re.compile('exact date|tell me when|on what date|on what day|was born?|died?'),
133
+ 'math': re.compile('math|solve|solution| sum |problem|geometric|vector|calculate|degrees|decimal|factorial'),
134
+ }
135
+ for task_name, patt in categories.items():
136
+ if patt.search(challenge):
137
+ return task_name
138
+
139
+ return 'qa'
140
+
141
+ async def chat(request: web.Request) -> Response:
142
+ """
143
+ Chat endpoint for the validator.
144
+
145
+ Required headers:
146
+ - api_key: The access key for the validator.
147
+
148
+ Required body:
149
+ - roles: The list of roles to query.
150
+ - messages: The list of messages to query.
151
+ Optional body:
152
+ - k: The number of nodes to query.
153
+ - exclude: The list of nodes to exclude from the query.
154
+ - timeout: The timeout for the query.
155
+ """
156
+
157
+ bt.logging.info(f'chat()')
158
+ # Check access key
159
+ access_key = request.headers.get("api_key")
160
+ if EXPECTED_ACCESS_KEY is not None and access_key != EXPECTED_ACCESS_KEY:
161
+ bt.logging.error(f'Invalid access key: {access_key}')
162
+ return Response(status=401, reason="Invalid access key")
163
+
164
+ try:
165
+ request_data = await request.json()
166
+ except ValueError:
167
+ bt.logging.error(f'Invalid request data: {request_data}')
168
+ return Response(status=400)
169
+
170
+ bt.logging.info(f'Request data: {request_data}')
171
+ k = request_data.get('k', 10)
172
+ exclude = request_data.get('exclude', [])
173
+ timeout = request_data.get('timeout', 10)
174
+ prefer = request_data.get('prefer', 'longest')
175
+ try:
176
+ # Guess the task name of current request
177
+ task_name = guess_task_name(request_data['messages'][-1])
178
+
179
+ # Get the list of uids to query for this step.
180
+ uids = get_random_uids(validator, k=k, exclude=exclude or []).to(validator.device)
181
+ axons = [validator.metagraph.axons[uid] for uid in uids]
182
+
183
+ # Make calls to the network with the prompt.
184
+ bt.logging.info(f'Calling dendrite')
185
+ responses = await validator.dendrite(
186
+ axons=axons,
187
+ synapse=PromptingSynapse(roles=request_data['roles'], messages=request_data['messages']),
188
+ timeout=timeout,
189
+ )
190
+
191
+ bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
192
+ # Encapsulate the responses in a response event (dataclass)
193
+ response_event = DendriteResponseEvent(responses, uids)
194
+
195
+ # convert dict to json
196
+ response = response_event.__state_dict__()
197
+
198
+ response['completion_is_valid'] = valid = list(map(completion_is_valid, response['completions']))
199
+ valid_completions = [response['completions'][i] for i, v in enumerate(valid) if v]
200
+
201
+ response['task_name'] = task_name
202
+ response['ensemble_result'] = ensemble_result(valid_completions, task_name=task_name, prefer=prefer)
203
+
204
+ bt.logging.info(f"Response:\n {response}")
205
+ return Response(status=200, reason="I can't believe it's not butter!", text=json.dumps(response))
206
+
207
+ except Exception:
208
+ bt.logging.error(f'Encountered in {chat.__name__}:\n{traceback.format_exc()}')
209
+ return Response(status=500, reason="Internal error")
210
+
211
+
212
+
213
+
214
+ class ValidatorApplication(web.Application):
215
+ def __init__(self, *a, **kw):
216
+ super().__init__(*a, **kw)
217
+ # TODO: Enable rewarding and other features
218
+
219
+
220
+ validator_app = ValidatorApplication()
221
+ validator_app.add_routes([web.post('/chat/', chat)])
222
+
223
+ bt.logging.info("Starting validator application.")
224
+ bt.logging.info(validator_app)
225
+
226
+
227
+ def main(run_aio_app=True, test=False) -> None:
228
+
229
+ loop = asyncio.get_event_loop()
230
+
231
+ # port = validator.metagraph.axons[validator.uid].port
232
+ port = 10000
233
+ if run_aio_app:
234
+ try:
235
+ web.run_app(validator_app, port=port, loop=loop)
236
+ except KeyboardInterrupt:
237
+ bt.logging.info("Keyboard interrupt detected. Exiting validator.")
238
+ finally:
239
+ pass
240
+
241
+ if __name__ == "__main__":
242
+ validator = Validator()
243
+ main()