pedroferreira's picture
runs black
df1e6f4
raw
history blame
5.43 kB
import re
import asyncio
import bittensor as bt
from aiohttp import web
from collections import Counter
from prompting.rewards import DateRewardModel, FloatDiffModel
from validators.streamer import AsyncResponseDataStreamer
UNSUCCESSFUL_RESPONSE_PATTERNS = [
"I'm sorry",
"unable to",
"I cannot",
"I can't",
"I am unable",
"I am sorry",
"I can not",
"don't know",
"not sure",
"don't understand",
"not capable",
]
reward_models = {
"date_qa": DateRewardModel(),
"math": FloatDiffModel(),
}
def completion_is_valid(completion: str):
"""
Get the completion statuses from the completions.
"""
if not completion.strip():
return False
patt = re.compile(
r"\b(?:" + "|".join(UNSUCCESSFUL_RESPONSE_PATTERNS) + r")\b", re.IGNORECASE
)
if not len(re.findall(r"\w+", completion)) or patt.search(completion):
return False
return True
def ensemble_result(completions: list, task_name: str, prefer: str = "longest"):
"""
Ensemble completions from multiple models.
# TODO: Measure agreement
# TODO: Figure out how to mitigate the cabal effect (large groups will appear to be more credible)
# TODO: Reward pipeline
"""
if not completions:
return None
answer = None
if task_name in ("qa", "summarization"):
# No special handling for QA or summarization
supporting_completions = completions
elif task_name == "date_qa":
# filter the completions to be the ones that contain valid dates and if there are multiple dates, select the most common one (with support > 1)
dates = list(map(reward_models[task_name].parse_dates_from_text, completions))
bt.logging.info(f"Unprocessed dates: {dates}")
valid_date_indices = [i for i, d in enumerate(dates) if d]
valid_completions = [completions[i] for i in valid_date_indices]
valid_dates = [dates[i] for i in valid_date_indices]
dates = [f"{d[0].strftime('%-d %B')} {d[1]}" for d in valid_dates]
if not dates:
return None
counter = Counter(dates)
most_common, count = counter.most_common()[0]
answer = most_common
if count == 1:
supporting_completions = valid_completions
else:
supporting_completions = [
c for i, c in enumerate(valid_completions) if dates[i] == most_common
]
elif task_name == "math":
# filter the completions to be the ones that contain valid numbers and if there are multiple values, select the most common one (with support > 1)
# TODO: use the median instead of the most common value
vals = list(map(reward_models[task_name].extract_number, completions))
vals = [val for val in vals if val]
if not vals:
return None
most_common, count = Counter(dates).most_common()[0]
bt.logging.info(f"Most common value: {most_common}, count: {count}")
answer = most_common
if count == 1:
supporting_completions = completions
else:
supporting_completions = [
c for i, c in enumerate(completions) if vals[i] == most_common
]
bt.logging.info(f"Supporting completions: {supporting_completions}")
if prefer == "longest":
preferred_completion = sorted(supporting_completions, key=len)[-1]
elif prefer == "shortest":
preferred_completion = sorted(supporting_completions, key=len)[0]
elif prefer == "most_common":
preferred_completion = max(
set(supporting_completions), key=supporting_completions.count
)
else:
raise ValueError(f"Unknown ensemble preference: {prefer}")
return {
"completion": preferred_completion,
"accepted_answer": answer,
"support": len(supporting_completions),
"support_indices": [completions.index(c) for c in supporting_completions],
"method": f'Selected the {prefer.replace("_", " ")} completion',
}
def guess_task_name(challenge: str):
# TODO: use a pre-trained classifier to guess the task name
categories = {
"summarization": re.compile("summar|quick rundown|overview"),
"date_qa": re.compile(
"exact date|tell me when|on what date|on what day|was born?|died?"
),
"math": re.compile(
"math|solve|solution| sum |problem|geometric|vector|calculate|degrees|decimal|factorial"
),
}
for task_name, patt in categories.items():
if patt.search(challenge):
return task_name
return "qa"
# Simulate the stream synapse for the echo endpoint
class EchoAsyncIterator:
def __init__(self, message: str, k: int, delay: float):
self.message = message
self.k = k
self.delay = delay
async def __aiter__(self):
for _ in range(self.k):
for word in self.message.split():
yield [word]
await asyncio.sleep(self.delay)
async def echo_stream(request: web.Request) -> web.StreamResponse:
request_data = request["data"]
k = request_data.get("k", 1)
message = "\n\n".join(request_data["messages"])
echo_iterator = EchoAsyncIterator(message, k, delay=0.3)
streamer = AsyncResponseDataStreamer(echo_iterator, selected_uid=0, delay=0.3)
return await streamer.stream(request)