File size: 5,434 Bytes
48f02b6
7ea7f29
 
32e1e2e
48f02b6
 
2a5b08d
48f02b6
fdc8fdb
 
 
 
 
 
 
 
 
 
 
 
 
48f02b6
 
fdc8fdb
 
48f02b6
 
fdc8fdb
48f02b6
 
 
 
 
 
 
fdc8fdb
 
 
 
48f02b6
 
 
 
fdc8fdb
48f02b6
 
 
 
 
 
 
 
 
 
fdc8fdb
48f02b6
 
 
fdc8fdb
48f02b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc8fdb
 
 
48f02b6
fdc8fdb
48f02b6
 
 
 
 
 
 
 
 
 
 
 
 
fdc8fdb
 
 
48f02b6
 
fdc8fdb
48f02b6
fdc8fdb
48f02b6
fdc8fdb
 
 
 
48f02b6
 
 
 
fdc8fdb
 
 
 
 
48f02b6
 
fdc8fdb
48f02b6
 
 
fdc8fdb
 
 
 
 
 
 
48f02b6
 
 
 
 
fdc8fdb
32e1e2e
 
2a5b08d
 
 
 
 
 
 
 
 
 
 
 
 
 
7ea7f29
 
1b0bdb5
fdc8fdb
32e1e2e
2a5b08d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import re
import asyncio
import bittensor as bt
from aiohttp import web
from collections import Counter
from prompting.rewards import DateRewardModel, FloatDiffModel
from validators.streamer import AsyncResponseDataStreamer

UNSUCCESSFUL_RESPONSE_PATTERNS = [
    "I'm sorry",
    "unable to",
    "I cannot",
    "I can't",
    "I am unable",
    "I am sorry",
    "I can not",
    "don't know",
    "not sure",
    "don't understand",
    "not capable",
]

reward_models = {
    "date_qa": DateRewardModel(),
    "math": FloatDiffModel(),
}


def completion_is_valid(completion: str):
    """
    Get the completion statuses from the completions.
    """
    if not completion.strip():
        return False

    patt = re.compile(
        r"\b(?:" + "|".join(UNSUCCESSFUL_RESPONSE_PATTERNS) + r")\b", re.IGNORECASE
    )
    if not len(re.findall(r"\w+", completion)) or patt.search(completion):
        return False
    return True


def ensemble_result(completions: list, task_name: str, prefer: str = "longest"):
    """
    Ensemble completions from multiple models.
    # TODO: Measure agreement
    # TODO: Figure out how to mitigate the cabal effect (large groups will appear to be more credible)
    # TODO: Reward pipeline
    """
    if not completions:
        return None

    answer = None
    if task_name in ("qa", "summarization"):
        # No special handling for QA or summarization
        supporting_completions = completions

    elif task_name == "date_qa":
        # filter the completions to be the ones that contain valid dates and if there are multiple dates, select the most common one (with support > 1)
        dates = list(map(reward_models[task_name].parse_dates_from_text, completions))
        bt.logging.info(f"Unprocessed dates: {dates}")
        valid_date_indices = [i for i, d in enumerate(dates) if d]
        valid_completions = [completions[i] for i in valid_date_indices]
        valid_dates = [dates[i] for i in valid_date_indices]
        dates = [f"{d[0].strftime('%-d %B')} {d[1]}" for d in valid_dates]
        if not dates:
            return None

        counter = Counter(dates)
        most_common, count = counter.most_common()[0]
        answer = most_common
        if count == 1:
            supporting_completions = valid_completions
        else:
            supporting_completions = [
                c for i, c in enumerate(valid_completions) if dates[i] == most_common
            ]

    elif task_name == "math":
        # filter the completions to be the ones that contain valid numbers and if there are multiple values, select the most common one (with support > 1)
        # TODO: use the median instead of the most common value
        vals = list(map(reward_models[task_name].extract_number, completions))
        vals = [val for val in vals if val]
        if not vals:
            return None

        most_common, count = Counter(dates).most_common()[0]
        bt.logging.info(f"Most common value: {most_common}, count: {count}")
        answer = most_common
        if count == 1:
            supporting_completions = completions
        else:
            supporting_completions = [
                c for i, c in enumerate(completions) if vals[i] == most_common
            ]

    bt.logging.info(f"Supporting completions: {supporting_completions}")
    if prefer == "longest":
        preferred_completion = sorted(supporting_completions, key=len)[-1]
    elif prefer == "shortest":
        preferred_completion = sorted(supporting_completions, key=len)[0]
    elif prefer == "most_common":
        preferred_completion = max(
            set(supporting_completions), key=supporting_completions.count
        )
    else:
        raise ValueError(f"Unknown ensemble preference: {prefer}")

    return {
        "completion": preferred_completion,
        "accepted_answer": answer,
        "support": len(supporting_completions),
        "support_indices": [completions.index(c) for c in supporting_completions],
        "method": f'Selected the {prefer.replace("_", " ")} completion',
    }


def guess_task_name(challenge: str):
    # TODO: use a pre-trained classifier to guess the task name
    categories = {
        "summarization": re.compile("summar|quick rundown|overview"),
        "date_qa": re.compile(
            "exact date|tell me when|on what date|on what day|was born?|died?"
        ),
        "math": re.compile(
            "math|solve|solution| sum |problem|geometric|vector|calculate|degrees|decimal|factorial"
        ),
    }
    for task_name, patt in categories.items():
        if patt.search(challenge):
            return task_name

    return "qa"


# Simulate the stream synapse for the echo endpoint
class EchoAsyncIterator:
    def __init__(self, message: str, k: int, delay: float):
        self.message = message
        self.k = k
        self.delay = delay

    async def __aiter__(self):
        for _ in range(self.k):
            for word in self.message.split():
                yield [word]
                await asyncio.sleep(self.delay)


async def echo_stream(request: web.Request) -> web.StreamResponse:
    request_data = request["data"]
    k = request_data.get("k", 1)
    message = "\n\n".join(request_data["messages"])

    echo_iterator = EchoAsyncIterator(message, k, delay=0.3)
    streamer = AsyncResponseDataStreamer(echo_iterator, selected_uid=0, delay=0.3)

    return await streamer.stream(request)