File size: 6,023 Bytes
48f02b6
32e1e2e
 
7ea7f29
 
32e1e2e
7ea7f29
48f02b6
 
 
fdc8fdb
 
 
 
 
 
 
 
 
 
 
 
 
48f02b6
 
fdc8fdb
 
48f02b6
 
fdc8fdb
48f02b6
 
 
 
 
 
 
fdc8fdb
 
 
 
48f02b6
 
 
 
fdc8fdb
48f02b6
 
 
 
 
 
 
 
 
 
fdc8fdb
48f02b6
 
 
fdc8fdb
48f02b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc8fdb
 
 
48f02b6
fdc8fdb
48f02b6
 
 
 
 
 
 
 
 
 
 
 
 
fdc8fdb
 
 
48f02b6
 
fdc8fdb
48f02b6
fdc8fdb
48f02b6
fdc8fdb
 
 
 
48f02b6
 
 
 
fdc8fdb
 
 
 
 
48f02b6
 
fdc8fdb
48f02b6
 
 
fdc8fdb
 
 
 
 
 
 
48f02b6
 
 
 
 
fdc8fdb
32e1e2e
 
7ea7f29
 
1b0bdb5
fdc8fdb
32e1e2e
 
fdc8fdb
7ea7f29
fdc8fdb
7ea7f29
32e1e2e
fdc8fdb
7ea7f29
 
 
32e1e2e
 
 
1b0bdb5
fdc8fdb
32e1e2e
1b0bdb5
32e1e2e
1b0bdb5
7ea7f29
 
1b0bdb5
32e1e2e
 
1b0bdb5
7ea7f29
1b0bdb5
7ea7f29
 
1b0bdb5
 
32e1e2e
 
7ea7f29
32e1e2e
 
 
fdc8fdb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import re
import time
import json
import asyncio
import bittensor as bt
from aiohttp import web
from responses import TextStreamResponse
from collections import Counter
from prompting.rewards import DateRewardModel, FloatDiffModel

UNSUCCESSFUL_RESPONSE_PATTERNS = [
    "I'm sorry",
    "unable to",
    "I cannot",
    "I can't",
    "I am unable",
    "I am sorry",
    "I can not",
    "don't know",
    "not sure",
    "don't understand",
    "not capable",
]

reward_models = {
    "date_qa": DateRewardModel(),
    "math": FloatDiffModel(),
}


def completion_is_valid(completion: str):
    """
    Get the completion statuses from the completions.
    """
    if not completion.strip():
        return False

    patt = re.compile(
        r"\b(?:" + "|".join(UNSUCCESSFUL_RESPONSE_PATTERNS) + r")\b", re.IGNORECASE
    )
    if not len(re.findall(r"\w+", completion)) or patt.search(completion):
        return False
    return True


def ensemble_result(completions: list, task_name: str, prefer: str = "longest"):
    """
    Ensemble completions from multiple models.
    # TODO: Measure agreement
    # TODO: Figure out how to mitigate the cabal effect (large groups will appear to be more credible)
    # TODO: Reward pipeline
    """
    if not completions:
        return None

    answer = None
    if task_name in ("qa", "summarization"):
        # No special handling for QA or summarization
        supporting_completions = completions

    elif task_name == "date_qa":
        # filter the completions to be the ones that contain valid dates and if there are multiple dates, select the most common one (with support > 1)
        dates = list(map(reward_models[task_name].parse_dates_from_text, completions))
        bt.logging.info(f"Unprocessed dates: {dates}")
        valid_date_indices = [i for i, d in enumerate(dates) if d]
        valid_completions = [completions[i] for i in valid_date_indices]
        valid_dates = [dates[i] for i in valid_date_indices]
        dates = [f"{d[0].strftime('%-d %B')} {d[1]}" for d in valid_dates]
        if not dates:
            return None

        counter = Counter(dates)
        most_common, count = counter.most_common()[0]
        answer = most_common
        if count == 1:
            supporting_completions = valid_completions
        else:
            supporting_completions = [
                c for i, c in enumerate(valid_completions) if dates[i] == most_common
            ]

    elif task_name == "math":
        # filter the completions to be the ones that contain valid numbers and if there are multiple values, select the most common one (with support > 1)
        # TODO: use the median instead of the most common value
        vals = list(map(reward_models[task_name].extract_number, completions))
        vals = [val for val in vals if val]
        if not vals:
            return None

        most_common, count = Counter(dates).most_common()[0]
        bt.logging.info(f"Most common value: {most_common}, count: {count}")
        answer = most_common
        if count == 1:
            supporting_completions = completions
        else:
            supporting_completions = [
                c for i, c in enumerate(completions) if vals[i] == most_common
            ]

    bt.logging.info(f"Supporting completions: {supporting_completions}")
    if prefer == "longest":
        preferred_completion = sorted(supporting_completions, key=len)[-1]
    elif prefer == "shortest":
        preferred_completion = sorted(supporting_completions, key=len)[0]
    elif prefer == "most_common":
        preferred_completion = max(
            set(supporting_completions), key=supporting_completions.count
        )
    else:
        raise ValueError(f"Unknown ensemble preference: {prefer}")

    return {
        "completion": preferred_completion,
        "accepted_answer": answer,
        "support": len(supporting_completions),
        "support_indices": [completions.index(c) for c in supporting_completions],
        "method": f'Selected the {prefer.replace("_", " ")} completion',
    }


def guess_task_name(challenge: str):
    # TODO: use a pre-trained classifier to guess the task name
    categories = {
        "summarization": re.compile("summar|quick rundown|overview"),
        "date_qa": re.compile(
            "exact date|tell me when|on what date|on what day|was born?|died?"
        ),
        "math": re.compile(
            "math|solve|solution| sum |problem|geometric|vector|calculate|degrees|decimal|factorial"
        ),
    }
    for task_name, patt in categories.items():
        if patt.search(challenge):
            return task_name

    return "qa"


async def echo_stream(request: web.Request) -> web.StreamResponse:
    request_data = request["data"]
    k = request_data.get("k", 1)
    message = "\n\n".join(request_data["messages"])

    # Create a StreamResponse
    response = web.StreamResponse(
        status=200, reason="OK", headers={"Content-Type": "application/json"}
    )
    await response.prepare(request)

    completion = ""
    chunks = []
    chunks_timings = []
    start_time = time.time()
    # Echo the message k times with a timeout between each chunk
    for _ in range(k):
        for word in message.split():
            chunk = f"{word} "
            await response.write(chunk.encode("utf-8"))
            completion += chunk
            await asyncio.sleep(0.3)
            bt.logging.info(f"Echoed: {chunk}")

            chunks.append(chunk)
            chunks_timings.append(time.time() - start_time)

    completion = completion.strip()

    # Prepare final JSON chunk
    response_data = TextStreamResponse(
        streamed_chunks=chunks,
        streamed_chunks_timings=chunks_timings,
        completion=completion,
        timing=time.time() - start_time,
    ).to_dict()

    # Send the final JSON as part of the stream
    await response.write(json.dumps(response_data).encode("utf-8"))

    # Finalize the response
    await response.write_eof()
    return response