Spaces:
Sleeping
Sleeping
File size: 4,365 Bytes
48f02b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import re
import bittensor as bt
from collections import Counter
from prompting.rewards import DateRewardModel, FloatDiffModel
UNSUCCESSFUL_RESPONSE_PATTERNS = ["I'm sorry", "unable to", "I cannot", "I can't", "I am unable", "I am sorry", "I can not", "don't know", "not sure", "don't understand", "not capable"]
reward_models = {
'date_qa': DateRewardModel(),
'math': FloatDiffModel(),
}
def completion_is_valid(completion: str):
"""
Get the completion statuses from the completions.
"""
if not completion.strip():
return False
patt = re.compile(r'\b(?:' + '|'.join(UNSUCCESSFUL_RESPONSE_PATTERNS) + r')\b', re.IGNORECASE)
if not len(re.findall(r'\w+',completion)) or patt.search(completion):
return False
return True
def ensemble_result(completions: list, task_name: str, prefer: str = 'longest'):
"""
Ensemble completions from multiple models.
# TODO: Measure agreement
# TODO: Figure out how to mitigate the cabal effect (large groups will appear to be more credible)
# TODO: Reward pipeline
"""
if not completions:
return None
answer = None
if task_name in ('qa', 'summarization'):
# No special handling for QA or summarization
supporting_completions = completions
elif task_name == 'date_qa':
# filter the completions to be the ones that contain valid dates and if there are multiple dates, select the most common one (with support > 1)
dates = list(map(reward_models[task_name].parse_dates_from_text, completions))
bt.logging.info(f"Unprocessed dates: {dates}")
valid_date_indices = [i for i, d in enumerate(dates) if d]
valid_completions = [completions[i] for i in valid_date_indices]
valid_dates = [dates[i] for i in valid_date_indices]
dates = [f"{d[0].strftime('%-d %B')} {d[1]}" for d in valid_dates]
if not dates:
return None
counter = Counter(dates)
most_common, count = counter.most_common()[0]
answer = most_common
if count == 1:
supporting_completions = valid_completions
else:
supporting_completions = [c for i, c in enumerate(valid_completions) if dates[i]==most_common]
elif task_name == 'math':
# filter the completions to be the ones that contain valid numbers and if there are multiple values, select the most common one (with support > 1)
# TODO: use the median instead of the most common value
vals = list(map(reward_models[task_name].extract_number, completions))
vals = [val for val in vals if val]
if not vals:
return None
most_common, count = Counter(dates).most_common()[0]
bt.logging.info(f"Most common value: {most_common}, count: {count}")
answer = most_common
if count == 1:
supporting_completions = completions
else:
supporting_completions = [c for i, c in enumerate(completions) if vals[i]==most_common]
bt.logging.info(f"Supporting completions: {supporting_completions}")
if prefer == 'longest':
preferred_completion = sorted(supporting_completions, key=len)[-1]
elif prefer == 'shortest':
preferred_completion = sorted(supporting_completions, key=len)[0]
elif prefer == 'most_common':
preferred_completion = max(set(supporting_completions), key=supporting_completions.count)
else:
raise ValueError(f"Unknown ensemble preference: {prefer}")
return {
'completion': preferred_completion,
'accepted_answer': answer,
'support': len(supporting_completions),
'support_indices': [completions.index(c) for c in supporting_completions],
'method': f'Selected the {prefer.replace("_", " ")} completion'
}
def guess_task_name(challenge: str):
# TODO: use a pre-trained classifier to guess the task name
categories = {
'summarization': re.compile('summar|quick rundown|overview'),
'date_qa': re.compile('exact date|tell me when|on what date|on what day|was born?|died?'),
'math': re.compile('math|solve|solution| sum |problem|geometric|vector|calculate|degrees|decimal|factorial'),
}
for task_name, patt in categories.items():
if patt.search(challenge):
return task_name
return 'qa' |