pedroferreira commited on
Commit
b073cce
·
1 Parent(s): a34ad94

reallocates and refactors validator code to internal package

Browse files
validators/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from base import QueryValidatorParams, ValidatorWrapper, MockValidator
2
+ from sn1_validator_wrapper import S1ValidatorWrapper
validators/base.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import List
3
+ from dataclasses import dataclass
4
+ from aiohttp.web_response import Response
5
+
6
+ @dataclass
7
+ class QueryValidatorParams:
8
+ k_miners: int
9
+ exclude: List[str]
10
+ roles: List[str]
11
+ messages: List[str]
12
+ timeout: int
13
+ prefer: str
14
+
15
+ @staticmethod
16
+ def from_dict(data: dict):
17
+ return QueryValidatorParams(
18
+ k_miners=data.get('k', 10),
19
+ exclude=data.get('exclude', []),
20
+ roles=data['roles'],
21
+ messages=data['messages'],
22
+ timeout=data.get('timeout', 10),
23
+ prefer=data.get('prefer', 'longest')
24
+ )
25
+
26
+ class ValidatorWrapper(ABC):
27
+ @abstractmethod
28
+ async def query_validator(self, params:QueryValidatorParams) -> Response:
29
+ pass
30
+
31
+
32
+ class MockValidator(ValidatorWrapper):
33
+ async def query_validator(self, params:QueryValidatorParams) -> Response:
34
+ ...
35
+
validators/sn1_validator_wrapper.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import utils
3
+ import traceback
4
+ import bittensor as bt
5
+ import asyncio
6
+ from prompting.forward import handle_response
7
+ from prompting.validator import Validator
8
+ from prompting.utils.uids import get_random_uids
9
+ from prompting.protocol import PromptingSynapse, StreamPromptingSynapse
10
+ from prompting.dendrite import DendriteResponseEvent
11
+ from base import QueryValidatorParams, ValidatorWrapper
12
+ from aiohttp.web_response import Response, StreamResponse
13
+ from deprecated import deprecated
14
+
15
+ class S1ValidatorWrapper(ValidatorWrapper):
16
+ def __init__(self):
17
+ self.validator = Validator()
18
+
19
+
20
+ @deprecated(reason="This function is deprecated. Validators use stream synapse now, use get_stream_response instead.")
21
+ async def get_response(self, params:QueryValidatorParams) -> Response:
22
+ try:
23
+ # Guess the task name of current request
24
+ task_name = utils.guess_task_name(params.messages[-1])
25
+
26
+ # Get the list of uids to query for this step.
27
+ uids = get_random_uids(self.validator, k=params.k_miners, exclude=params.exclude or []).tolist()
28
+ axons = [self.validator.metagraph.axons[uid] for uid in uids]
29
+
30
+ # Make calls to the network with the prompt.
31
+ bt.logging.info(f'Calling dendrite')
32
+ responses = await self.validator.dendrite(
33
+ axons=axons,
34
+ synapse=PromptingSynapse(roles=params.roles, messages=params.messages),
35
+ timeout=params.timeout,
36
+ )
37
+
38
+ bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
39
+ # Encapsulate the responses in a response event (dataclass)
40
+ response_event = DendriteResponseEvent(responses, uids)
41
+
42
+ # convert dict to json
43
+ response = response_event.__state_dict__()
44
+
45
+ response['completion_is_valid'] = valid = list(map(utils.completion_is_valid, response['completions']))
46
+ valid_completions = [response['completions'][i] for i, v in enumerate(valid) if v]
47
+
48
+ response['task_name'] = task_name
49
+ response['ensemble_result'] = utils.ensemble_result(valid_completions, task_name=task_name, prefer=params.prefer)
50
+
51
+ bt.logging.info(f"Response:\n {response}")
52
+ return Response(status=200, reason="I can't believe it's not butter!", text=json.dumps(response))
53
+
54
+ except Exception:
55
+ bt.logging.error(f'Encountered in {self.__class__.__name__}:get_response:\n{traceback.format_exc()}')
56
+ return Response(status=500, reason="Internal error")
57
+
58
+
59
+ async def get_stream_response(self, params:QueryValidatorParams) -> StreamResponse:
60
+ response = StreamResponse(status=200, reason="OK")
61
+ response.headers['Content-Type'] = 'application/json'
62
+
63
+ await response.prepare() # Prepare and send the headers
64
+
65
+ try:
66
+ # Guess the task name of current request
67
+ task_name = utils.guess_task_name(params.messages[-1])
68
+
69
+ # Get the list of uids to query for this step.
70
+ uids = get_random_uids(self.validator, k=params.k_miners, exclude=params.exclude or []).tolist()
71
+ axons = [self.validator.metagraph.axons[uid] for uid in uids]
72
+
73
+ # Make calls to the network with the prompt.
74
+ bt.logging.info(f'Calling dendrite')
75
+ streams_responses = await self.validator.dendrite(
76
+ axons=axons,
77
+ synapse=StreamPromptingSynapse(roles=params.roles, messages=params.messages),
78
+ timeout=params.timeout,
79
+ deserialize=False,
80
+ streaming=True,
81
+ )
82
+
83
+ # Asynchronous iteration over streaming responses
84
+ async for stream_result in streams_responses:
85
+ if stream_result is not None:
86
+ # Convert stream result to JSON and write to the response stream
87
+ json_data = json.dumps(stream_result)
88
+ await response.write(json_data.encode('utf-8'))
89
+
90
+ except Exception as e:
91
+ bt.logging.error(f'Encountered an error in {self.__class__.__name__}:get_stream_response:\n{traceback.format_exc()}')
92
+ response.set_status(500, reason="Internal error")
93
+ await response.write(json.dumps({'error': str(e)}).encode('utf-8'))
94
+ finally:
95
+ await response.write_eof() # Ensure to close the response properly
96
+
97
+ return response
98
+
99
+
100
+ async def query_validator(self, params:QueryValidatorParams, stream: bool = True) -> Response:
101
+ if stream:
102
+ return await self.get_stream_response(params)
103
+ else:
104
+ # DEPRECATED
105
+ return await self.get_response(params)
106
+