File size: 9,163 Bytes
1e9def1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc8fdb
1e9def1
 
 
fdc8fdb
 
 
1e9def1
 
fdc8fdb
1e9def1
 
 
 
 
fdc8fdb
1e9def1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc8fdb
1e9def1
 
 
 
 
fdc8fdb
1e9def1
fdc8fdb
1e9def1
 
 
fdc8fdb
1e9def1
 
 
 
 
 
 
 
fdc8fdb
1e9def1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import time
import sys
import asyncio
import numpy as np
import bittensor as bt
import traceback
from typing import List, Dict, Awaitable
from prompting.agent import HumanAgent
from prompting.dendrite import DendriteResponseEvent
from prompting.conversation import create_task
from prompting.protocol import StreamPromptingSynapse
from prompting.rewards import RewardResult
from prompting.utils.uids import get_random_uids
from prompting.utils.logging import log_event
from prompting.utils.misc import async_log, serialize_exception_to_string
from dataclasses import dataclass


@async_log
async def generate_reference(agent):
    loop = asyncio.get_running_loop()
    result = await loop.run_in_executor(
        None, agent.task.generate_reference, agent.llm_pipeline
    )
    return result


@async_log
async def execute_dendrite_call(dendrite_call):
    responses = await dendrite_call
    return responses


@dataclass
class StreamResult:
    synapse: StreamPromptingSynapse = None
    exception: BaseException = None
    uid: int = None


async def process_response(uid: int, async_generator: Awaitable):
    """Process a single response asynchronously."""
    try:
        chunk = None  # Initialize chunk with a default value
        async for chunk in async_generator:  # most important loop, as this is where we acquire the final synapse.
            bt.logging.debug(f"\nchunk for uid {uid}: {chunk}")

        if chunk is not None:
            synapse = chunk  # last object yielded is the synapse itself with completion filled

            # Assuming chunk holds the last value yielded which should be a synapse
            if isinstance(synapse, StreamPromptingSynapse):
                return synapse

        bt.logging.debug(
            f"Synapse is not StreamPromptingSynapse. Miner uid {uid} completion set to '' "
        )
    except Exception as e:
        # bt.logging.error(f"Error in generating reference or handling responses: {e}", exc_info=True)
        traceback_details = traceback.format_exc()
        bt.logging.error(
            f"Error in generating reference or handling responses for uid {uid}: {e}\n{traceback_details}"
        )

        failed_synapse = StreamPromptingSynapse(
            roles=["user"], messages=["failure"], completion=""
        )

        return failed_synapse


@async_log
async def handle_response(responses: Dict[int, Awaitable]) -> List[StreamResult]:
    """The handle_response function is responsible for creating asyncio tasks around acquiring streamed miner chunks
    and processing them asynchronously. It then pairs the results with their original UIDs and returns a list of StreamResults.

    Args:
        responses (Dict[int, Awaitable]): Responses contains awaitables that are used to acquire streamed miner chunks.

    Raises:
        ValueError

    Returns:
        List[StreamResult]: DataClass containing the synapse, exception, and uid
    """
    tasks_with_uid = [
        (uid, responses[uid]) for uid, _ in responses.items()
    ]  # Pair UIDs with their tasks

    # Start tasks, preserving order and their associated UIDs
    tasks = [process_response(uid, resp) for uid, resp in tasks_with_uid]

    results = await asyncio.gather(*tasks, return_exceptions=True)

    mapped_results = []
    # Pair each result with its original uid
    for (uid, _), result in zip(tasks_with_uid, results):
        # If the result is a StreamPromptingSynapse, the response was successful and the stream result is added without exceptions
        if isinstance(result, StreamPromptingSynapse):
            mapped_results.append(StreamResult(synapse=result, uid=uid))

        # If the result is an exception, the response was unsuccessful and the stream result is added with the exception and an empty synapse
        elif isinstance(result, BaseException):
            failed_synapse = StreamPromptingSynapse(
                roles=["user"], messages=["failure"], completion=""
            )
            mapped_results.append(
                StreamResult(synapse=failed_synapse, exception=result, uid=uid)
            )

        # If the result is neither an error or a StreamSynapse, log the error and raise a ValueError
        else:
            bt.logging.error(f"Unexpected result type for UID {uid}: {result}")
            raise ValueError(f"Unexpected result type for UID {uid}: {result}")

    return mapped_results


@async_log
async def generate_reference(agent: HumanAgent):
    loop = asyncio.get_running_loop()
    result = await loop.run_in_executor(
        None, agent.task.generate_reference, agent.llm_pipeline
    )
    return result


def log_stream_results(stream_results: List[StreamResult]):
    failed_responses = [
        response for response in stream_results if response.exception is not None
    ]
    empty_responses = [
        response
        for response in stream_results
        if response.exception is None and response.synapse.completion == ""
    ]
    non_empty_responses = [
        response
        for response in stream_results
        if response.exception is None and response.synapse.completion != ""
    ]

    bt.logging.info(f"Total of non_empty responses: ({len(non_empty_responses)})")
    bt.logging.info(f"Total of empty responses: ({len(empty_responses)})")
    bt.logging.info(
        f"Total of failed responses: ({len(failed_responses)}):\n {failed_responses}"
    )

    for failed_response in failed_responses:
        formatted_exception = serialize_exception_to_string(failed_response.exception)
        bt.logging.error(
            f"Failed response for uid {failed_response.uid}: {formatted_exception}"
        )


async def run_step(
    self, agent: HumanAgent, k: int, timeout: float, exclude: list = None
):
    """Executes a single step of the agent, which consists of:
    - Getting a list of uids to query
    - Querying the network
    - Rewarding the network
    - Updating the scores
    - Logging the event

    Args:
        agent (HumanAgent): The agent to run the step for.
        k (int): The number of uids to query.
        timeout (float): The timeout for the queries.
        exclude (list, optional): The list of uids to exclude from the query. Defaults to [].
    """

    bt.logging.debug("run_step", agent.task.name)

    # Record event start time.
    start_time = time.time()
    # Get the list of uids to query for this step.
    uids = get_random_uids(self, k=k, exclude=exclude or []).to(self.device)
    uids_cpu = uids.cpu().tolist()

    axons = [self.metagraph.axons[uid] for uid in uids]

    # Directly call dendrite and process responses in parallel
    streams_responses = await self.dendrite(
        axons=axons,
        synapse=StreamPromptingSynapse(roles=["user"], messages=[agent.challenge]),
        timeout=timeout,
        deserialize=False,
        streaming=True,
    )

    # Prepare the task for handling stream responses
    handle_stream_responses_task = asyncio.create_task(
        handle_response(responses=dict(zip(uids_cpu, streams_responses)))
    )

    if not agent.task.static_reference:
        reference_generation_task = generate_reference(agent)
        _, stream_results = await asyncio.gather(
            reference_generation_task, handle_stream_responses_task
        )
    else:
        stream_results = await handle_stream_responses_task

    log_stream_results(stream_results)

    all_synapses_results = [stream_result.synapse for stream_result in stream_results]

    # Encapsulate the responses in a response event (dataclass)
    response_event = DendriteResponseEvent(
        responses=all_synapses_results, uids=uids, timeout=timeout
    )

    bt.logging.info(f"Created DendriteResponseEvent:\n {response_event}")
    # Reward the responses and get the reward result (dataclass)
    # This contains a list of RewardEvents but can be exported as a dict (column-wise) for logging etc
    reward_result = RewardResult(
        self.reward_pipeline,
        agent=agent,
        response_event=response_event,
        device=self.device,
    )
    bt.logging.info(f"Created RewardResult:\n {reward_result}")

    # The original idea was that the agent is 'satisfied' when it gets a good enough response (e.g. reward critera is met, such as ROUGE>threshold)
    agent.update_progress(
        top_reward=reward_result.rewards.max(),
        top_response=response_event.completions[reward_result.rewards.argmax()],
    )

    self.update_scores(reward_result.rewards, uids)

    stream_results_uids = [stream_result.uid for stream_result in stream_results]
    stream_results_exceptions = [
        serialize_exception_to_string(stream_result.exception)
        for stream_result in stream_results
    ]
    # Log the step event.
    event = {
        "block": self.block,
        "step_time": time.time() - start_time,
        "stream_results_uids": stream_results_uids,
        "stream_results_exceptions": stream_results_exceptions,
        **agent.__state_dict__(full=self.config.neuron.log_full),
        **reward_result.__state_dict__(full=self.config.neuron.log_full),
        **response_event.__state_dict__(),
    }

    return event