|
import torch |
|
import random |
|
import numpy as np |
|
from collections import defaultdict |
|
from torch.distributions import Categorical |
|
from torch.nn.utils.rnn import pad_sequence |
|
from scrl.model import labels_to_summary |
|
from nltk import word_tokenize |
|
from pprint import pprint |
|
|
|
|
|
def sample_from_policy( |
|
input_ids, |
|
probs, |
|
device="cuda", |
|
force_diff=True, |
|
diff_trials=1000, |
|
): |
|
m = Categorical(probs) |
|
argmax_labels = torch.argmax(probs, dim=2) |
|
sample_labels = m.sample() |
|
|
|
if force_diff: |
|
for _ in range(diff_trials): |
|
if (argmax_labels == sample_labels).all(): |
|
sample_labels = m.sample() |
|
else: |
|
break |
|
|
|
sample_probs = m.log_prob(sample_labels) |
|
return sample_probs, sample_labels |
|
|
|
|
|
def best_of_k_samples( |
|
args, |
|
manager, |
|
tokenizer, |
|
reward_generator, |
|
input_ids, |
|
batch, |
|
probs, |
|
k_samples=50, |
|
return_all=False |
|
): |
|
batch_size = probs.size(0) |
|
|
|
prob_batches = [] |
|
summary_batches = [] |
|
reward_batches = [] |
|
detail_batches = [] |
|
label_batches = [] |
|
for _ in range(k_samples): |
|
sample_probs, sample_labels = sample_from_policy( |
|
input_ids, |
|
probs, |
|
device=args.device |
|
) |
|
sample_summaries = labels_to_summary( |
|
input_ids, sample_labels, tokenizer |
|
) |
|
sample_rewards, sample_details = reward_generator( |
|
batch["document"], sample_summaries |
|
) |
|
|
|
prob_batches.append(sample_probs) |
|
summary_batches.append(sample_summaries) |
|
reward_batches.append(sample_rewards) |
|
detail_batches.append(sample_details) |
|
label_batches.append(sample_labels) |
|
|
|
|
|
best_indices = [] |
|
for i in range(batch_size): |
|
rewards = [reward_batches[j][i] for j in range(k_samples)] |
|
scored = sorted(enumerate(rewards), key=lambda x: x[1], reverse=True) |
|
best_idx = scored[0][0] |
|
best_indices.append(best_idx) |
|
|
|
sample_probs = torch.stack([prob_batches[j][i] for i, j in enumerate(best_indices)]) |
|
sample_summaries = [summary_batches[j][i] for i, j in enumerate(best_indices)] |
|
sample_rewards = [reward_batches[j][i] for i, j in enumerate(best_indices)] |
|
sample_labels = torch.stack([label_batches[j][i] for i, j in enumerate(best_indices)]) |
|
|
|
sample_details = [] |
|
for i, j in enumerate(best_indices): |
|
detail_keys = sorted(detail_batches[0].keys()) |
|
details = defaultdict(list) |
|
for k in detail_keys: |
|
details[k].append(detail_batches[j][k][i]) |
|
sample_details.append(details) |
|
|
|
sample_data = { |
|
"probs": prob_batches, |
|
"rewards": reward_batches, |
|
"summaries": summary_batches, |
|
"details": detail_batches, |
|
"labels": label_batches, |
|
} |
|
return sample_probs, sample_summaries, sample_rewards, sample_details, sample_labels, sample_data |
|
|