Spaces:
Sleeping
Sleeping
import json | |
import torch | |
from dataclasses import dataclass | |
#################################### | |
# SCRIPT ARGUMENTS | |
#################################### | |
class ScriptArguments: | |
""" | |
Arguments for the Bradley-Terry evaluation script. | |
""" | |
old_generations_file: str | |
new_generations_file: str | |
output_file: str = 'bt_results.json' | |
#################################### | |
# FUNCTIONS | |
#################################### | |
def load_rewards(file_path): | |
""" | |
Load the rewards from a JSON file. | |
Args: | |
file_path (str): Path to the JSON file containing model generations and rewards. | |
Returns: | |
list: List of dictionaries with prompts, outputs, and rewards. | |
""" | |
with open(file_path, 'r') as f: | |
return json.load(f) | |
def bradley_terry_comparison(old_rewards, new_rewards): | |
""" | |
Perform Bradley-Terry comparison between two sets of model generations. | |
Args: | |
old_rewards (list): List of dictionaries for the OLD model's generations and rewards. | |
new_rewards (list): List of dictionaries for the NEW model's generations and rewards. | |
Returns: | |
list: Comparison results including preferred outputs and probabilities. | |
dict: Metrics summary including percentage preferred and average probabilities. | |
""" | |
results = [] | |
new_preferred_count = 0 | |
old_preferred_count = 0 | |
probabilities = [] | |
for ix in range(len(old_rewards)): | |
old = old_rewards[ix] | |
new = new_rewards[ix] | |
# Ensure prompts match | |
assert old['prompt'] == new['prompt'], f"ERROR: Prompts at index {ix} do not match." | |
# Compute Bradley-Terry probability | |
new_reward = torch.tensor(old['reward'], dtype=torch.float32) | |
old_reward = torch.tensor(new['reward'], dtype=torch.float32) | |
prob_new_preferred = torch.sigmoid(new_reward - old_reward).item() | |
probabilities.append(prob_new_preferred) | |
preferred_model = 'new' if prob_new_preferred > 0.5 else 'old' | |
# Count preferences | |
if preferred_model == 'new': | |
new_preferred_count += 1 | |
else: | |
old_preferred_count += 1 | |
# Log results | |
bt_result = { | |
'prompt': old['prompt'], | |
'old_output': old['output'], | |
'new_output': new['output'], | |
'old_reward': old['reward'], | |
'new_reward': new['reward'], | |
'preferred': preferred_model, | |
'prob_new_preferred': prob_new_preferred | |
} | |
results.append(bt_result) | |
# Calculate metrics | |
total_examples = len(old_rewards) | |
metrics = { | |
'total_examples': total_examples, | |
'new_preferred_percentage': 100 * new_preferred_count / total_examples, | |
'old_preferred_percentage': 100 * old_preferred_count / total_examples, | |
'avg_probability_new_preferred': sum(probabilities) / total_examples | |
} | |
return results, metrics | |
def save_results(results, output_path): | |
""" | |
Save the comparison results to a JSON file. | |
Args: | |
results (list): List of comparison results. | |
output_path (str): Path to the output JSON file. | |
""" | |
with open(output_path, "w") as f: | |
json.dump(results, f, indent=4) | |
print(f"Results saved to {output_path}") | |
def print_metrics(metrics): | |
""" | |
Print evaluation metrics. | |
Args: | |
metrics (dict): Dictionary containing evaluation metrics. | |
""" | |
print("\nEVALUATION METRICS:") | |
print(f"Total examples: {metrics['total_examples']}") | |
print(f"Percentage preferred - KTO model: {metrics['new_preferred_percentage']:.2f}%") | |
print(f"Percentage preferred - SFT model: {metrics['old_preferred_percentage']:.2f}%") | |
print(f"Average probability of KTO model being preferred: {metrics['avg_probability_new_preferred']:.4f}") | |
#################################### | |
# MAIN SCRIPT | |
#################################### | |
def main(): | |
args = ScriptArguments() | |
print("Loading data...") | |
old_rewards = load_rewards(args.sft_generations_file) | |
new_rewards = load_rewards(args.kto_generations_file) | |
# Perform Bradley-Terry comparison | |
print("Performing Bradley-Terry comparison...") | |
results, metrics = bradley_terry_comparison(old_rewards, new_rewards) | |
save_results(results, args.output_file) | |
print_metrics(metrics) | |
if __name__ == "__main__": | |
main() | |