import re import bittensor as bt import time import json from aiohttp import web from collections import Counter from prompting.rewards import DateRewardModel, FloatDiffModel UNSUCCESSFUL_RESPONSE_PATTERNS = ["I'm sorry", "unable to", "I cannot", "I can't", "I am unable", "I am sorry", "I can not", "don't know", "not sure", "don't understand", "not capable"] reward_models = { 'date_qa': DateRewardModel(), 'math': FloatDiffModel(), } def completion_is_valid(completion: str): """ Get the completion statuses from the completions. """ if not completion.strip(): return False patt = re.compile(r'\b(?:' + '|'.join(UNSUCCESSFUL_RESPONSE_PATTERNS) + r')\b', re.IGNORECASE) if not len(re.findall(r'\w+',completion)) or patt.search(completion): return False return True def ensemble_result(completions: list, task_name: str, prefer: str = 'longest'): """ Ensemble completions from multiple models. # TODO: Measure agreement # TODO: Figure out how to mitigate the cabal effect (large groups will appear to be more credible) # TODO: Reward pipeline """ if not completions: return None answer = None if task_name in ('qa', 'summarization'): # No special handling for QA or summarization supporting_completions = completions elif task_name == 'date_qa': # filter the completions to be the ones that contain valid dates and if there are multiple dates, select the most common one (with support > 1) dates = list(map(reward_models[task_name].parse_dates_from_text, completions)) bt.logging.info(f"Unprocessed dates: {dates}") valid_date_indices = [i for i, d in enumerate(dates) if d] valid_completions = [completions[i] for i in valid_date_indices] valid_dates = [dates[i] for i in valid_date_indices] dates = [f"{d[0].strftime('%-d %B')} {d[1]}" for d in valid_dates] if not dates: return None counter = Counter(dates) most_common, count = counter.most_common()[0] answer = most_common if count == 1: supporting_completions = valid_completions else: supporting_completions = [c for i, c in enumerate(valid_completions) if dates[i]==most_common] elif task_name == 'math': # filter the completions to be the ones that contain valid numbers and if there are multiple values, select the most common one (with support > 1) # TODO: use the median instead of the most common value vals = list(map(reward_models[task_name].extract_number, completions)) vals = [val for val in vals if val] if not vals: return None most_common, count = Counter(dates).most_common()[0] bt.logging.info(f"Most common value: {most_common}, count: {count}") answer = most_common if count == 1: supporting_completions = completions else: supporting_completions = [c for i, c in enumerate(completions) if vals[i]==most_common] bt.logging.info(f"Supporting completions: {supporting_completions}") if prefer == 'longest': preferred_completion = sorted(supporting_completions, key=len)[-1] elif prefer == 'shortest': preferred_completion = sorted(supporting_completions, key=len)[0] elif prefer == 'most_common': preferred_completion = max(set(supporting_completions), key=supporting_completions.count) else: raise ValueError(f"Unknown ensemble preference: {prefer}") return { 'completion': preferred_completion, 'accepted_answer': answer, 'support': len(supporting_completions), 'support_indices': [completions.index(c) for c in supporting_completions], 'method': f'Selected the {prefer.replace("_", " ")} completion' } def guess_task_name(challenge: str): # TODO: use a pre-trained classifier to guess the task name categories = { 'summarization': re.compile('summar|quick rundown|overview'), 'date_qa': re.compile('exact date|tell me when|on what date|on what day|was born?|died?'), 'math': re.compile('math|solve|solution| sum |problem|geometric|vector|calculate|degrees|decimal|factorial'), } for task_name, patt in categories.items(): if patt.search(challenge): return task_name return 'qa' async def echo_stream(request_data: dict): k = request_data.get('k', 1) exclude = request_data.get('exclude', []) timeout = request_data.get('timeout', 0.2) message = '\n\n'.join(request_data['messages']) # Create a StreamResponse response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'text/plain'}) await response.prepare() completion = '' # Echo the message k times with a timeout between each chunk for _ in range(k): for word in message.split(): chunk = f'{word} ' await response.write(chunk.encode('utf-8')) completion += chunk time.sleep(timeout) bt.logging.info(f"Echoed: {chunk}") completion = completion.strip() # Prepare final JSON chunk json_chunk = json.dumps({ "uids": [0], "completion": completion, "completions": [completion.strip()], "timings": [0], "status_messages": ['Went well!'], "status_codes": [200], "completion_is_valid": [True], "task_name": 'echo', "ensemble_result": {} }) # Send the final JSON as part of the stream await response.write(f"\n\nJSON_RESPONSE_BEGIN:\n{json_chunk}".encode('utf-8')) # Finalize the response await response.write_eof() return response