File size: 2,087 Bytes
d7cdabb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import bittensor as bt
from neurons.validator import Validator
from prompting.utils.uids import get_random_uids
from prompting.protocol import PromptingSynapse
from prompting.dendrite import DendriteResponseEvent
from abc import ABC, abstractmethod
from typing import List
from dataclasses import dataclass

@dataclass
class QueryValidatorParams:
    k_miners: int
    exclude: List[str]
    roles: List[str]
    messages: List[str]
    timeout: int
    
    @staticmethod
    def from_dict(data: dict):
        return QueryValidatorParams(            
            k_miners=data.get('k', 10),
            exclude=data.get('exclude', []),
            roles=data['roles'],
            messages=data['messages'],
            timeout=data.get('timeout', 10)
        )

class ValidatorWrapper(ABC):
    @abstractmethod
    async def query_validator(self, params:QueryValidatorParams):
        pass
    
    
class S1ValidatorWrapper(ValidatorWrapper):
    def __init__(self):
        self.validator = Validator()    
    
    async def query_validator(self, params:QueryValidatorParams) -> DendriteResponseEvent:
        # Get the list of uids to query for this step.
        uids = get_random_uids(
            self.validator,
            k=params.k_miners,
            exclude=params.exclude).to(self.validator.device)
        
        axons = [self.validator.metagraph.axons[uid] for uid in uids]

        # Make calls to the network with the prompt.
        bt.logging.info(f'Calling dendrite')
        responses = await self.validator.dendrite(
            axons=axons,
            synapse=PromptingSynapse(roles=params.roles, messages=params.request_data.messages),
            timeout=params.timeout,
        )
        
        # Encapsulate the responses in a response event (dataclass)
        bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
        response_event = DendriteResponseEvent(responses, uids)
        return response_event
        
    
class MockValidator(ValidatorWrapper):
    def query_validator(self, query: str) -> bool:
        return False