pedroferreira commited on
Commit
32e1e2e
·
1 Parent(s): b073cce

add middlewares + api refactoring

Browse files
Files changed (7) hide show
  1. api.py +0 -107
  2. forward.py +244 -244
  3. middlewares.py +34 -0
  4. requirements.txt +2 -1
  5. server.py +20 -127
  6. utils.py +47 -4
  7. validator_wrapper.py +0 -70
api.py DELETED
@@ -1,107 +0,0 @@
1
-
2
- import json
3
- import asyncio
4
-
5
- import traceback
6
- import bittensor as bt
7
-
8
- import utils
9
-
10
- from typing import List
11
- from neurons.validator import Validator
12
- from prompting.forward import handle_response
13
- from prompting.dendrite import DendriteResponseEvent
14
- from prompting.protocol import PromptingSynapse, StreamPromptingSynapse
15
- from prompting.utils.uids import get_random_uids
16
-
17
- from aiohttp import web
18
-
19
- from aiohttp.web_response import Response
20
-
21
-
22
- async def single_response(validator: Validator, messages: List[str], roles: List[str], k: int = 5, timeout: float = 3.0, exclude: List[int] = None, prefer: str = 'longest') -> Response:
23
-
24
- try:
25
- # Guess the task name of current request
26
- task_name = utils.guess_task_name(messages[-1])
27
-
28
- # Get the list of uids to query for this step.
29
- uids = get_random_uids(validator, k=k, exclude=exclude or []).tolist()
30
- axons = [validator.metagraph.axons[uid] for uid in uids]
31
-
32
- # Make calls to the network with the prompt.
33
- bt.logging.info(f'Calling dendrite')
34
- responses = await validator.dendrite(
35
- axons=axons,
36
- synapse=PromptingSynapse(roles=roles, messages=messages),
37
- timeout=timeout,
38
- )
39
-
40
- bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
41
- # Encapsulate the responses in a response event (dataclass)
42
- response_event = DendriteResponseEvent(responses, uids)
43
-
44
- # convert dict to json
45
- response = response_event.__state_dict__()
46
-
47
- response['completion_is_valid'] = valid = list(map(utils.completion_is_valid, response['completions']))
48
- valid_completions = [response['completions'][i] for i, v in enumerate(valid) if v]
49
-
50
- response['task_name'] = task_name
51
- response['ensemble_result'] = utils.ensemble_result(valid_completions, task_name=task_name, prefer=prefer)
52
-
53
- bt.logging.info(f"Response:\n {response}")
54
- return Response(status=200, reason="I can't believe it's not butter!", text=json.dumps(response))
55
-
56
- except Exception:
57
- bt.logging.error(f'Encountered in {single_response.__name__}:\n{traceback.format_exc()}')
58
- return Response(status=500, reason="Internal error")
59
-
60
-
61
- async def stream_response(validator: Validator, messages: List[str], roles: List[str], k: int = 5, timeout: float = 3.0, exclude: List[int] = None, prefer: str = 'longest') -> web.StreamResponse:
62
-
63
- try:
64
- # Guess the task name of current request
65
- task_name = utils.guess_task_name(messages[-1])
66
-
67
- # Get the list of uids to query for this step.
68
- uids = get_random_uids(validator, k=k, exclude=exclude or []).tolist()
69
- axons = [validator.metagraph.axons[uid] for uid in uids]
70
-
71
- # Make calls to the network with the prompt.
72
- bt.logging.info(f'Calling dendrite')
73
- streams_responses = await validator.dendrite(
74
- axons=axons,
75
- synapse=StreamPromptingSynapse(roles=roles, messages=messages),
76
- timeout=timeout,
77
- deserialize=False,
78
- streaming=True,
79
- )
80
-
81
- # Prepare the task for handling stream responses
82
- handle_stream_responses_task = asyncio.create_task(
83
- handle_response(responses=dict(zip(uids, streams_responses)))
84
- )
85
-
86
- stream_results = await handle_stream_responses_task
87
-
88
- responses = [stream_result.synapse for stream_result in stream_results]
89
- bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
90
- # Encapsulate the responses in a response event (dataclass)
91
- response_event = DendriteResponseEvent(responses, uids)
92
-
93
- # convert dict to json
94
- response = response_event.__state_dict__()
95
-
96
- response['completion_is_valid'] = valid = list(map(utils.completion_is_valid, response['completions']))
97
- valid_completions = [response['completions'][i] for i, v in enumerate(valid) if v]
98
-
99
- response['task_name'] = task_name
100
- response['ensemble_result'] = utils.ensemble_result(valid_completions, task_name=task_name, prefer=prefer)
101
-
102
- bt.logging.info(f"Response:\n {response}")
103
- return Response(status=200, reason="I can't believe it's not butter!", text=json.dumps(response))
104
-
105
- except Exception:
106
- bt.logging.error(f'Encountered in {single_response.__name__}:\n{traceback.format_exc()}')
107
- return Response(status=500, reason="Internal error")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
forward.py CHANGED
@@ -1,244 +1,244 @@
1
- import time
2
- import sys
3
- import asyncio
4
- import numpy as np
5
- import bittensor as bt
6
- import traceback
7
- from typing import List, Dict, Awaitable
8
- from prompting.agent import HumanAgent
9
- from prompting.dendrite import DendriteResponseEvent
10
- from prompting.conversation import create_task
11
- from prompting.protocol import StreamPromptingSynapse
12
- from prompting.rewards import RewardResult
13
- from prompting.utils.uids import get_random_uids
14
- from prompting.utils.logging import log_event
15
- from prompting.utils.misc import async_log, serialize_exception_to_string
16
- from dataclasses import dataclass
17
-
18
- @async_log
19
- async def generate_reference(agent):
20
- loop = asyncio.get_running_loop()
21
- result = await loop.run_in_executor(None, agent.task.generate_reference, agent.llm_pipeline)
22
- return result
23
-
24
- @async_log
25
- async def execute_dendrite_call(dendrite_call):
26
- responses = await dendrite_call
27
- return responses
28
-
29
- @dataclass
30
- class StreamResult:
31
- synapse: StreamPromptingSynapse = None
32
- exception: BaseException = None
33
- uid: int = None
34
-
35
-
36
- async def process_response(uid: int, async_generator: Awaitable):
37
- """Process a single response asynchronously."""
38
- try:
39
- chunk = None # Initialize chunk with a default value
40
- async for chunk in async_generator: # most important loop, as this is where we acquire the final synapse.
41
- bt.logging.debug(f"\nchunk for uid {uid}: {chunk}")
42
-
43
- if chunk is not None:
44
- synapse = chunk # last object yielded is the synapse itself with completion filled
45
-
46
- # Assuming chunk holds the last value yielded which should be a synapse
47
- if isinstance(synapse, StreamPromptingSynapse):
48
- return synapse
49
-
50
- bt.logging.debug(
51
- f"Synapse is not StreamPromptingSynapse. Miner uid {uid} completion set to '' "
52
- )
53
- except Exception as e:
54
- # bt.logging.error(f"Error in generating reference or handling responses: {e}", exc_info=True)
55
- traceback_details = traceback.format_exc()
56
- bt.logging.error(
57
- f"Error in generating reference or handling responses for uid {uid}: {e}\n{traceback_details}"
58
- )
59
-
60
- failed_synapse = StreamPromptingSynapse(
61
- roles=["user"], messages=["failure"], completion=""
62
- )
63
-
64
- return failed_synapse
65
-
66
-
67
- @async_log
68
- async def handle_response(responses: Dict[int, Awaitable]) -> List[StreamResult]:
69
- """The handle_response function is responsible for creating asyncio tasks around acquiring streamed miner chunks
70
- and processing them asynchronously. It then pairs the results with their original UIDs and returns a list of StreamResults.
71
-
72
- Args:
73
- responses (Dict[int, Awaitable]): Responses contains awaitables that are used to acquire streamed miner chunks.
74
-
75
- Raises:
76
- ValueError
77
-
78
- Returns:
79
- List[StreamResult]: DataClass containing the synapse, exception, and uid
80
- """
81
- tasks_with_uid = [
82
- (uid, responses[uid]) for uid, _ in responses.items()
83
- ] # Pair UIDs with their tasks
84
-
85
- # Start tasks, preserving order and their associated UIDs
86
- tasks = [process_response(uid, resp) for uid, resp in tasks_with_uid]
87
-
88
- results = await asyncio.gather(*tasks, return_exceptions=True)
89
-
90
- mapped_results = []
91
- # Pair each result with its original uid
92
- for (uid, _), result in zip(tasks_with_uid, results):
93
- # If the result is a StreamPromptingSynapse, the response was successful and the stream result is added without exceptions
94
- if isinstance(result, StreamPromptingSynapse):
95
- mapped_results.append(StreamResult(synapse=result, uid=uid))
96
-
97
- # If the result is an exception, the response was unsuccessful and the stream result is added with the exception and an empty synapse
98
- elif isinstance(result, BaseException):
99
- failed_synapse = StreamPromptingSynapse(
100
- roles=["user"], messages=["failure"], completion=""
101
- )
102
- mapped_results.append(
103
- StreamResult(synapse=failed_synapse, exception=result, uid=uid)
104
- )
105
-
106
- # If the result is neither an error or a StreamSynapse, log the error and raise a ValueError
107
- else:
108
- bt.logging.error(f"Unexpected result type for UID {uid}: {result}")
109
- raise ValueError(f"Unexpected result type for UID {uid}: {result}")
110
-
111
- return mapped_results
112
-
113
-
114
- @async_log
115
- async def generate_reference(agent: HumanAgent):
116
- loop = asyncio.get_running_loop()
117
- result = await loop.run_in_executor(
118
- None, agent.task.generate_reference, agent.llm_pipeline
119
- )
120
- return result
121
-
122
-
123
- def log_stream_results(stream_results: List[StreamResult]):
124
- failed_responses = [
125
- response for response in stream_results if response.exception is not None
126
- ]
127
- empty_responses = [
128
- response
129
- for response in stream_results
130
- if response.exception is None and response.synapse.completion == ""
131
- ]
132
- non_empty_responses = [
133
- response
134
- for response in stream_results
135
- if response.exception is None and response.synapse.completion != ""
136
- ]
137
-
138
- bt.logging.info(f"Total of non_empty responses: ({len(non_empty_responses)})")
139
- bt.logging.info(f"Total of empty responses: ({len(empty_responses)})")
140
- bt.logging.info(
141
- f"Total of failed responses: ({len(failed_responses)}):\n {failed_responses}"
142
- )
143
-
144
- for failed_response in failed_responses:
145
- formatted_exception = serialize_exception_to_string(failed_response.exception)
146
- bt.logging.error(
147
- f"Failed response for uid {failed_response.uid}: {formatted_exception}"
148
- )
149
-
150
-
151
- async def run_step(
152
- self, agent: HumanAgent, k: int, timeout: float, exclude: list = None
153
- ):
154
- """Executes a single step of the agent, which consists of:
155
- - Getting a list of uids to query
156
- - Querying the network
157
- - Rewarding the network
158
- - Updating the scores
159
- - Logging the event
160
-
161
- Args:
162
- agent (HumanAgent): The agent to run the step for.
163
- k (int): The number of uids to query.
164
- timeout (float): The timeout for the queries.
165
- exclude (list, optional): The list of uids to exclude from the query. Defaults to [].
166
- """
167
-
168
- bt.logging.debug("run_step", agent.task.name)
169
-
170
- # Record event start time.
171
- start_time = time.time()
172
- # Get the list of uids to query for this step.
173
- uids = get_random_uids(self, k=k, exclude=exclude or []).to(self.device)
174
- uids_cpu = uids.cpu().tolist()
175
-
176
- axons = [self.metagraph.axons[uid] for uid in uids]
177
-
178
- # Directly call dendrite and process responses in parallel
179
- streams_responses = await self.dendrite(
180
- axons=axons,
181
- synapse=StreamPromptingSynapse(roles=["user"], messages=[agent.challenge]),
182
- timeout=timeout,
183
- deserialize=False,
184
- streaming=True,
185
- )
186
-
187
- # Prepare the task for handling stream responses
188
- handle_stream_responses_task = asyncio.create_task(
189
- handle_response(responses=dict(zip(uids_cpu, streams_responses)))
190
- )
191
-
192
- if not agent.task.static_reference:
193
- reference_generation_task = generate_reference(agent)
194
- _, stream_results = await asyncio.gather(
195
- reference_generation_task, handle_stream_responses_task
196
- )
197
- else:
198
- stream_results = await handle_stream_responses_task
199
-
200
- log_stream_results(stream_results)
201
-
202
- all_synapses_results = [stream_result.synapse for stream_result in stream_results]
203
-
204
- # Encapsulate the responses in a response event (dataclass)
205
- response_event = DendriteResponseEvent(
206
- responses=all_synapses_results, uids=uids, timeout=timeout
207
- )
208
-
209
- bt.logging.info(f"Created DendriteResponseEvent:\n {response_event}")
210
- # Reward the responses and get the reward result (dataclass)
211
- # This contains a list of RewardEvents but can be exported as a dict (column-wise) for logging etc
212
- reward_result = RewardResult(
213
- self.reward_pipeline,
214
- agent=agent,
215
- response_event=response_event,
216
- device=self.device,
217
- )
218
- bt.logging.info(f"Created RewardResult:\n {reward_result}")
219
-
220
- # The original idea was that the agent is 'satisfied' when it gets a good enough response (e.g. reward critera is met, such as ROUGE>threshold)
221
- agent.update_progress(
222
- top_reward=reward_result.rewards.max(),
223
- top_response=response_event.completions[reward_result.rewards.argmax()],
224
- )
225
-
226
- self.update_scores(reward_result.rewards, uids)
227
-
228
- stream_results_uids = [stream_result.uid for stream_result in stream_results]
229
- stream_results_exceptions = [
230
- serialize_exception_to_string(stream_result.exception)
231
- for stream_result in stream_results
232
- ]
233
- # Log the step event.
234
- event = {
235
- "block": self.block,
236
- "step_time": time.time() - start_time,
237
- "stream_results_uids": stream_results_uids,
238
- "stream_results_exceptions": stream_results_exceptions,
239
- **agent.__state_dict__(full=self.config.neuron.log_full),
240
- **reward_result.__state_dict__(full=self.config.neuron.log_full),
241
- **response_event.__state_dict__(),
242
- }
243
-
244
- return event
 
1
+ # import time
2
+ # import sys
3
+ # import asyncio
4
+ # import numpy as np
5
+ # import bittensor as bt
6
+ # import traceback
7
+ # from typing import List, Dict, Awaitable
8
+ # from prompting.agent import HumanAgent
9
+ # from prompting.dendrite import DendriteResponseEvent
10
+ # from prompting.conversation import create_task
11
+ # from prompting.protocol import StreamPromptingSynapse
12
+ # from prompting.rewards import RewardResult
13
+ # from prompting.utils.uids import get_random_uids
14
+ # from prompting.utils.logging import log_event
15
+ # from prompting.utils.misc import async_log, serialize_exception_to_string
16
+ # from dataclasses import dataclass
17
+
18
+ # @async_log
19
+ # async def generate_reference(agent):
20
+ # loop = asyncio.get_running_loop()
21
+ # result = await loop.run_in_executor(None, agent.task.generate_reference, agent.llm_pipeline)
22
+ # return result
23
+
24
+ # @async_log
25
+ # async def execute_dendrite_call(dendrite_call):
26
+ # responses = await dendrite_call
27
+ # return responses
28
+
29
+ # @dataclass
30
+ # class StreamResult:
31
+ # synapse: StreamPromptingSynapse = None
32
+ # exception: BaseException = None
33
+ # uid: int = None
34
+
35
+
36
+ # async def process_response(uid: int, async_generator: Awaitable):
37
+ # """Process a single response asynchronously."""
38
+ # try:
39
+ # chunk = None # Initialize chunk with a default value
40
+ # async for chunk in async_generator: # most important loop, as this is where we acquire the final synapse.
41
+ # bt.logging.debug(f"\nchunk for uid {uid}: {chunk}")
42
+
43
+ # if chunk is not None:
44
+ # synapse = chunk # last object yielded is the synapse itself with completion filled
45
+
46
+ # # Assuming chunk holds the last value yielded which should be a synapse
47
+ # if isinstance(synapse, StreamPromptingSynapse):
48
+ # return synapse
49
+
50
+ # bt.logging.debug(
51
+ # f"Synapse is not StreamPromptingSynapse. Miner uid {uid} completion set to '' "
52
+ # )
53
+ # except Exception as e:
54
+ # # bt.logging.error(f"Error in generating reference or handling responses: {e}", exc_info=True)
55
+ # traceback_details = traceback.format_exc()
56
+ # bt.logging.error(
57
+ # f"Error in generating reference or handling responses for uid {uid}: {e}\n{traceback_details}"
58
+ # )
59
+
60
+ # failed_synapse = StreamPromptingSynapse(
61
+ # roles=["user"], messages=["failure"], completion=""
62
+ # )
63
+
64
+ # return failed_synapse
65
+
66
+
67
+ # @async_log
68
+ # async def handle_response(responses: Dict[int, Awaitable]) -> List[StreamResult]:
69
+ # """The handle_response function is responsible for creating asyncio tasks around acquiring streamed miner chunks
70
+ # and processing them asynchronously. It then pairs the results with their original UIDs and returns a list of StreamResults.
71
+
72
+ # Args:
73
+ # responses (Dict[int, Awaitable]): Responses contains awaitables that are used to acquire streamed miner chunks.
74
+
75
+ # Raises:
76
+ # ValueError
77
+
78
+ # Returns:
79
+ # List[StreamResult]: DataClass containing the synapse, exception, and uid
80
+ # """
81
+ # tasks_with_uid = [
82
+ # (uid, responses[uid]) for uid, _ in responses.items()
83
+ # ] # Pair UIDs with their tasks
84
+
85
+ # # Start tasks, preserving order and their associated UIDs
86
+ # tasks = [process_response(uid, resp) for uid, resp in tasks_with_uid]
87
+
88
+ # results = await asyncio.gather(*tasks, return_exceptions=True)
89
+
90
+ # mapped_results = []
91
+ # # Pair each result with its original uid
92
+ # for (uid, _), result in zip(tasks_with_uid, results):
93
+ # # If the result is a StreamPromptingSynapse, the response was successful and the stream result is added without exceptions
94
+ # if isinstance(result, StreamPromptingSynapse):
95
+ # mapped_results.append(StreamResult(synapse=result, uid=uid))
96
+
97
+ # # If the result is an exception, the response was unsuccessful and the stream result is added with the exception and an empty synapse
98
+ # elif isinstance(result, BaseException):
99
+ # failed_synapse = StreamPromptingSynapse(
100
+ # roles=["user"], messages=["failure"], completion=""
101
+ # )
102
+ # mapped_results.append(
103
+ # StreamResult(synapse=failed_synapse, exception=result, uid=uid)
104
+ # )
105
+
106
+ # # If the result is neither an error or a StreamSynapse, log the error and raise a ValueError
107
+ # else:
108
+ # bt.logging.error(f"Unexpected result type for UID {uid}: {result}")
109
+ # raise ValueError(f"Unexpected result type for UID {uid}: {result}")
110
+
111
+ # return mapped_results
112
+
113
+
114
+ # @async_log
115
+ # async def generate_reference(agent: HumanAgent):
116
+ # loop = asyncio.get_running_loop()
117
+ # result = await loop.run_in_executor(
118
+ # None, agent.task.generate_reference, agent.llm_pipeline
119
+ # )
120
+ # return result
121
+
122
+
123
+ # def log_stream_results(stream_results: List[StreamResult]):
124
+ # failed_responses = [
125
+ # response for response in stream_results if response.exception is not None
126
+ # ]
127
+ # empty_responses = [
128
+ # response
129
+ # for response in stream_results
130
+ # if response.exception is None and response.synapse.completion == ""
131
+ # ]
132
+ # non_empty_responses = [
133
+ # response
134
+ # for response in stream_results
135
+ # if response.exception is None and response.synapse.completion != ""
136
+ # ]
137
+
138
+ # bt.logging.info(f"Total of non_empty responses: ({len(non_empty_responses)})")
139
+ # bt.logging.info(f"Total of empty responses: ({len(empty_responses)})")
140
+ # bt.logging.info(
141
+ # f"Total of failed responses: ({len(failed_responses)}):\n {failed_responses}"
142
+ # )
143
+
144
+ # for failed_response in failed_responses:
145
+ # formatted_exception = serialize_exception_to_string(failed_response.exception)
146
+ # bt.logging.error(
147
+ # f"Failed response for uid {failed_response.uid}: {formatted_exception}"
148
+ # )
149
+
150
+
151
+ # async def run_step(
152
+ # self, agent: HumanAgent, k: int, timeout: float, exclude: list = None
153
+ # ):
154
+ # """Executes a single step of the agent, which consists of:
155
+ # - Getting a list of uids to query
156
+ # - Querying the network
157
+ # - Rewarding the network
158
+ # - Updating the scores
159
+ # - Logging the event
160
+
161
+ # Args:
162
+ # agent (HumanAgent): The agent to run the step for.
163
+ # k (int): The number of uids to query.
164
+ # timeout (float): The timeout for the queries.
165
+ # exclude (list, optional): The list of uids to exclude from the query. Defaults to [].
166
+ # """
167
+
168
+ # bt.logging.debug("run_step", agent.task.name)
169
+
170
+ # # Record event start time.
171
+ # start_time = time.time()
172
+ # # Get the list of uids to query for this step.
173
+ # uids = get_random_uids(self, k=k, exclude=exclude or []).to(self.device)
174
+ # uids_cpu = uids.cpu().tolist()
175
+
176
+ # axons = [self.metagraph.axons[uid] for uid in uids]
177
+
178
+ # # Directly call dendrite and process responses in parallel
179
+ # streams_responses = await self.dendrite(
180
+ # axons=axons,
181
+ # synapse=StreamPromptingSynapse(roles=["user"], messages=[agent.challenge]),
182
+ # timeout=timeout,
183
+ # deserialize=False,
184
+ # streaming=True,
185
+ # )
186
+
187
+ # # Prepare the task for handling stream responses
188
+ # handle_stream_responses_task = asyncio.create_task(
189
+ # handle_response(responses=dict(zip(uids_cpu, streams_responses)))
190
+ # )
191
+
192
+ # if not agent.task.static_reference:
193
+ # reference_generation_task = generate_reference(agent)
194
+ # _, stream_results = await asyncio.gather(
195
+ # reference_generation_task, handle_stream_responses_task
196
+ # )
197
+ # else:
198
+ # stream_results = await handle_stream_responses_task
199
+
200
+ # log_stream_results(stream_results)
201
+
202
+ # all_synapses_results = [stream_result.synapse for stream_result in stream_results]
203
+
204
+ # # Encapsulate the responses in a response event (dataclass)
205
+ # response_event = DendriteResponseEvent(
206
+ # responses=all_synapses_results, uids=uids, timeout=timeout
207
+ # )
208
+
209
+ # bt.logging.info(f"Created DendriteResponseEvent:\n {response_event}")
210
+ # # Reward the responses and get the reward result (dataclass)
211
+ # # This contains a list of RewardEvents but can be exported as a dict (column-wise) for logging etc
212
+ # reward_result = RewardResult(
213
+ # self.reward_pipeline,
214
+ # agent=agent,
215
+ # response_event=response_event,
216
+ # device=self.device,
217
+ # )
218
+ # bt.logging.info(f"Created RewardResult:\n {reward_result}")
219
+
220
+ # # The original idea was that the agent is 'satisfied' when it gets a good enough response (e.g. reward critera is met, such as ROUGE>threshold)
221
+ # agent.update_progress(
222
+ # top_reward=reward_result.rewards.max(),
223
+ # top_response=response_event.completions[reward_result.rewards.argmax()],
224
+ # )
225
+
226
+ # self.update_scores(reward_result.rewards, uids)
227
+
228
+ # stream_results_uids = [stream_result.uid for stream_result in stream_results]
229
+ # stream_results_exceptions = [
230
+ # serialize_exception_to_string(stream_result.exception)
231
+ # for stream_result in stream_results
232
+ # ]
233
+ # # Log the step event.
234
+ # event = {
235
+ # "block": self.block,
236
+ # "step_time": time.time() - start_time,
237
+ # "stream_results_uids": stream_results_uids,
238
+ # "stream_results_exceptions": stream_results_exceptions,
239
+ # **agent.__state_dict__(full=self.config.neuron.log_full),
240
+ # **reward_result.__state_dict__(full=self.config.neuron.log_full),
241
+ # **response_event.__state_dict__(),
242
+ # }
243
+
244
+ # return event
middlewares.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import bittensor as bt
4
+ from aiohttp.web import Response
5
+
6
+ EXPECTED_ACCESS_KEY = os.environ.get('EXPECTED_ACCESS_KEY')
7
+
8
+ async def api_key_middleware(app, handler):
9
+ async def middleware_handler(request):
10
+ # Logging the request
11
+ bt.logging.info(f"Handling {request.method} request to {request.path}")
12
+
13
+ # Check access key
14
+ access_key = request.headers.get("api_key")
15
+ if EXPECTED_ACCESS_KEY is not None and access_key != EXPECTED_ACCESS_KEY:
16
+ bt.logging.error(f'Invalid access key: {access_key}')
17
+ return Response(status=401, reason="Invalid access key")
18
+
19
+ # Continue to the next handler if the API key is valid
20
+ return await handler(request)
21
+ return middleware_handler
22
+
23
+ async def json_parsing_middleware(app, handler):
24
+ async def middleware_handler(request):
25
+ try:
26
+ # Parsing JSON data from the request
27
+ request['data'] = await request.json()
28
+ except json.JSONDecodeError as e:
29
+ bt.logging.error(f'Invalid JSON data: {str(e)}')
30
+ return Response(status=400, text="Invalid JSON")
31
+
32
+ # Continue to the next handler if JSON is successfully parsed
33
+ return await handler(request)
34
+ return middleware_handler
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  git+https://github.com/opentensor/prompting.git@features/move-validator-into-prompting
2
- aiohttp
 
 
1
  git+https://github.com/opentensor/prompting.git@features/move-validator-into-prompting
2
+ aiohttp
3
+ deprecated
server.py CHANGED
@@ -1,16 +1,10 @@
1
-
2
-
3
-
4
- import os
5
- import time
6
  import asyncio
7
- import json
8
  import bittensor as bt
9
- from collections import Counter
10
- from validator_wrapper import QueryValidatorParams, S1ValidatorWrapper
11
- from prompting.rewards import DateRewardModel, FloatDiffModel
12
  from aiohttp import web
13
  from aiohttp.web_response import Response
 
 
14
 
15
  """
16
  # test
@@ -41,135 +35,34 @@ EXPECTED_ACCESS_KEY="hey-michal" python app.py --neuron.model_id mock --wallet.n
41
  add --mock to test the echo stream
42
  """
43
 
44
- EXPECTED_ACCESS_KEY = os.environ.get('EXPECTED_ACCESS_KEY')
45
-
46
  validator = None
47
 
48
-
 
49
  async def chat(request: web.Request) -> Response:
50
  """
51
  Chat endpoint for the validator.
52
-
53
- Required headers:
54
- - api_key: The access key for the validator.
55
-
56
- Required body:
57
- - roles: The list of roles to query.
58
- - messages: The list of messages to query.
59
- Optional body:
60
- - k: The number of nodes to query.
61
- - exclude: The list of nodes to exclude from the query.
62
- - timeout: The timeout for the query.
63
- """
64
-
65
- bt.logging.info(f'chat()')
66
- # Check access key
67
- access_key = request.headers.get("api_key")
68
- if EXPECTED_ACCESS_KEY is not None and access_key != EXPECTED_ACCESS_KEY:
69
- bt.logging.error(f'Invalid access key: {access_key}')
70
- return Response(status=401, reason="Invalid access key")
71
-
72
- try:
73
- request_data = await request.json()
74
- except ValueError:
75
- bt.logging.error(f'Invalid request data: {request_data}')
76
- return Response(status=400)
77
-
78
- # try:
79
- # # Guess the task name of current request
80
- # task_name = guess_task_name(request_data['messages'][-1])
81
-
82
- # # Get the list of uids to query for this step.
83
- # params = QueryValidatorParams.from_dict(request_data)
84
- # response_event = await validator.query_validator(params)
85
-
86
- # # convert dict to json
87
- # response = response_event.__state_dict__()
88
-
89
- # response['completion_is_valid'] = valid = list(map(completion_is_valid, response['completions']))
90
- # valid_completions = [response['completions'][i] for i, v in enumerate(valid) if v]
91
-
92
- # response['task_name'] = task_name
93
- # prefer = request_data.get('prefer', 'longest')
94
- # response['ensemble_result'] = ensemble_result(valid_completions, task_name=task_name, prefer=prefer)
95
-
96
- # bt.logging.info(f"Response:\n {response}")
97
- # return Response(status=200, reason="I can't believe it's not butter!", text=json.dumps(response))
98
- # except Exception:
99
- # bt.logging.error(f'Encountered in {chat.__name__}:\n{traceback.format_exc()}')
100
- # return Response(status=500, reason="Internal error")
101
- bt.logging.info(f'Request data: {request_data}')
102
-
103
- stream = request_data.get('stream', False)
104
- if stream:
105
- return stream_response(**request_data)
106
- else:
107
- return single_response(**request_data)
108
-
109
-
110
-
111
-
112
- async def echo_stream(request):
113
-
114
- bt.logging.info(f'echo_stream()')
115
- # Check access key
116
- access_key = request.headers.get("api_key")
117
- if EXPECTED_ACCESS_KEY is not None and access_key != EXPECTED_ACCESS_KEY:
118
- bt.logging.error(f'Invalid access key: {access_key}')
119
- return Response(status=401, reason="Invalid access key")
120
-
121
- try:
122
- request_data = await request.json()
123
- except ValueError:
124
- bt.logging.error(f'Invalid request data: {request_data}')
125
- return Response(status=400)
126
-
127
- bt.logging.info(f'Request data: {request_data}')
128
- k = request_data.get('k', 1)
129
- exclude = request_data.get('exclude', [])
130
- timeout = request_data.get('timeout', 0.2)
131
- message = '\n\n'.join(request_data['messages'])
132
-
133
- # Create a StreamResponse
134
- response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'text/plain'})
135
- await response.prepare(request)
136
-
137
- completion = ''
138
- # Echo the message k times with a timeout between each chunk
139
- for _ in range(k):
140
- for word in message.split():
141
- chunk = f'{word} '
142
- await response.write(chunk.encode('utf-8'))
143
- completion += chunk
144
- time.sleep(timeout)
145
- bt.logging.info(f"Echoed: {chunk}")
146
-
147
- completion = completion.strip()
148
-
149
- # Prepare final JSON chunk
150
- json_chunk = json.dumps({
151
- "uids": [0],
152
- "completion": completion,
153
- "completions": [completion.strip()],
154
- "timings": [0],
155
- "status_messages": ['Went well!'],
156
- "status_codes": [200],
157
- "completion_is_valid": [True],
158
- "task_name": 'echo',
159
- "ensemble_result": {}
160
- })
161
-
162
- # Send the final JSON as part of the stream
163
- await response.write(f"\n\nJSON_RESPONSE_BEGIN:\n{json_chunk}".encode('utf-8'))
164
-
165
- # Finalize the response
166
- await response.write_eof()
167
  return response
168
 
169
 
 
 
 
 
 
 
170
  class ValidatorApplication(web.Application):
171
  def __init__(self, *a, **kw):
172
  super().__init__(*a, **kw)
 
 
 
173
  # TODO: Enable rewarding and other features
174
 
175
 
 
 
 
 
 
 
1
  import asyncio
2
+ import utils
3
  import bittensor as bt
 
 
 
4
  from aiohttp import web
5
  from aiohttp.web_response import Response
6
+ from validators import S1ValidatorWrapper, QueryValidatorParams
7
+ from middlewares import api_key_middleware, json_parsing_middleware
8
 
9
  """
10
  # test
 
35
  add --mock to test the echo stream
36
  """
37
 
 
 
38
  validator = None
39
 
40
+ @api_key_middleware
41
+ @json_parsing_middleware
42
  async def chat(request: web.Request) -> Response:
43
  """
44
  Chat endpoint for the validator.
45
+ """
46
+ request_data = request['data']
47
+ params = QueryValidatorParams.from_dict(request_data)
48
+ # TODO: SET STREAM AS DEFAULT
49
+ stream = request_data.get('stream', False)
50
+ response = await validator.query_validator(params, stream=stream)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  return response
52
 
53
 
54
+ @api_key_middleware
55
+ @json_parsing_middleware
56
+ async def echo_stream(request, request_data):
57
+ request_data = request['data']
58
+ return await utils.echo_stream(request_data)
59
+
60
  class ValidatorApplication(web.Application):
61
  def __init__(self, *a, **kw):
62
  super().__init__(*a, **kw)
63
+ self.middlewares.append(api_key_middleware)
64
+ self.middlewares.append(json_parsing_middleware)
65
+
66
  # TODO: Enable rewarding and other features
67
 
68
 
utils.py CHANGED
@@ -1,8 +1,9 @@
1
  import re
2
  import bittensor as bt
3
-
 
 
4
  from collections import Counter
5
-
6
  from prompting.rewards import DateRewardModel, FloatDiffModel
7
 
8
  UNSUCCESSFUL_RESPONSE_PATTERNS = ["I'm sorry", "unable to", "I cannot", "I can't", "I am unable", "I am sorry", "I can not", "don't know", "not sure", "don't understand", "not capable"]
@@ -95,7 +96,6 @@ def ensemble_result(completions: list, task_name: str, prefer: str = 'longest'):
95
  }
96
 
97
  def guess_task_name(challenge: str):
98
-
99
  # TODO: use a pre-trained classifier to guess the task name
100
  categories = {
101
  'summarization': re.compile('summar|quick rundown|overview'),
@@ -106,4 +106,47 @@ def guess_task_name(challenge: str):
106
  if patt.search(challenge):
107
  return task_name
108
 
109
- return 'qa'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import re
2
  import bittensor as bt
3
+ import time
4
+ import json
5
+ from aiohttp import web
6
  from collections import Counter
 
7
  from prompting.rewards import DateRewardModel, FloatDiffModel
8
 
9
  UNSUCCESSFUL_RESPONSE_PATTERNS = ["I'm sorry", "unable to", "I cannot", "I can't", "I am unable", "I am sorry", "I can not", "don't know", "not sure", "don't understand", "not capable"]
 
96
  }
97
 
98
  def guess_task_name(challenge: str):
 
99
  # TODO: use a pre-trained classifier to guess the task name
100
  categories = {
101
  'summarization': re.compile('summar|quick rundown|overview'),
 
106
  if patt.search(challenge):
107
  return task_name
108
 
109
+ return 'qa'
110
+
111
+
112
+ async def echo_stream(request_data: dict):
113
+ k = request_data.get('k', 1)
114
+ exclude = request_data.get('exclude', [])
115
+ timeout = request_data.get('timeout', 0.2)
116
+ message = '\n\n'.join(request_data['messages'])
117
+
118
+ # Create a StreamResponse
119
+ response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'text/plain'})
120
+ await response.prepare()
121
+
122
+ completion = ''
123
+ # Echo the message k times with a timeout between each chunk
124
+ for _ in range(k):
125
+ for word in message.split():
126
+ chunk = f'{word} '
127
+ await response.write(chunk.encode('utf-8'))
128
+ completion += chunk
129
+ time.sleep(timeout)
130
+ bt.logging.info(f"Echoed: {chunk}")
131
+
132
+ completion = completion.strip()
133
+
134
+ # Prepare final JSON chunk
135
+ json_chunk = json.dumps({
136
+ "uids": [0],
137
+ "completion": completion,
138
+ "completions": [completion.strip()],
139
+ "timings": [0],
140
+ "status_messages": ['Went well!'],
141
+ "status_codes": [200],
142
+ "completion_is_valid": [True],
143
+ "task_name": 'echo',
144
+ "ensemble_result": {}
145
+ })
146
+
147
+ # Send the final JSON as part of the stream
148
+ await response.write(f"\n\nJSON_RESPONSE_BEGIN:\n{json_chunk}".encode('utf-8'))
149
+
150
+ # Finalize the response
151
+ await response.write_eof()
152
+ return response
validator_wrapper.py DELETED
@@ -1,70 +0,0 @@
1
- import bittensor as bt
2
- from prompting.validator import Validator
3
- from prompting.utils.uids import get_random_uids
4
- from prompting.protocol import PromptingSynapse
5
- from prompting.dendrite import DendriteResponseEvent
6
- from abc import ABC, abstractmethod
7
- from typing import List
8
- from dataclasses import dataclass
9
-
10
- @dataclass
11
- class QueryValidatorParams:
12
- k_miners: int
13
- exclude: List[str]
14
- roles: List[str]
15
- messages: List[str]
16
- timeout: int
17
-
18
- @staticmethod
19
- def from_dict(data: dict):
20
- return QueryValidatorParams(
21
- k_miners=data.get('k', 10),
22
- exclude=data.get('exclude', []),
23
- roles=data['roles'],
24
- messages=data['messages'],
25
- timeout=data.get('timeout', 10)
26
- )
27
-
28
- class ValidatorWrapper(ABC):
29
- @abstractmethod
30
- async def query_validator(self, params:QueryValidatorParams):
31
- pass
32
-
33
-
34
- class S1ValidatorWrapper(ValidatorWrapper):
35
- def __init__(self):
36
- self.validator = Validator()
37
-
38
- async def query_validator(self, params:QueryValidatorParams) -> DendriteResponseEvent:
39
- # Get the list of uids to query for this step.
40
- uids = get_random_uids(
41
- self.validator,
42
- k=params.k_miners,
43
- exclude=params.exclude).to(self.validator.device)
44
-
45
- axons = [self.validator.metagraph.axons[uid] for uid in uids]
46
-
47
- # Make calls to the network with the prompt.
48
- bt.logging.info(f'Calling dendrite')
49
- responses = await self.validator.dendrite(
50
- axons=axons,
51
- synapse=PromptingSynapse(roles=params.roles, messages=params.messages),
52
- timeout=params.timeout,
53
- )
54
-
55
- # Encapsulate the responses in a response event (dataclass)
56
- bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
57
- response_event = DendriteResponseEvent(responses, uids, params.timeout)
58
- return response_event
59
-
60
-
61
- class MockValidator(ValidatorWrapper):
62
- async def query_validator(self, params:QueryValidatorParams) -> DendriteResponseEvent:
63
- ...
64
-
65
-
66
-
67
-
68
-
69
-
70
-