Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import re | |
import asyncio | |
import bittensor as bt | |
from aiohttp import web | |
from collections import Counter | |
from prompting.rewards import DateRewardModel, FloatDiffModel | |
from validators.streamer import AsyncResponseDataStreamer | |
UNSUCCESSFUL_RESPONSE_PATTERNS = [ | |
"I'm sorry", | |
"unable to", | |
"I cannot", | |
"I can't", | |
"I am unable", | |
"I am sorry", | |
"I can not", | |
"don't know", | |
"not sure", | |
"don't understand", | |
"not capable", | |
] | |
reward_models = { | |
"date_qa": DateRewardModel(), | |
"math": FloatDiffModel(), | |
} | |
def completion_is_valid(completion: str): | |
""" | |
Get the completion statuses from the completions. | |
""" | |
if not completion.strip(): | |
return False | |
patt = re.compile( | |
r"\b(?:" + "|".join(UNSUCCESSFUL_RESPONSE_PATTERNS) + r")\b", re.IGNORECASE | |
) | |
if not len(re.findall(r"\w+", completion)) or patt.search(completion): | |
return False | |
return True | |
def ensemble_result(completions: list, task_name: str, prefer: str = "longest"): | |
""" | |
Ensemble completions from multiple models. | |
# TODO: Measure agreement | |
# TODO: Figure out how to mitigate the cabal effect (large groups will appear to be more credible) | |
# TODO: Reward pipeline | |
""" | |
if not completions: | |
return None | |
answer = None | |
if task_name in ("qa", "summarization"): | |
# No special handling for QA or summarization | |
supporting_completions = completions | |
elif task_name == "date_qa": | |
# filter the completions to be the ones that contain valid dates and if there are multiple dates, select the most common one (with support > 1) | |
dates = list(map(reward_models[task_name].parse_dates_from_text, completions)) | |
bt.logging.info(f"Unprocessed dates: {dates}") | |
valid_date_indices = [i for i, d in enumerate(dates) if d] | |
valid_completions = [completions[i] for i in valid_date_indices] | |
valid_dates = [dates[i] for i in valid_date_indices] | |
dates = [f"{d[0].strftime('%-d %B')} {d[1]}" for d in valid_dates] | |
if not dates: | |
return None | |
counter = Counter(dates) | |
most_common, count = counter.most_common()[0] | |
answer = most_common | |
if count == 1: | |
supporting_completions = valid_completions | |
else: | |
supporting_completions = [ | |
c for i, c in enumerate(valid_completions) if dates[i] == most_common | |
] | |
elif task_name == "math": | |
# filter the completions to be the ones that contain valid numbers and if there are multiple values, select the most common one (with support > 1) | |
# TODO: use the median instead of the most common value | |
vals = list(map(reward_models[task_name].extract_number, completions)) | |
vals = [val for val in vals if val] | |
if not vals: | |
return None | |
most_common, count = Counter(dates).most_common()[0] | |
bt.logging.info(f"Most common value: {most_common}, count: {count}") | |
answer = most_common | |
if count == 1: | |
supporting_completions = completions | |
else: | |
supporting_completions = [ | |
c for i, c in enumerate(completions) if vals[i] == most_common | |
] | |
bt.logging.info(f"Supporting completions: {supporting_completions}") | |
if prefer == "longest": | |
preferred_completion = sorted(supporting_completions, key=len)[-1] | |
elif prefer == "shortest": | |
preferred_completion = sorted(supporting_completions, key=len)[0] | |
elif prefer == "most_common": | |
preferred_completion = max( | |
set(supporting_completions), key=supporting_completions.count | |
) | |
else: | |
raise ValueError(f"Unknown ensemble preference: {prefer}") | |
return { | |
"completion": preferred_completion, | |
"accepted_answer": answer, | |
"support": len(supporting_completions), | |
"support_indices": [completions.index(c) for c in supporting_completions], | |
"method": f'Selected the {prefer.replace("_", " ")} completion', | |
} | |
def guess_task_name(challenge: str): | |
# TODO: use a pre-trained classifier to guess the task name | |
categories = { | |
"summarization": re.compile("summar|quick rundown|overview"), | |
"date_qa": re.compile( | |
"exact date|tell me when|on what date|on what day|was born?|died?" | |
), | |
"math": re.compile( | |
"math|solve|solution| sum |problem|geometric|vector|calculate|degrees|decimal|factorial" | |
), | |
} | |
for task_name, patt in categories.items(): | |
if patt.search(challenge): | |
return task_name | |
return "qa" | |
# Simulate the stream synapse for the echo endpoint | |
class EchoAsyncIterator: | |
def __init__(self, message: str, k: int, delay: float): | |
self.message = message | |
self.k = k | |
self.delay = delay | |
async def __aiter__(self): | |
for _ in range(self.k): | |
for word in self.message.split(): | |
yield [word] | |
await asyncio.sleep(self.delay) | |
async def echo_stream(request: web.Request) -> web.StreamResponse: | |
request_data = request["data"] | |
k = request_data.get("k", 1) | |
message = "\n\n".join(request_data["messages"]) | |
echo_iterator = EchoAsyncIterator(message, k, delay=0.3) | |
streamer = AsyncResponseDataStreamer(echo_iterator, selected_uid=0, delay=0.3) | |
return await streamer.stream(request) | |