import re import asyncio import bittensor as bt from aiohttp import web from collections import Counter from prompting.rewards import DateRewardModel, FloatDiffModel from validators.streamer import AsyncResponseDataStreamer UNSUCCESSFUL_RESPONSE_PATTERNS = [ "I'm sorry", "unable to", "I cannot", "I can't", "I am unable", "I am sorry", "I can not", "don't know", "not sure", "don't understand", "not capable", ] reward_models = { "date_qa": DateRewardModel(), "math": FloatDiffModel(), } def completion_is_valid(completion: str): """ Get the completion statuses from the completions. """ if not completion.strip(): return False patt = re.compile( r"\b(?:" + "|".join(UNSUCCESSFUL_RESPONSE_PATTERNS) + r")\b", re.IGNORECASE ) if not len(re.findall(r"\w+", completion)) or patt.search(completion): return False return True def ensemble_result(completions: list, task_name: str, prefer: str = "longest"): """ Ensemble completions from multiple models. # TODO: Measure agreement # TODO: Figure out how to mitigate the cabal effect (large groups will appear to be more credible) # TODO: Reward pipeline """ if not completions: return None answer = None if task_name in ("qa", "summarization"): # No special handling for QA or summarization supporting_completions = completions elif task_name == "date_qa": # filter the completions to be the ones that contain valid dates and if there are multiple dates, select the most common one (with support > 1) dates = list(map(reward_models[task_name].parse_dates_from_text, completions)) bt.logging.info(f"Unprocessed dates: {dates}") valid_date_indices = [i for i, d in enumerate(dates) if d] valid_completions = [completions[i] for i in valid_date_indices] valid_dates = [dates[i] for i in valid_date_indices] dates = [f"{d[0].strftime('%-d %B')} {d[1]}" for d in valid_dates] if not dates: return None counter = Counter(dates) most_common, count = counter.most_common()[0] answer = most_common if count == 1: supporting_completions = valid_completions else: supporting_completions = [ c for i, c in enumerate(valid_completions) if dates[i] == most_common ] elif task_name == "math": # filter the completions to be the ones that contain valid numbers and if there are multiple values, select the most common one (with support > 1) # TODO: use the median instead of the most common value vals = list(map(reward_models[task_name].extract_number, completions)) vals = [val for val in vals if val] if not vals: return None most_common, count = Counter(dates).most_common()[0] bt.logging.info(f"Most common value: {most_common}, count: {count}") answer = most_common if count == 1: supporting_completions = completions else: supporting_completions = [ c for i, c in enumerate(completions) if vals[i] == most_common ] bt.logging.info(f"Supporting completions: {supporting_completions}") if prefer == "longest": preferred_completion = sorted(supporting_completions, key=len)[-1] elif prefer == "shortest": preferred_completion = sorted(supporting_completions, key=len)[0] elif prefer == "most_common": preferred_completion = max( set(supporting_completions), key=supporting_completions.count ) else: raise ValueError(f"Unknown ensemble preference: {prefer}") return { "completion": preferred_completion, "accepted_answer": answer, "support": len(supporting_completions), "support_indices": [completions.index(c) for c in supporting_completions], "method": f'Selected the {prefer.replace("_", " ")} completion', } def guess_task_name(challenge: str): # TODO: use a pre-trained classifier to guess the task name categories = { "summarization": re.compile("summar|quick rundown|overview"), "date_qa": re.compile( "exact date|tell me when|on what date|on what day|was born?|died?" ), "math": re.compile( "math|solve|solution| sum |problem|geometric|vector|calculate|degrees|decimal|factorial" ), } for task_name, patt in categories.items(): if patt.search(challenge): return task_name return "qa" # Simulate the stream synapse for the echo endpoint class EchoAsyncIterator: def __init__(self, message: str, k: int, delay: float): self.message = message self.k = k self.delay = delay async def __aiter__(self): for _ in range(self.k): for word in self.message.split(): yield [word] await asyncio.sleep(self.delay) async def echo_stream(request: web.Request) -> web.StreamResponse: request_data = request["data"] k = request_data.get("k", 1) message = "\n\n".join(request_data["messages"]) echo_iterator = EchoAsyncIterator(message, k, delay=0.3) streamer = AsyncResponseDataStreamer(echo_iterator, selected_uid=0, delay=0.3) return await streamer.stream(request)