File size: 11,237 Bytes
82b3169
 
 
 
 
728d41d
82b3169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
728d41d
 
 
82b3169
 
 
 
 
 
 
728d41d
82b3169
 
 
 
728d41d
 
 
 
 
 
82b3169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
728d41d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82b3169
 
 
 
 
 
 
 
728d41d
 
 
 
82b3169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313



import os
import re
import time
import asyncio
import json
import traceback
import bittensor as bt

from collections import Counter

from neurons.validator import Validator
from prompting.dendrite import DendriteResponseEvent
from prompting.protocol import PromptingSynapse
from prompting.utils.uids import get_random_uids
from prompting.rewards import DateRewardModel, FloatDiffModel
from aiohttp import web
from aiohttp.web_response import Response

"""
# test
```
curl -X POST http://0.0.0.0:10000/chat/ -H "api_key: hello" -d '{"k": 5, "timeout": 3, "roles": ["user"], "messages": ["hello world"]}'

curl -X POST http://0.0.0.0:10000/chat/ -H "api_key: hey-michal" -d '{"k": 5, "timeout": 3, "roles": ["user"], "messages": ["on what exact date did the 21st century begin?"]}'

# stream
curl --no-buffer -X POST http://129.146.127.82:10000/echo/ -H "api_key: hey-michal" -d '{"k": 3, "timeout": 0.2, "roles": ["user"], "messages": ["i need to tell you something important but first"]}'
```

TROUBLESHOOT
check if port is open
```
sudo ufw allow 10000/tcp
sudo ufw allow 10000/tcp
```
# run
```
EXPECTED_ACCESS_KEY="hey-michal" pm2 start app.py --interpreter python3 --name app -- --neuron.model_id mock --wallet.name sn1 --wallet.hotkey v1 --netuid 1 --neuron.tasks math --neuron.task_p 1 --neuron.device cpu
```

basic testing
```
EXPECTED_ACCESS_KEY="hey-michal" python app.py --neuron.model_id mock --wallet.name sn1 --wallet.hotkey v1 --netuid 1 --neuron.tasks math --neuron.task_p 1 --neuron.device cpu 
```
add --mock to test the echo stream
"""

EXPECTED_ACCESS_KEY = os.environ.get('EXPECTED_ACCESS_KEY')

validator = None
reward_models = {
    'date_qa': DateRewardModel(),
    'math': FloatDiffModel(),
}

def completion_is_valid(completion: str):
    """
    Get the completion statuses from the completions.
    """
    patt = re.compile(r'I\'m sorry|unable to|I cannot|I can\'t|I am unable|I am sorry|I can not|don\'t know|not sure|don\'t understand')
    if not len(re.findall(r'\w+',completion)) or patt.search(completion):
        return False    
    return True


def ensemble_result(completions: list, task_name: str, prefer: str = 'longest'):
    """
    Ensemble completions from multiple models.
    # TODO: Measure agreement
    # TODO: Figure out how to mitigate the cabal effect (large groups will appear to be more credible)
    # TODO: Reward pipeline
    """
    if not completions:
        return None
    
    
    answer = None
    if task_name in ('qa', 'summarization'):
        # No special handling for QA or summarization
        supporting_completions = completions
    
    elif task_name == 'date_qa':
        # filter the completions to be the ones that contain valid dates and if there are multiple dates, select the most common one (with support > 1)
        dates = list(map(reward_models[task_name].parse_dates_from_text, completions))
        bt.logging.info(f"Unprocessed dates: {dates}")
        valid_date_indices = [i for i, d in enumerate(dates) if d]
        valid_completions = [completions[i] for i in valid_date_indices]
        valid_dates = [dates[i] for i in valid_date_indices]
        dates = [f"{d[0].strftime('%-d %B')} {d[1]}" for d in valid_dates]
        if not dates:
            return None
        
        counter = Counter(dates)
        most_common, count = counter.most_common()[0]
        answer = most_common
        if count == 1:
            supporting_completions = valid_completions
        else:
            supporting_completions = [c for i, c in enumerate(valid_completions) if dates[i]==most_common]
        
    elif task_name == 'math':
        # filter the completions to be the ones that contain valid numbers and if there are multiple values, select the most common one (with support > 1)
        # TODO: use the median instead of the most common value
        vals = list(map(reward_models[task_name].extract_number, completions))
        vals = [val for val in vals if val]
        if not vals:
            return None
        
        most_common, count = Counter(dates).most_common()[0]
        bt.logging.info(f"Most common value: {most_common}, count: {count}")
        answer = most_common
        if count == 1:
            supporting_completions = completions
        else:
            supporting_completions = [c for i, c in enumerate(completions) if vals[i]==most_common]
        
    
    bt.logging.info(f"Supporting completions: {supporting_completions}")
    if prefer == 'longest':
        preferred_completion = sorted(supporting_completions, key=len)[-1]
    elif prefer == 'shortest':
        preferred_completion = sorted(supporting_completions, key=len)[0]
    elif prefer == 'most_common':
        preferred_completion = max(set(supporting_completions), key=supporting_completions.count)
    else:
        raise ValueError(f"Unknown ensemble preference: {prefer}")
        
    return {
        'completion': preferred_completion,
        'accepted_answer': answer, 
        'support': len(supporting_completions),
        'support_indices': [completions.index(c) for c in supporting_completions],
        'method': f'Selected the {prefer.replace("_", " ")} completion'
    }
    
def guess_task_name(challenge: str):
    categories = {
        'summarization': re.compile('summar|quick rundown|overview'),
        'date_qa': re.compile('exact date|tell me when|on what date|on what day|was born?|died?'),
        'math': re.compile('math|solve|solution| sum |problem|geometric|vector|calculate|degrees|decimal|factorial'),
    }
    for task_name, patt in categories.items():
        if patt.search(challenge):
            return task_name

    return 'qa'

async def chat(request: web.Request) -> Response:
    """
    Chat endpoint for the validator.

    Required headers:
    - api_key: The access key for the validator.

    Required body:
    - roles: The list of roles to query.
    - messages: The list of messages to query.
    Optional body:
    - k: The number of nodes to query.
    - exclude: The list of nodes to exclude from the query.
    - timeout: The timeout for the query.
    """

    bt.logging.info(f'chat()')
    # Check access key
    access_key = request.headers.get("api_key")
    if EXPECTED_ACCESS_KEY is not None and access_key != EXPECTED_ACCESS_KEY:
        bt.logging.error(f'Invalid access key: {access_key}')
        return Response(status=401, reason="Invalid access key")

    try:
        request_data = await request.json()
    except ValueError:
        bt.logging.error(f'Invalid request data: {request_data}')
        return Response(status=400)

    bt.logging.info(f'Request data: {request_data}')
    k = request_data.get('k', 10)
    exclude = request_data.get('exclude', [])
    timeout = request_data.get('timeout', 10)
    prefer = request_data.get('prefer', 'longest')
    try:
        # Guess the task name of current request
        task_name = guess_task_name(request_data['messages'][-1])
        
        # Get the list of uids to query for this step.
        uids = get_random_uids(validator, k=k, exclude=exclude or []).to(validator.device)
        axons = [validator.metagraph.axons[uid] for uid in uids]

        # Make calls to the network with the prompt.
        bt.logging.info(f'Calling dendrite')
        responses = await validator.dendrite(
            axons=axons,
            synapse=PromptingSynapse(roles=request_data['roles'], messages=request_data['messages']),
            timeout=timeout,
        )

        bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
        # Encapsulate the responses in a response event (dataclass)
        response_event = DendriteResponseEvent(responses, uids)

        # convert dict to json
        response = response_event.__state_dict__()
        
        response['completion_is_valid'] = valid = list(map(completion_is_valid, response['completions']))
        valid_completions = [response['completions'][i] for i, v in enumerate(valid) if v]

        response['task_name'] = task_name
        response['ensemble_result'] = ensemble_result(valid_completions, task_name=task_name, prefer=prefer)
        
        bt.logging.info(f"Response:\n {response}")
        return Response(status=200, reason="I can't believe it's not butter!", text=json.dumps(response))

    except Exception:
        bt.logging.error(f'Encountered in {chat.__name__}:\n{traceback.format_exc()}')
        return Response(status=500, reason="Internal error")



async def echo_stream(request):
    
    bt.logging.info(f'echo_stream()')
    # Check access key
    access_key = request.headers.get("api_key")
    if EXPECTED_ACCESS_KEY is not None and access_key != EXPECTED_ACCESS_KEY:
        bt.logging.error(f'Invalid access key: {access_key}')
        return Response(status=401, reason="Invalid access key")

    try:
        request_data = await request.json()
    except ValueError:
        bt.logging.error(f'Invalid request data: {request_data}')
        return Response(status=400)

    bt.logging.info(f'Request data: {request_data}')
    k = request_data.get('k', 1)
    exclude = request_data.get('exclude', [])
    timeout = request_data.get('timeout', 0.2)    
    message = '\n\n'.join(request_data['messages'])

    # Create a StreamResponse
    response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'text/plain'})
    await response.prepare(request)

    completion = ''
    # Echo the message k times with a timeout between each chunk
    for _ in range(k):
        for word in message.split():
            chunk = f'{word} '
            await response.write(chunk.encode('utf-8'))
            completion += chunk 
            time.sleep(timeout)
            bt.logging.info(f"Echoed: {chunk}")

    completion = completion.strip()

    # Prepare final JSON chunk
    json_chunk = json.dumps({
        "uids": [0],
        "completion": completion,
        "completions": [completion.strip()],
        "timings": [0],
        "status_messages": ['Went well!'],
        "status_codes": [200],
        "completion_is_valid": [True],
        "task_name": 'echo',
        "ensemble_result": {}
    })
    
    # Send the final JSON as part of the stream
    await response.write(f"\n\nJSON_RESPONSE_BEGIN:\n{json_chunk}".encode('utf-8'))

    # Finalize the response
    await response.write_eof()
    return response

class ValidatorApplication(web.Application):
    def __init__(self, *a, **kw):
        super().__init__(*a, **kw)
        # TODO: Enable rewarding and other features


validator_app = ValidatorApplication()
validator_app.add_routes([
    web.post('/chat/', chat),
    web.post('/echo/', echo_stream)
    ])

bt.logging.info("Starting validator application.")
bt.logging.info(validator_app)


def main(run_aio_app=True, test=False) -> None:

    loop = asyncio.get_event_loop()

    # port = validator.metagraph.axons[validator.uid].port
    port = 10000
    if run_aio_app:
        try:
            web.run_app(validator_app, port=port, loop=loop)
        except KeyboardInterrupt:
            bt.logging.info("Keyboard interrupt detected. Exiting validator.")
        finally:
            pass

if __name__ == "__main__":
    validator = Validator()
    main()