Spaces:
Sleeping
Sleeping
import re | |
import bittensor as bt | |
import time | |
import json | |
from aiohttp import web | |
from collections import Counter | |
from prompting.rewards import DateRewardModel, FloatDiffModel | |
UNSUCCESSFUL_RESPONSE_PATTERNS = [ | |
"I'm sorry", | |
"unable to", | |
"I cannot", | |
"I can't", | |
"I am unable", | |
"I am sorry", | |
"I can not", | |
"don't know", | |
"not sure", | |
"don't understand", | |
"not capable", | |
] | |
reward_models = { | |
"date_qa": DateRewardModel(), | |
"math": FloatDiffModel(), | |
} | |
def completion_is_valid(completion: str): | |
""" | |
Get the completion statuses from the completions. | |
""" | |
if not completion.strip(): | |
return False | |
patt = re.compile( | |
r"\b(?:" + "|".join(UNSUCCESSFUL_RESPONSE_PATTERNS) + r")\b", re.IGNORECASE | |
) | |
if not len(re.findall(r"\w+", completion)) or patt.search(completion): | |
return False | |
return True | |
def ensemble_result(completions: list, task_name: str, prefer: str = "longest"): | |
""" | |
Ensemble completions from multiple models. | |
# TODO: Measure agreement | |
# TODO: Figure out how to mitigate the cabal effect (large groups will appear to be more credible) | |
# TODO: Reward pipeline | |
""" | |
if not completions: | |
return None | |
answer = None | |
if task_name in ("qa", "summarization"): | |
# No special handling for QA or summarization | |
supporting_completions = completions | |
elif task_name == "date_qa": | |
# filter the completions to be the ones that contain valid dates and if there are multiple dates, select the most common one (with support > 1) | |
dates = list(map(reward_models[task_name].parse_dates_from_text, completions)) | |
bt.logging.info(f"Unprocessed dates: {dates}") | |
valid_date_indices = [i for i, d in enumerate(dates) if d] | |
valid_completions = [completions[i] for i in valid_date_indices] | |
valid_dates = [dates[i] for i in valid_date_indices] | |
dates = [f"{d[0].strftime('%-d %B')} {d[1]}" for d in valid_dates] | |
if not dates: | |
return None | |
counter = Counter(dates) | |
most_common, count = counter.most_common()[0] | |
answer = most_common | |
if count == 1: | |
supporting_completions = valid_completions | |
else: | |
supporting_completions = [ | |
c for i, c in enumerate(valid_completions) if dates[i] == most_common | |
] | |
elif task_name == "math": | |
# filter the completions to be the ones that contain valid numbers and if there are multiple values, select the most common one (with support > 1) | |
# TODO: use the median instead of the most common value | |
vals = list(map(reward_models[task_name].extract_number, completions)) | |
vals = [val for val in vals if val] | |
if not vals: | |
return None | |
most_common, count = Counter(dates).most_common()[0] | |
bt.logging.info(f"Most common value: {most_common}, count: {count}") | |
answer = most_common | |
if count == 1: | |
supporting_completions = completions | |
else: | |
supporting_completions = [ | |
c for i, c in enumerate(completions) if vals[i] == most_common | |
] | |
bt.logging.info(f"Supporting completions: {supporting_completions}") | |
if prefer == "longest": | |
preferred_completion = sorted(supporting_completions, key=len)[-1] | |
elif prefer == "shortest": | |
preferred_completion = sorted(supporting_completions, key=len)[0] | |
elif prefer == "most_common": | |
preferred_completion = max( | |
set(supporting_completions), key=supporting_completions.count | |
) | |
else: | |
raise ValueError(f"Unknown ensemble preference: {prefer}") | |
return { | |
"completion": preferred_completion, | |
"accepted_answer": answer, | |
"support": len(supporting_completions), | |
"support_indices": [completions.index(c) for c in supporting_completions], | |
"method": f'Selected the {prefer.replace("_", " ")} completion', | |
} | |
def guess_task_name(challenge: str): | |
# TODO: use a pre-trained classifier to guess the task name | |
categories = { | |
"summarization": re.compile("summar|quick rundown|overview"), | |
"date_qa": re.compile( | |
"exact date|tell me when|on what date|on what day|was born?|died?" | |
), | |
"math": re.compile( | |
"math|solve|solution| sum |problem|geometric|vector|calculate|degrees|decimal|factorial" | |
), | |
} | |
for task_name, patt in categories.items(): | |
if patt.search(challenge): | |
return task_name | |
return "qa" | |
async def echo_stream(request_data: dict): | |
k = request_data.get("k", 1) | |
exclude = request_data.get("exclude", []) | |
timeout = request_data.get("timeout", 0.2) | |
message = "\n\n".join(request_data["messages"]) | |
# Create a StreamResponse | |
response = web.StreamResponse( | |
status=200, reason="OK", headers={"Content-Type": "text/plain"} | |
) | |
await response.prepare() | |
completion = "" | |
# Echo the message k times with a timeout between each chunk | |
for _ in range(k): | |
for word in message.split(): | |
chunk = f"{word} " | |
await response.write(chunk.encode("utf-8")) | |
completion += chunk | |
time.sleep(timeout) | |
bt.logging.info(f"Echoed: {chunk}") | |
completion = completion.strip() | |
# Prepare final JSON chunk | |
json_chunk = json.dumps( | |
{ | |
"uids": [0], | |
"completion": completion, | |
"completions": [completion.strip()], | |
"timings": [0], | |
"status_messages": ["Went well!"], | |
"status_codes": [200], | |
"completion_is_valid": [True], | |
"task_name": "echo", | |
"ensemble_result": {}, | |
} | |
) | |
# Send the final JSON as part of the stream | |
await response.write(f"\n\nJSON_RESPONSE_BEGIN:\n{json_chunk}".encode("utf-8")) | |
# Finalize the response | |
await response.write_eof() | |
return response | |