Last commit not found
import json | |
import logging | |
import sys | |
from collections import defaultdict | |
from typing import Any, Dict, List, Tuple | |
import torch | |
from accelerate import Accelerator | |
from accelerate.utils import set_seed | |
from transformers import AutoModel, AutoTokenizer | |
sys.path.append("..") | |
from src.model.unicrs.config import get_special_tokens_dict | |
from src.model.unicrs.kg_unicrs import KGForUniCRS | |
from src.model.unicrs.model_gpt2 import PromptGPT2forCRS | |
from src.model.unicrs.model_prompt import KGPrompt | |
from src.model.utils import padded_tensor | |
class UNICRS: | |
def __init__( | |
self, | |
seed, | |
kg_dataset, | |
debug, | |
tokenizer_path, | |
context_max_length, | |
entity_max_length, | |
resp_max_length, | |
text_tokenizer_path, | |
model, | |
text_encoder, | |
num_bases, | |
rec_model, | |
conv_model, | |
): | |
if seed is not None: | |
set_seed(seed) | |
self.debug = debug | |
self.accelerator = Accelerator( | |
device_placement=False, mixed_precision="fp16" | |
) | |
self.device = self.accelerator.device | |
self.context_max_length = context_max_length | |
self.entity_max_length = entity_max_length | |
self.resp_max_length = resp_max_length | |
self.padding = "max_length" | |
self.pad_to_multiple_of = 8 | |
self.tokenizer_path = tokenizer_path | |
self.text_tokenizer_path = text_tokenizer_path | |
self.text_encoder = text_encoder | |
self.model_path = model | |
self.rec_model_path = rec_model | |
self.conv_model_path = conv_model | |
# config | |
gpt2_special_tokens_dict, prompt_special_tokens_dict = ( | |
get_special_tokens_dict(kg_dataset) | |
) | |
# backbone | |
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path) | |
self.tokenizer.add_special_tokens(gpt2_special_tokens_dict) | |
self.tokenizer.padding_side = "left" | |
self.model = PromptGPT2forCRS.from_pretrained(self.model_path) | |
self.model.resize_token_embeddings(len(self.tokenizer)) | |
self.model.config.pad_token_id = self.tokenizer.pad_token_id | |
self.model = self.model.to(self.device) | |
# text prompt encoder | |
self.prompt_tokenizer = AutoTokenizer.from_pretrained( | |
self.text_tokenizer_path | |
) | |
self.prompt_tokenizer.add_special_tokens(prompt_special_tokens_dict) | |
self.text_encoder = AutoModel.from_pretrained(self.text_encoder) | |
self.text_encoder.resize_token_embeddings(len(self.prompt_tokenizer)) | |
self.text_encoder = self.text_encoder.to(self.device) | |
# kg prompt | |
self.kg_dataset = kg_dataset | |
self.kg = KGForUniCRS( | |
kg=self.kg_dataset, debug=self.debug | |
).get_kg_info() | |
self.item_ids = torch.as_tensor( | |
self.kg["item_ids"], device=self.device | |
) | |
self.kg_dataset_path = f"data/{self.kg_dataset}" | |
with open( | |
f"{self.kg_dataset_path}/entity2id.json", "r", encoding="utf-8" | |
) as f: | |
self.entity2id = json.load(f) | |
self.entity_pad_id = self.kg["pad_entity_id"] | |
self.num_bases = num_bases | |
# prompt for rec | |
self.rec_prompt_encoder = KGPrompt( | |
self.model.config.n_embd, | |
self.text_encoder.config.hidden_size, | |
self.model.config.n_head, | |
self.model.config.n_layer, | |
2, | |
n_entity=self.kg["num_entities"], | |
num_relations=self.kg["num_relations"], | |
num_bases=self.num_bases, | |
edge_index=self.kg["edge_index"], | |
edge_type=self.kg["edge_type"], | |
) | |
if rec_model is not None: | |
self.rec_prompt_encoder.load(self.rec_model_path) | |
self.rec_prompt_encoder = self.rec_prompt_encoder.to(self.device) | |
self.rec_prompt_encoder = self.accelerator.prepare( | |
self.rec_prompt_encoder | |
) | |
# prompt for conv | |
self.conv_prompt_encoder = KGPrompt( | |
self.model.config.n_embd, | |
self.text_encoder.config.hidden_size, | |
self.model.config.n_head, | |
self.model.config.n_layer, | |
2, | |
n_entity=self.kg["num_entities"], | |
num_relations=self.kg["num_relations"], | |
num_bases=self.num_bases, | |
edge_index=self.kg["edge_index"], | |
edge_type=self.kg["edge_type"], | |
) | |
if conv_model is not None: | |
self.conv_prompt_encoder.load(self.conv_model_path) | |
self.conv_prompt_encoder = self.conv_prompt_encoder.to(self.device) | |
self.conv_prompt_encoder = self.accelerator.prepare( | |
self.conv_prompt_encoder | |
) | |
def get_rec(self, conv_dict): | |
text_list = [] | |
turn_idx = 0 | |
for utt in conv_dict["context"]: | |
if utt != "": | |
text = "" | |
if turn_idx % 2 == 0: | |
text += "User: " | |
else: | |
text += "System: " | |
text += utt | |
text_list.append(text) | |
turn_idx += 1 | |
context = f"{self.tokenizer.eos_token}".join(text_list) | |
context += f"{self.tokenizer.eos_token}" | |
prompt_context = f"{self.prompt_tokenizer.sep_token}".join(text_list) | |
self.tokenizer.truncation_side = "left" | |
context_ids = self.tokenizer.encode( | |
context, truncation=True, max_length=self.context_max_length | |
) | |
self.prompt_tokenizer.truncation_side = "left" | |
prompt_ids = self.prompt_tokenizer.encode( | |
prompt_context, truncation=True, max_length=self.context_max_length | |
) | |
self.data_list = [] | |
if "rec" not in conv_dict.keys() or not conv_dict["rec"]: | |
# Interactive mode: the ground truth is not provided | |
data_dict = { | |
"context": context_ids, | |
"prompt": prompt_ids, | |
"entity": [ | |
self.entity2id[ent] | |
for ent in conv_dict["entity"][-self.entity_max_length :] | |
if ent in self.entity2id | |
], | |
} | |
self.data_list.append(data_dict) | |
else: | |
for rec in conv_dict["rec"]: | |
if rec in self.entity2id: | |
data_dict = { | |
"context": context_ids, | |
"prompt": prompt_ids, | |
"entity": [ | |
self.entity2id[ent] | |
for ent in conv_dict["entity"][ | |
-self.entity_max_length : | |
] | |
if ent in self.entity2id | |
], | |
"rec": self.entity2id[rec], | |
} | |
self.data_list.append(data_dict) | |
context_dict = defaultdict(list) | |
prompt_dict = defaultdict(list) | |
entity_list = [] | |
label_list = [] | |
for data in self.data_list: | |
context_dict["input_ids"].append(data["context"]) | |
prompt_dict["input_ids"].append(data["prompt"]) | |
entity_list.append(data["entity"]) | |
if "rec" in data.keys(): | |
label_list.append(data["rec"]) | |
context_dict = self.tokenizer.pad( | |
context_dict, | |
max_length=self.context_max_length, | |
padding=self.padding, | |
pad_to_multiple_of=self.pad_to_multiple_of, | |
) | |
if len(label_list) > 0: | |
context_dict["rec_labels"] = label_list | |
for k, v in context_dict.items(): | |
if not isinstance(v, torch.Tensor): | |
context_dict[k] = torch.as_tensor(v, device=self.device) | |
position_ids = context_dict["attention_mask"].long().cumsum(-1) - 1 | |
position_ids.masked_fill_(context_dict["attention_mask"] == 0, 1) | |
context_dict["position_ids"] = position_ids | |
input_batch = {} # for model | |
input_batch["context"] = context_dict | |
prompt_dict = self.prompt_tokenizer.pad( | |
prompt_dict, | |
max_length=self.context_max_length, | |
padding=self.padding, | |
pad_to_multiple_of=self.pad_to_multiple_of, | |
) | |
for k, v in prompt_dict.items(): | |
if not isinstance(v, torch.Tensor): | |
prompt_dict[k] = torch.as_tensor(v, device=self.device) | |
input_batch["prompt"] = prompt_dict | |
entity_list = padded_tensor( | |
entity_list, | |
pad_id=self.entity_pad_id, | |
pad_tail=True, | |
device=self.device, | |
debug=self.debug, | |
max_length=self.entity_max_length, | |
) | |
input_batch["entity"] = entity_list | |
# infer | |
token_embeds = self.text_encoder( | |
**input_batch["prompt"] | |
).last_hidden_state | |
prompt_embeds = self.rec_prompt_encoder( | |
entity_ids=input_batch["entity"], | |
token_embeds=token_embeds, | |
output_entity=True, | |
) | |
input_batch["context"]["prompt_embeds"] = prompt_embeds | |
input_batch["context"][ | |
"entity_embeds" | |
] = self.rec_prompt_encoder.get_entity_embeds() | |
outputs = self.model(**input_batch["context"], rec=True) | |
logits = outputs.rec_logits[:, self.item_ids] | |
ranks = torch.topk(logits, k=50, dim=-1).indices | |
preds = self.item_ids[ranks].tolist() | |
if "rec_labels" in input_batch["context"]: | |
labels = input_batch["context"]["rec_labels"].tolist() | |
else: | |
labels = None | |
return preds, labels | |
def get_conv(self, conv_dict): | |
# dataset | |
text_list = [] | |
turn_idx = 0 | |
for utt in conv_dict["context"]: | |
if utt != "" and len(utt) > 0: | |
text = "" | |
if turn_idx % 2 == 0: | |
text += "User: " | |
else: | |
text += "System: " | |
text += utt | |
text_list.append(text) | |
turn_idx += 1 | |
context = f"{self.tokenizer.eos_token}".join(text_list) | |
context += f"{self.tokenizer.eos_token}" | |
prompt_context = f"{self.prompt_tokenizer.sep_token}".join(text_list) | |
self.tokenizer.truncation_side = "left" | |
context_ids = self.tokenizer.encode( | |
context, truncation=True, max_length=self.context_max_length | |
) | |
self.prompt_tokenizer.truncation_side = "left" | |
prompt_ids = self.prompt_tokenizer.encode( | |
prompt_context, truncation=True, max_length=self.context_max_length | |
) | |
self.tokenizer.truncation_side = "right" | |
if turn_idx % 2 == 0: | |
user_str = "User: " | |
else: | |
user_str = "System: " | |
resp = user_str + conv_dict["resp"] | |
resp_ids = self.tokenizer.encode( | |
resp, truncation=True, max_length=self.resp_max_length | |
) | |
resp_ids.append(self.tokenizer.eos_token_id) | |
entity_list = [ | |
self.entity2id[ent] | |
for ent in conv_dict["entity"][-self.entity_max_length :] | |
if ent in self.entity2id | |
] | |
data_dict = { | |
"context": context_ids, | |
"prompt": prompt_ids, | |
"entity": entity_list, | |
} | |
# dataloader | |
context_dict = defaultdict(list) | |
context_len_list = [] | |
prompt_dict = defaultdict(list) | |
entity_list = [] | |
label_dict = defaultdict(list) | |
bot_prompt = self.tokenizer.convert_tokens_to_ids( | |
self.tokenizer.tokenize("System:") | |
) | |
context = data_dict["context"] + bot_prompt | |
context_len_list.append((len(data_dict["context"]))) | |
context_dict["input_ids"] = context | |
prompt_dict["input_ids"] = data_dict["prompt"] | |
entity_list.append(data_dict["entity"]) | |
context_max_length = self.context_max_length + len(bot_prompt) | |
context_dict = self.tokenizer.pad( | |
context_dict, | |
max_length=context_max_length, | |
padding=self.padding, | |
pad_to_multiple_of=self.pad_to_multiple_of, | |
) | |
for k, v in context_dict.items(): | |
if not isinstance(v, torch.Tensor): | |
context_dict[k] = torch.as_tensor( | |
v, device=self.device | |
).unsqueeze(0) | |
input_batch = {} | |
position_ids = context_dict["attention_mask"].long().cumsum(-1) - 1 | |
position_ids.masked_fill_(context_dict["attention_mask"] == 0, 1) | |
context_dict["position_ids"] = position_ids | |
input_batch["conv_labels"] = label_dict["input_ids"] | |
input_batch["context_len"] = context_len_list | |
input_batch["context"] = context_dict | |
prompt_dict = self.prompt_tokenizer.pad( | |
prompt_dict, | |
max_length=self.context_max_length, | |
padding=self.padding, | |
pad_to_multiple_of=self.pad_to_multiple_of, | |
) | |
for k, v in prompt_dict.items(): | |
if not isinstance(v, torch.Tensor): | |
prompt_dict[k] = torch.as_tensor( | |
v, device=self.device | |
).unsqueeze(0) | |
input_batch["prompt"] = prompt_dict | |
entity_list = padded_tensor( | |
entity_list, | |
pad_id=self.entity_pad_id, | |
pad_tail=True, | |
device=self.device, | |
debug=self.debug, | |
max_length=self.entity_max_length, | |
) | |
input_batch["entity"] = entity_list | |
# infer | |
self.conv_prompt_encoder.eval() | |
token_embeds = self.text_encoder( | |
**input_batch["prompt"] | |
).last_hidden_state | |
prompt_embeds = self.conv_prompt_encoder( | |
entity_ids=input_batch["entity"], | |
token_embeds=token_embeds, | |
output_entity=False, | |
use_conv_prefix=True, | |
) | |
input_batch["context"]["prompt_embeds"] = prompt_embeds | |
gen_args = { | |
"max_new_tokens": self.resp_max_length, | |
"no_repeat_ngram_size": 3, | |
} | |
gen_seqs = self.model.generate(**input_batch["context"], **gen_args) | |
gen_str = self.tokenizer.decode(gen_seqs[0], skip_special_tokens=True) | |
return input_batch, gen_str | |
def get_choice(self, gen_inputs, options, state, conv_dict=None): | |
state = torch.as_tensor(state, device=self.device) | |
outputs = self.accelerator.unwrap_model(self.model).generate( | |
**gen_inputs["context"], | |
min_new_tokens=1, | |
max_new_tokens=1, | |
return_dict_in_generate=True, | |
output_scores=True, | |
) | |
option_token_ids = [ | |
self.tokenizer.encode(op, add_special_tokens=False)[0] | |
for op in options | |
] | |
option_scores = outputs.scores[-1][0][option_token_ids] | |
option_scores += state | |
option_with_max_score = options[torch.argmax(option_scores)] | |
return option_with_max_score | |
def get_response( | |
self, | |
conv_dict: Dict[str, Any], | |
id2entity: Dict[int, str], | |
options: Tuple[str, Dict[str, str]], | |
state: List[float], | |
movie_token: str = "<mask>", | |
) -> Tuple[str, List[float]]: | |
"""Generates a response given a conversation context. | |
The method is based on the logic of the ask mode (i.e., see | |
`scripts/ask.py`). It consists of two steps: (1) choose to either | |
recommend items or generate a response, and (2) execute the chosen | |
step. Slightly deviates from the original implementation by not using | |
templates. | |
Args: | |
conv_dict: Conversation context. | |
id2entity: Mapping from entity ID to entity name. | |
options: Prompt with options and dictionary of options. | |
state: State of the option choices. | |
movie_token: Mask token for the movie. Defaults to "<mask>". | |
Returns: | |
Generated response and updated state. | |
""" | |
generated_inputs, generated_response = self.get_conv(conv_dict) | |
options_letter = list(options[1].keys()) | |
# Get the choice between recommend and generate | |
choice = self.get_choice(generated_inputs, options_letter, state) | |
# Generate a recommendation | |
recommended_items, _ = self.get_rec(conv_dict) | |
if choice == options_letter[-1]: | |
recommended_items_str = "" | |
for i, item_id in enumerate(recommended_items[0][:3]): | |
recommended_items_str += f"{i+1}: {id2entity[item_id]} \n" | |
response = ( | |
"I would recommend the following items: \n" | |
f"{recommended_items_str}" | |
) | |
else: | |
# Original : Generate a response to ask for preferences. The | |
# fallback is to use the generated response. | |
# response = ( | |
# options[1].get(choice, {}).get("template", generated_response) | |
# ) | |
generated_response = generated_response[ | |
generated_response.rfind("System:") + len("System:") + 1 : | |
] | |
for i in range(str.count(generated_response, movie_token)): | |
try: | |
generated_response = generated_response.replace( | |
movie_token, id2entity[recommended_items[i]], 1 | |
) | |
except IndexError as e: | |
logging.error(e) | |
generated_response = generated_response.replace( | |
movie_token, "", 1 | |
) | |
response = generated_response.strip() | |
# Update the state. Hack: penalize the choice to reduce the | |
# likelihood of selecting the same choice again | |
state[options_letter.index(choice)] += -1e5 | |
return response, state | |