Spaces:
Sleeping
Sleeping
File size: 5,434 Bytes
48f02b6 7ea7f29 32e1e2e 48f02b6 2a5b08d 48f02b6 fdc8fdb 48f02b6 fdc8fdb 48f02b6 fdc8fdb 48f02b6 fdc8fdb 48f02b6 fdc8fdb 48f02b6 fdc8fdb 48f02b6 fdc8fdb 48f02b6 fdc8fdb 48f02b6 fdc8fdb 48f02b6 fdc8fdb 48f02b6 fdc8fdb 48f02b6 fdc8fdb 48f02b6 fdc8fdb 48f02b6 fdc8fdb 48f02b6 fdc8fdb 48f02b6 fdc8fdb 48f02b6 fdc8fdb 32e1e2e 2a5b08d 7ea7f29 1b0bdb5 fdc8fdb 32e1e2e 2a5b08d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import re
import asyncio
import bittensor as bt
from aiohttp import web
from collections import Counter
from prompting.rewards import DateRewardModel, FloatDiffModel
from validators.streamer import AsyncResponseDataStreamer
UNSUCCESSFUL_RESPONSE_PATTERNS = [
"I'm sorry",
"unable to",
"I cannot",
"I can't",
"I am unable",
"I am sorry",
"I can not",
"don't know",
"not sure",
"don't understand",
"not capable",
]
reward_models = {
"date_qa": DateRewardModel(),
"math": FloatDiffModel(),
}
def completion_is_valid(completion: str):
"""
Get the completion statuses from the completions.
"""
if not completion.strip():
return False
patt = re.compile(
r"\b(?:" + "|".join(UNSUCCESSFUL_RESPONSE_PATTERNS) + r")\b", re.IGNORECASE
)
if not len(re.findall(r"\w+", completion)) or patt.search(completion):
return False
return True
def ensemble_result(completions: list, task_name: str, prefer: str = "longest"):
"""
Ensemble completions from multiple models.
# TODO: Measure agreement
# TODO: Figure out how to mitigate the cabal effect (large groups will appear to be more credible)
# TODO: Reward pipeline
"""
if not completions:
return None
answer = None
if task_name in ("qa", "summarization"):
# No special handling for QA or summarization
supporting_completions = completions
elif task_name == "date_qa":
# filter the completions to be the ones that contain valid dates and if there are multiple dates, select the most common one (with support > 1)
dates = list(map(reward_models[task_name].parse_dates_from_text, completions))
bt.logging.info(f"Unprocessed dates: {dates}")
valid_date_indices = [i for i, d in enumerate(dates) if d]
valid_completions = [completions[i] for i in valid_date_indices]
valid_dates = [dates[i] for i in valid_date_indices]
dates = [f"{d[0].strftime('%-d %B')} {d[1]}" for d in valid_dates]
if not dates:
return None
counter = Counter(dates)
most_common, count = counter.most_common()[0]
answer = most_common
if count == 1:
supporting_completions = valid_completions
else:
supporting_completions = [
c for i, c in enumerate(valid_completions) if dates[i] == most_common
]
elif task_name == "math":
# filter the completions to be the ones that contain valid numbers and if there are multiple values, select the most common one (with support > 1)
# TODO: use the median instead of the most common value
vals = list(map(reward_models[task_name].extract_number, completions))
vals = [val for val in vals if val]
if not vals:
return None
most_common, count = Counter(dates).most_common()[0]
bt.logging.info(f"Most common value: {most_common}, count: {count}")
answer = most_common
if count == 1:
supporting_completions = completions
else:
supporting_completions = [
c for i, c in enumerate(completions) if vals[i] == most_common
]
bt.logging.info(f"Supporting completions: {supporting_completions}")
if prefer == "longest":
preferred_completion = sorted(supporting_completions, key=len)[-1]
elif prefer == "shortest":
preferred_completion = sorted(supporting_completions, key=len)[0]
elif prefer == "most_common":
preferred_completion = max(
set(supporting_completions), key=supporting_completions.count
)
else:
raise ValueError(f"Unknown ensemble preference: {prefer}")
return {
"completion": preferred_completion,
"accepted_answer": answer,
"support": len(supporting_completions),
"support_indices": [completions.index(c) for c in supporting_completions],
"method": f'Selected the {prefer.replace("_", " ")} completion',
}
def guess_task_name(challenge: str):
# TODO: use a pre-trained classifier to guess the task name
categories = {
"summarization": re.compile("summar|quick rundown|overview"),
"date_qa": re.compile(
"exact date|tell me when|on what date|on what day|was born?|died?"
),
"math": re.compile(
"math|solve|solution| sum |problem|geometric|vector|calculate|degrees|decimal|factorial"
),
}
for task_name, patt in categories.items():
if patt.search(challenge):
return task_name
return "qa"
# Simulate the stream synapse for the echo endpoint
class EchoAsyncIterator:
def __init__(self, message: str, k: int, delay: float):
self.message = message
self.k = k
self.delay = delay
async def __aiter__(self):
for _ in range(self.k):
for word in self.message.split():
yield [word]
await asyncio.sleep(self.delay)
async def echo_stream(request: web.Request) -> web.StreamResponse:
request_data = request["data"]
k = request_data.get("k", 1)
message = "\n\n".join(request_data["messages"])
echo_iterator = EchoAsyncIterator(message, k, delay=0.3)
streamer = AsyncResponseDataStreamer(echo_iterator, selected_uid=0, delay=0.3)
return await streamer.stream(request)
|