Spaces:
Sleeping
Sleeping
import re | |
import bittensor as bt | |
import time | |
import json | |
from aiohttp import web | |
from collections import Counter | |
from prompting.rewards import DateRewardModel, FloatDiffModel | |
UNSUCCESSFUL_RESPONSE_PATTERNS = ["I'm sorry", "unable to", "I cannot", "I can't", "I am unable", "I am sorry", "I can not", "don't know", "not sure", "don't understand", "not capable"] | |
reward_models = { | |
'date_qa': DateRewardModel(), | |
'math': FloatDiffModel(), | |
} | |
def completion_is_valid(completion: str): | |
""" | |
Get the completion statuses from the completions. | |
""" | |
if not completion.strip(): | |
return False | |
patt = re.compile(r'\b(?:' + '|'.join(UNSUCCESSFUL_RESPONSE_PATTERNS) + r')\b', re.IGNORECASE) | |
if not len(re.findall(r'\w+',completion)) or patt.search(completion): | |
return False | |
return True | |
def ensemble_result(completions: list, task_name: str, prefer: str = 'longest'): | |
""" | |
Ensemble completions from multiple models. | |
# TODO: Measure agreement | |
# TODO: Figure out how to mitigate the cabal effect (large groups will appear to be more credible) | |
# TODO: Reward pipeline | |
""" | |
if not completions: | |
return None | |
answer = None | |
if task_name in ('qa', 'summarization'): | |
# No special handling for QA or summarization | |
supporting_completions = completions | |
elif task_name == 'date_qa': | |
# filter the completions to be the ones that contain valid dates and if there are multiple dates, select the most common one (with support > 1) | |
dates = list(map(reward_models[task_name].parse_dates_from_text, completions)) | |
bt.logging.info(f"Unprocessed dates: {dates}") | |
valid_date_indices = [i for i, d in enumerate(dates) if d] | |
valid_completions = [completions[i] for i in valid_date_indices] | |
valid_dates = [dates[i] for i in valid_date_indices] | |
dates = [f"{d[0].strftime('%-d %B')} {d[1]}" for d in valid_dates] | |
if not dates: | |
return None | |
counter = Counter(dates) | |
most_common, count = counter.most_common()[0] | |
answer = most_common | |
if count == 1: | |
supporting_completions = valid_completions | |
else: | |
supporting_completions = [c for i, c in enumerate(valid_completions) if dates[i]==most_common] | |
elif task_name == 'math': | |
# filter the completions to be the ones that contain valid numbers and if there are multiple values, select the most common one (with support > 1) | |
# TODO: use the median instead of the most common value | |
vals = list(map(reward_models[task_name].extract_number, completions)) | |
vals = [val for val in vals if val] | |
if not vals: | |
return None | |
most_common, count = Counter(dates).most_common()[0] | |
bt.logging.info(f"Most common value: {most_common}, count: {count}") | |
answer = most_common | |
if count == 1: | |
supporting_completions = completions | |
else: | |
supporting_completions = [c for i, c in enumerate(completions) if vals[i]==most_common] | |
bt.logging.info(f"Supporting completions: {supporting_completions}") | |
if prefer == 'longest': | |
preferred_completion = sorted(supporting_completions, key=len)[-1] | |
elif prefer == 'shortest': | |
preferred_completion = sorted(supporting_completions, key=len)[0] | |
elif prefer == 'most_common': | |
preferred_completion = max(set(supporting_completions), key=supporting_completions.count) | |
else: | |
raise ValueError(f"Unknown ensemble preference: {prefer}") | |
return { | |
'completion': preferred_completion, | |
'accepted_answer': answer, | |
'support': len(supporting_completions), | |
'support_indices': [completions.index(c) for c in supporting_completions], | |
'method': f'Selected the {prefer.replace("_", " ")} completion' | |
} | |
def guess_task_name(challenge: str): | |
# TODO: use a pre-trained classifier to guess the task name | |
categories = { | |
'summarization': re.compile('summar|quick rundown|overview'), | |
'date_qa': re.compile('exact date|tell me when|on what date|on what day|was born?|died?'), | |
'math': re.compile('math|solve|solution| sum |problem|geometric|vector|calculate|degrees|decimal|factorial'), | |
} | |
for task_name, patt in categories.items(): | |
if patt.search(challenge): | |
return task_name | |
return 'qa' | |
async def echo_stream(request_data: dict): | |
k = request_data.get('k', 1) | |
exclude = request_data.get('exclude', []) | |
timeout = request_data.get('timeout', 0.2) | |
message = '\n\n'.join(request_data['messages']) | |
# Create a StreamResponse | |
response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'text/plain'}) | |
await response.prepare() | |
completion = '' | |
# Echo the message k times with a timeout between each chunk | |
for _ in range(k): | |
for word in message.split(): | |
chunk = f'{word} ' | |
await response.write(chunk.encode('utf-8')) | |
completion += chunk | |
time.sleep(timeout) | |
bt.logging.info(f"Echoed: {chunk}") | |
completion = completion.strip() | |
# Prepare final JSON chunk | |
json_chunk = json.dumps({ | |
"uids": [0], | |
"completion": completion, | |
"completions": [completion.strip()], | |
"timings": [0], | |
"status_messages": ['Went well!'], | |
"status_codes": [200], | |
"completion_is_valid": [True], | |
"task_name": 'echo', | |
"ensemble_result": {} | |
}) | |
# Send the final JSON as part of the stream | |
await response.write(f"\n\nJSON_RESPONSE_BEGIN:\n{json_chunk}".encode('utf-8')) | |
# Finalize the response | |
await response.write_eof() | |
return response |