Spaces:
Sleeping
Sleeping
File size: 4,397 Bytes
04f923c 67312ac 04f923c 67312ac 04f923c 67312ac 04f923c 67312ac 04f923c 67312ac b79d85c 04f923c 67312ac 04f923c b79d85c 04f923c 67312ac b79d85c 04f923c 67312ac 04f923c 67312ac 04f923c 67312ac b79d85c 04f923c b79d85c 04f923c 67312ac 04f923c b79d85c 67312ac 04f923c 67312ac 04f923c 67312ac 04f923c 67312ac 04f923c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import json
import torch
from dataclasses import dataclass
####################################
# SCRIPT ARGUMENTS
####################################
@dataclass
class ScriptArguments:
"""
Arguments for the Bradley-Terry evaluation script.
"""
old_generations_file: str
new_generations_file: str
output_file: str = 'bt_results.json'
####################################
# FUNCTIONS
####################################
def load_rewards(file_path):
"""
Load the rewards from a JSON file.
Args:
file_path (str): Path to the JSON file containing model generations and rewards.
Returns:
list: List of dictionaries with prompts, outputs, and rewards.
"""
with open(file_path, 'r') as f:
return json.load(f)
def bradley_terry_comparison(old_rewards, new_rewards):
"""
Perform Bradley-Terry comparison between two sets of model generations.
Args:
old_rewards (list): List of dictionaries for the OLD model's generations and rewards.
new_rewards (list): List of dictionaries for the NEW model's generations and rewards.
Returns:
list: Comparison results including preferred outputs and probabilities.
dict: Metrics summary including percentage preferred and average probabilities.
"""
results = []
new_preferred_count = 0
old_preferred_count = 0
probabilities = []
for ix in range(len(old_rewards)):
old = old_rewards[ix]
new = new_rewards[ix]
# Ensure prompts match
assert old['prompt'] == new['prompt'], f"ERROR: Prompts at index {ix} do not match."
# Compute Bradley-Terry probability
new_reward = torch.tensor(old['reward'], dtype=torch.float32)
old_reward = torch.tensor(new['reward'], dtype=torch.float32)
prob_new_preferred = torch.sigmoid(new_reward - old_reward).item()
probabilities.append(prob_new_preferred)
preferred_model = 'new' if prob_new_preferred > 0.5 else 'old'
# Count preferences
if preferred_model == 'new':
new_preferred_count += 1
else:
old_preferred_count += 1
# Log results
bt_result = {
'prompt': old['prompt'],
'old_output': old['output'],
'new_output': new['output'],
'old_reward': old['reward'],
'new_reward': new['reward'],
'preferred': preferred_model,
'prob_new_preferred': prob_new_preferred
}
results.append(bt_result)
# Calculate metrics
total_examples = len(old_rewards)
metrics = {
'total_examples': total_examples,
'new_preferred_percentage': 100 * new_preferred_count / total_examples,
'old_preferred_percentage': 100 * old_preferred_count / total_examples,
'avg_probability_new_preferred': sum(probabilities) / total_examples
}
return results, metrics
def save_results(results, output_path):
"""
Save the comparison results to a JSON file.
Args:
results (list): List of comparison results.
output_path (str): Path to the output JSON file.
"""
with open(output_path, "w") as f:
json.dump(results, f, indent=4)
print(f"Results saved to {output_path}")
def print_metrics(metrics):
"""
Print evaluation metrics.
Args:
metrics (dict): Dictionary containing evaluation metrics.
"""
print("\nEVALUATION METRICS:")
print(f"Total examples: {metrics['total_examples']}")
print(f"Percentage preferred - KTO model: {metrics['new_preferred_percentage']:.2f}%")
print(f"Percentage preferred - SFT model: {metrics['old_preferred_percentage']:.2f}%")
print(f"Average probability of KTO model being preferred: {metrics['avg_probability_new_preferred']:.4f}")
####################################
# MAIN SCRIPT
####################################
def main():
args = ScriptArguments()
print("Loading data...")
old_rewards = load_rewards(args.sft_generations_file)
new_rewards = load_rewards(args.kto_generations_file)
# Perform Bradley-Terry comparison
print("Performing Bradley-Terry comparison...")
results, metrics = bradley_terry_comparison(old_rewards, new_rewards)
save_results(results, args.output_file)
print_metrics(metrics)
if __name__ == "__main__":
main()
|