CRSArena / src /model /UNICRS.py
Nol00's picture
Update src/model/UNICRS.py
f57c0ab verified
import json
import logging
import sys
from collections import defaultdict
from typing import Any, Dict, List, Tuple
import torch
from accelerate import Accelerator
from accelerate.utils import set_seed
from transformers import AutoModel, AutoTokenizer
sys.path.append("..")
from src.model.unicrs.config import get_special_tokens_dict
from src.model.unicrs.kg_unicrs import KGForUniCRS
from src.model.unicrs.model_gpt2 import PromptGPT2forCRS
from src.model.unicrs.model_prompt import KGPrompt
from src.model.utils import padded_tensor
class UNICRS:
def __init__(
self,
seed,
kg_dataset,
debug,
tokenizer_path,
context_max_length,
entity_max_length,
resp_max_length,
text_tokenizer_path,
model,
text_encoder,
num_bases,
rec_model,
conv_model,
):
if seed is not None:
set_seed(seed)
self.debug = debug
self.accelerator = Accelerator(
device_placement=False, mixed_precision="fp16"
)
self.device = self.accelerator.device
self.context_max_length = context_max_length
self.entity_max_length = entity_max_length
self.resp_max_length = resp_max_length
self.padding = "max_length"
self.pad_to_multiple_of = 8
self.tokenizer_path = tokenizer_path
self.text_tokenizer_path = text_tokenizer_path
self.text_encoder = text_encoder
self.model_path = model
self.rec_model_path = rec_model
self.conv_model_path = conv_model
# config
gpt2_special_tokens_dict, prompt_special_tokens_dict = (
get_special_tokens_dict(kg_dataset)
)
# backbone
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
self.tokenizer.add_special_tokens(gpt2_special_tokens_dict)
self.tokenizer.padding_side = "left"
self.model = PromptGPT2forCRS.from_pretrained(self.model_path)
self.model.resize_token_embeddings(len(self.tokenizer))
self.model.config.pad_token_id = self.tokenizer.pad_token_id
self.model = self.model.to(self.device)
# text prompt encoder
self.prompt_tokenizer = AutoTokenizer.from_pretrained(
self.text_tokenizer_path
)
self.prompt_tokenizer.add_special_tokens(prompt_special_tokens_dict)
self.text_encoder = AutoModel.from_pretrained(self.text_encoder)
self.text_encoder.resize_token_embeddings(len(self.prompt_tokenizer))
self.text_encoder = self.text_encoder.to(self.device)
# kg prompt
self.kg_dataset = kg_dataset
self.kg = KGForUniCRS(
kg=self.kg_dataset, debug=self.debug
).get_kg_info()
self.item_ids = torch.as_tensor(
self.kg["item_ids"], device=self.device
)
self.kg_dataset_path = f"data/{self.kg_dataset}"
with open(
f"{self.kg_dataset_path}/entity2id.json", "r", encoding="utf-8"
) as f:
self.entity2id = json.load(f)
self.entity_pad_id = self.kg["pad_entity_id"]
self.num_bases = num_bases
# prompt for rec
self.rec_prompt_encoder = KGPrompt(
self.model.config.n_embd,
self.text_encoder.config.hidden_size,
self.model.config.n_head,
self.model.config.n_layer,
2,
n_entity=self.kg["num_entities"],
num_relations=self.kg["num_relations"],
num_bases=self.num_bases,
edge_index=self.kg["edge_index"],
edge_type=self.kg["edge_type"],
)
if rec_model is not None:
self.rec_prompt_encoder.load(self.rec_model_path)
self.rec_prompt_encoder = self.rec_prompt_encoder.to(self.device)
self.rec_prompt_encoder = self.accelerator.prepare(
self.rec_prompt_encoder
)
# prompt for conv
self.conv_prompt_encoder = KGPrompt(
self.model.config.n_embd,
self.text_encoder.config.hidden_size,
self.model.config.n_head,
self.model.config.n_layer,
2,
n_entity=self.kg["num_entities"],
num_relations=self.kg["num_relations"],
num_bases=self.num_bases,
edge_index=self.kg["edge_index"],
edge_type=self.kg["edge_type"],
)
if conv_model is not None:
self.conv_prompt_encoder.load(self.conv_model_path)
self.conv_prompt_encoder = self.conv_prompt_encoder.to(self.device)
self.conv_prompt_encoder = self.accelerator.prepare(
self.conv_prompt_encoder
)
def get_rec(self, conv_dict):
text_list = []
turn_idx = 0
for utt in conv_dict["context"]:
if utt != "":
text = ""
if turn_idx % 2 == 0:
text += "User: "
else:
text += "System: "
text += utt
text_list.append(text)
turn_idx += 1
context = f"{self.tokenizer.eos_token}".join(text_list)
context += f"{self.tokenizer.eos_token}"
prompt_context = f"{self.prompt_tokenizer.sep_token}".join(text_list)
self.tokenizer.truncation_side = "left"
context_ids = self.tokenizer.encode(
context, truncation=True, max_length=self.context_max_length
)
self.prompt_tokenizer.truncation_side = "left"
prompt_ids = self.prompt_tokenizer.encode(
prompt_context, truncation=True, max_length=self.context_max_length
)
self.data_list = []
if "rec" not in conv_dict.keys() or not conv_dict["rec"]:
# Interactive mode: the ground truth is not provided
data_dict = {
"context": context_ids,
"prompt": prompt_ids,
"entity": [
self.entity2id[ent]
for ent in conv_dict["entity"][-self.entity_max_length :]
if ent in self.entity2id
],
}
self.data_list.append(data_dict)
else:
for rec in conv_dict["rec"]:
if rec in self.entity2id:
data_dict = {
"context": context_ids,
"prompt": prompt_ids,
"entity": [
self.entity2id[ent]
for ent in conv_dict["entity"][
-self.entity_max_length :
]
if ent in self.entity2id
],
"rec": self.entity2id[rec],
}
self.data_list.append(data_dict)
context_dict = defaultdict(list)
prompt_dict = defaultdict(list)
entity_list = []
label_list = []
for data in self.data_list:
context_dict["input_ids"].append(data["context"])
prompt_dict["input_ids"].append(data["prompt"])
entity_list.append(data["entity"])
if "rec" in data.keys():
label_list.append(data["rec"])
context_dict = self.tokenizer.pad(
context_dict,
max_length=self.context_max_length,
padding=self.padding,
pad_to_multiple_of=self.pad_to_multiple_of,
)
if len(label_list) > 0:
context_dict["rec_labels"] = label_list
for k, v in context_dict.items():
if not isinstance(v, torch.Tensor):
context_dict[k] = torch.as_tensor(v, device=self.device)
position_ids = context_dict["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(context_dict["attention_mask"] == 0, 1)
context_dict["position_ids"] = position_ids
input_batch = {} # for model
input_batch["context"] = context_dict
prompt_dict = self.prompt_tokenizer.pad(
prompt_dict,
max_length=self.context_max_length,
padding=self.padding,
pad_to_multiple_of=self.pad_to_multiple_of,
)
for k, v in prompt_dict.items():
if not isinstance(v, torch.Tensor):
prompt_dict[k] = torch.as_tensor(v, device=self.device)
input_batch["prompt"] = prompt_dict
entity_list = padded_tensor(
entity_list,
pad_id=self.entity_pad_id,
pad_tail=True,
device=self.device,
debug=self.debug,
max_length=self.entity_max_length,
)
input_batch["entity"] = entity_list
# infer
token_embeds = self.text_encoder(
**input_batch["prompt"]
).last_hidden_state
prompt_embeds = self.rec_prompt_encoder(
entity_ids=input_batch["entity"],
token_embeds=token_embeds,
output_entity=True,
)
input_batch["context"]["prompt_embeds"] = prompt_embeds
input_batch["context"][
"entity_embeds"
] = self.rec_prompt_encoder.get_entity_embeds()
outputs = self.model(**input_batch["context"], rec=True)
logits = outputs.rec_logits[:, self.item_ids]
ranks = torch.topk(logits, k=50, dim=-1).indices
preds = self.item_ids[ranks].tolist()
if "rec_labels" in input_batch["context"]:
labels = input_batch["context"]["rec_labels"].tolist()
else:
labels = None
return preds, labels
def get_conv(self, conv_dict):
# dataset
text_list = []
turn_idx = 0
for utt in conv_dict["context"]:
if utt != "" and len(utt) > 0:
text = ""
if turn_idx % 2 == 0:
text += "User: "
else:
text += "System: "
text += utt
text_list.append(text)
turn_idx += 1
context = f"{self.tokenizer.eos_token}".join(text_list)
context += f"{self.tokenizer.eos_token}"
prompt_context = f"{self.prompt_tokenizer.sep_token}".join(text_list)
self.tokenizer.truncation_side = "left"
context_ids = self.tokenizer.encode(
context, truncation=True, max_length=self.context_max_length
)
self.prompt_tokenizer.truncation_side = "left"
prompt_ids = self.prompt_tokenizer.encode(
prompt_context, truncation=True, max_length=self.context_max_length
)
self.tokenizer.truncation_side = "right"
if turn_idx % 2 == 0:
user_str = "User: "
else:
user_str = "System: "
resp = user_str + conv_dict["resp"]
resp_ids = self.tokenizer.encode(
resp, truncation=True, max_length=self.resp_max_length
)
resp_ids.append(self.tokenizer.eos_token_id)
entity_list = [
self.entity2id[ent]
for ent in conv_dict["entity"][-self.entity_max_length :]
if ent in self.entity2id
]
data_dict = {
"context": context_ids,
"prompt": prompt_ids,
"entity": entity_list,
}
# dataloader
context_dict = defaultdict(list)
context_len_list = []
prompt_dict = defaultdict(list)
entity_list = []
label_dict = defaultdict(list)
bot_prompt = self.tokenizer.convert_tokens_to_ids(
self.tokenizer.tokenize("System:")
)
context = data_dict["context"] + bot_prompt
context_len_list.append((len(data_dict["context"])))
context_dict["input_ids"] = context
prompt_dict["input_ids"] = data_dict["prompt"]
entity_list.append(data_dict["entity"])
context_max_length = self.context_max_length + len(bot_prompt)
context_dict = self.tokenizer.pad(
context_dict,
max_length=context_max_length,
padding=self.padding,
pad_to_multiple_of=self.pad_to_multiple_of,
)
for k, v in context_dict.items():
if not isinstance(v, torch.Tensor):
context_dict[k] = torch.as_tensor(
v, device=self.device
).unsqueeze(0)
input_batch = {}
position_ids = context_dict["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(context_dict["attention_mask"] == 0, 1)
context_dict["position_ids"] = position_ids
input_batch["conv_labels"] = label_dict["input_ids"]
input_batch["context_len"] = context_len_list
input_batch["context"] = context_dict
prompt_dict = self.prompt_tokenizer.pad(
prompt_dict,
max_length=self.context_max_length,
padding=self.padding,
pad_to_multiple_of=self.pad_to_multiple_of,
)
for k, v in prompt_dict.items():
if not isinstance(v, torch.Tensor):
prompt_dict[k] = torch.as_tensor(
v, device=self.device
).unsqueeze(0)
input_batch["prompt"] = prompt_dict
entity_list = padded_tensor(
entity_list,
pad_id=self.entity_pad_id,
pad_tail=True,
device=self.device,
debug=self.debug,
max_length=self.entity_max_length,
)
input_batch["entity"] = entity_list
# infer
self.conv_prompt_encoder.eval()
token_embeds = self.text_encoder(
**input_batch["prompt"]
).last_hidden_state
prompt_embeds = self.conv_prompt_encoder(
entity_ids=input_batch["entity"],
token_embeds=token_embeds,
output_entity=False,
use_conv_prefix=True,
)
input_batch["context"]["prompt_embeds"] = prompt_embeds
gen_args = {
"max_new_tokens": self.resp_max_length,
"no_repeat_ngram_size": 3,
}
gen_seqs = self.model.generate(**input_batch["context"], **gen_args)
gen_str = self.tokenizer.decode(gen_seqs[0], skip_special_tokens=False)
return input_batch, gen_str
def get_choice(self, gen_inputs, options, state, conv_dict=None):
state = torch.as_tensor(state, device=self.device)
outputs = self.accelerator.unwrap_model(self.model).generate(
**gen_inputs["context"],
min_new_tokens=1,
max_new_tokens=1,
return_dict_in_generate=True,
output_scores=True,
)
option_token_ids = [
self.tokenizer.encode(op, add_special_tokens=False)[0]
for op in options
]
option_scores = outputs.scores[-1][0][option_token_ids]
option_scores += state
option_with_max_score = options[torch.argmax(option_scores)]
return option_with_max_score
def get_response(
self,
conv_dict: Dict[str, Any],
id2entity: Dict[int, str],
options: Tuple[str, Dict[str, str]],
state: List[float],
movie_token: str = "<pad>",
) -> Tuple[str, List[float]]:
"""Generates a response given a conversation context.
The method is based on the logic of the ask mode (i.e., see
`scripts/ask.py`). It consists of two steps: (1) choose to either
recommend items or generate a response, and (2) execute the chosen
step. Slightly deviates from the original implementation by not using
templates.
Args:
conv_dict: Conversation context.
id2entity: Mapping from entity ID to entity name.
options: Prompt with options and dictionary of options.
state: State of the option choices.
movie_token: Mask token for the movie. Defaults to "<pad>".
Returns:
Generated response and updated state.
"""
generated_inputs, generated_response = self.get_conv(conv_dict)
options_letter = list(options[1].keys())
# Get the choice between recommend and generate
choice = self.get_choice(generated_inputs, options_letter, state)
# Generate a recommendation
recommended_items, _ = self.get_rec(conv_dict)
if choice == options_letter[-1]:
recommended_items_str = ""
for i, item_id in enumerate(recommended_items[0][:3]):
recommended_items_str += f"{i+1}: {id2entity[item_id]} \n"
response = (
"I would recommend the following items: \n"
f"{recommended_items_str}"
)
else:
# Original : Generate a response to ask for preferences. The
# fallback is to use the generated response.
# response = (
# options[1].get(choice, {}).get("template", generated_response)
# )
generated_response = generated_response[
generated_response.rfind("System:") + len("System:") + 1 :
]
generated_response = generated_response.replace(
"<|endoftext|>", ""
)
for i in range(str.count(generated_response, movie_token)):
try:
generated_response = generated_response.replace(
movie_token, id2entity[recommended_items[0][i]], 1
)
except IndexError as e:
logging.error(e)
generated_response = generated_response.replace(
movie_token, "", 1
)
response = generated_response.strip()
# Update the state. Hack: penalize the choice to reduce the
# likelihood of selecting the same choice again
state[options_letter.index(choice)] += -1e5
return response, state