import json import sys from collections import defaultdict from typing import Any, Dict, List, Tuple import torch from accelerate import Accelerator from accelerate.utils import set_seed from transformers import AutoTokenizer, BartConfig sys.path.append("..") from src.model.kbrd.kbrd_model import KBRDforConv, KBRDforRec from src.model.kbrd.kg_kbrd import KGForKBRD from src.model.utils import padded_tensor class KBRD: def __init__( self, seed, kg_dataset, debug, hidden_size, entity_hidden_size, num_bases, rec_model, conv_model, context_max_length, tokenizer_path, encoder_layers, decoder_layers, text_hidden_size, attn_head, resp_max_length, entity_max_length, ): self.seed = seed if self.seed is not None: set_seed(self.seed) self.kg_dataset = kg_dataset # model detailed self.debug = debug self.hidden_size = hidden_size self.entity_hidden_size = entity_hidden_size self.num_bases = num_bases self.context_max_length = context_max_length self.entity_max_length = entity_max_length # model self.rec_model = rec_model self.conv_model = conv_model # conv self.tokenizer_path = tokenizer_path self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path) self.encoder_layers = encoder_layers self.decoder_layers = decoder_layers self.text_hidden_size = text_hidden_size self.attn_head = attn_head self.resp_max_length = resp_max_length self.padding = "max_length" self.pad_to_multiple_of = 8 self.kg_dataset_path = f"data/{self.kg_dataset}" with open( f"{self.kg_dataset_path}/entity2id.json", "r", encoding="utf-8" ) as f: self.entity2id = json.load(f) # Initialize the accelerator. self.accelerator = Accelerator( device_placement=False, mixed_precision="fp16" ) self.device = self.accelerator.device self.kg = KGForKBRD( kg_dataset=self.kg_dataset, debug=self.debug ).get_kg_info() self.pad_id = self.kg["pad_id"] # rec model self.crs_rec_model = KBRDforRec( hidden_size=self.hidden_size, num_relations=self.kg["num_relations"], num_bases=self.num_bases, num_entities=self.kg["num_entities"], ) if self.rec_model is not None: self.crs_rec_model.load(self.rec_model) self.crs_rec_model = self.crs_rec_model.to(self.device) self.crs_rec_model = self.accelerator.prepare(self.crs_rec_model) # conv model config = BartConfig.from_pretrained( self.conv_model, encoder_layers=self.encoder_layers, decoder_layers=self.decoder_layers, hidden_size=self.text_hidden_size, encoder_attention_heads=self.attn_head, decoder_attention_heads=self.attn_head, encoder_ffn_dim=self.text_hidden_size, decoder_ffn_dim=self.text_hidden_size, forced_bos_token_id=None, forced_eos_token_id=None, ) self.crs_conv_model = KBRDforConv( config, user_hidden_size=self.entity_hidden_size ).to(self.device) if self.conv_model is not None: self.crs_conv_model = KBRDforConv.from_pretrained( self.conv_model, user_hidden_size=self.entity_hidden_size ).to(self.device) self.crs_conv_model = self.accelerator.prepare(self.crs_conv_model) def get_rec(self, conv_dict): data_dict = { "item": [ self.entity2id[rec] for rec in conv_dict["rec"] if rec in self.entity2id ], } entity_ids = ( [ self.entity2id[ent] for ent in conv_dict["entity"][-self.entity_max_length :] if ent in self.entity2id ], ) if "dialog_id" in conv_dict: data_dict["dialog_id"] = conv_dict["dialog_id"] if "turn_id" in conv_dict: data_dict["turn_id"] = conv_dict["turn_id"] if "template" in conv_dict: data_dict["template"] = conv_dict["template"] # kg edge_index, edge_type = torch.as_tensor( self.kg["edge_index"], device=self.device ), torch.as_tensor(self.kg["edge_type"], device=self.device) entity_ids = padded_tensor( entity_ids, pad_id=self.pad_id, pad_tail=True, max_length=self.entity_max_length, device=self.device, debug=self.debug, ) data_dict["entity"] = { "entity_ids": entity_ids, "entity_mask": torch.ne(entity_ids, self.pad_id), } # infer self.crs_rec_model.eval() with torch.no_grad(): data_dict["entity"]["edge_index"] = edge_index data_dict["entity"]["edge_type"] = edge_type outputs = self.crs_rec_model( **data_dict["entity"], reduction="mean" ) logits = outputs["logit"][:, self.kg["item_ids"]] ranks = torch.topk(logits, k=50, dim=-1).indices.tolist() preds = [ [self.kg["item_ids"][rank] for rank in rank_list] for rank_list in ranks ] labels = data_dict["item"] return preds, labels def get_conv(self, conv_dict): self.tokenizer.truncation_side = "left" context_list = conv_dict["context"] context = f"{self.tokenizer.sep_token}".join(context_list) context_ids = self.tokenizer.encode( context, truncation=True, max_length=self.context_max_length ) context_batch = defaultdict(list) context_batch["input_ids"] = context_ids context_ids = self.tokenizer.pad( context_batch, max_length=self.context_max_length, padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of, ) self.tokenizer.truncation_side = "right" resp = conv_dict["resp"] resp_batch = defaultdict(list) resp_ids = self.tokenizer.encode( resp, truncation=True, max_length=self.resp_max_length ) resp_batch["input_ids"] = resp_ids resp_batch = self.tokenizer.pad( resp_batch, max_length=self.resp_max_length, padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of, ) context_batch["labels"] = resp_batch["input_ids"] for k, v in context_batch.items(): if not isinstance(v, torch.Tensor): context_batch[k] = torch.as_tensor( v, device=self.device ).unsqueeze(0) entity_list = ( [ self.entity2id[ent] for ent in conv_dict["entity"][-self.entity_max_length :] if ent in self.entity2id ], ) entity_ids = padded_tensor( entity_list, pad_id=self.pad_id, pad_tail=True, device=self.device, debug=self.debug, max_length=self.context_max_length, ) entity = { "entity_ids": entity_ids, "entity_mask": torch.ne(entity_ids, self.pad_id), } data_dict = {"context": context_batch, "entity": entity} edge_index, edge_type = torch.as_tensor( self.kg["edge_index"], device=self.device ), torch.as_tensor(self.kg["edge_type"], device=self.device) node_embeds = self.crs_rec_model.get_node_embeds(edge_index, edge_type) user_embeds = self.crs_rec_model( **data_dict["entity"], node_embeds=node_embeds )["user_embeds"] gen_inputs = { **data_dict["context"], "decoder_user_embeds": user_embeds, } gen_inputs.pop("labels") gen_args = { "min_length": 0, "max_length": self.resp_max_length, "num_beams": 1, "no_repeat_ngram_size": 3, "encoder_no_repeat_ngram_size": 3, } gen_seqs = self.accelerator.unwrap_model(self.crs_conv_model).generate( **gen_inputs, **gen_args ) gen_str = self.tokenizer.decode(gen_seqs[0], skip_special_tokens=True) return gen_inputs, gen_str def get_choice(self, gen_inputs, options, state, conv_dict=None): state = torch.as_tensor(state, device=self.device) outputs = self.accelerator.unwrap_model(self.crs_conv_model).generate( **gen_inputs, min_new_tokens=2, max_new_tokens=2, num_beams=1, return_dict_in_generate=True, output_scores=True, ) option_token_ids = [ self.tokenizer.encode(op, add_special_tokens=False)[0] for op in options ] option_scores = outputs.scores[-1][0][option_token_ids] option_scores += state option_with_max_score = options[torch.argmax(option_scores)] return option_with_max_score def get_response( self, conv_dict: Dict[str, Any], id2entity: Dict[int, str], options: Tuple[str, Dict[str, str]], state: List[float], ) -> Tuple[str, List[float]]: """Generates a response given a conversation context. The method is based on the logic of the ask mode (i.e., see `scripts/ask.py`). It consists of two steps: (1) choose to either recommend items or generate a response, and (2) execute the chosen step. Slightly deviates from the original implementation by not using templates. Args: conv_dict: Conversation context. id2entity: Mapping from entity id to entity name. options: Prompt with options and dictionary of options. state: State of the option choices. Returns: Generated response and updated state. """ generated_inputs, generated_response = self.get_conv(conv_dict) options_letter = list(options[1].keys()) # Get the choice between recommend and generate choice = self.get_choice(generated_inputs, options_letter, state) if choice == options_letter[-1]: # Generate a recommendation recommended_items, _ = self.get_rec(conv_dict) recommended_items_str = "" for i, item_id in enumerate(recommended_items[0][:3]): recommended_items_str += f"{i+1}: {id2entity[item_id]} \n" response = ( "I would recommend the following items: \n" f"{recommended_items_str}" ) else: # Original : Generate a response to ask for preferences. The # fallback is to use the generated response. # response = ( # options[1].get(choice, {}).get("template", generated_response) # ) response = generated_response # Update the state. Hack: penalize the choice to reduce the # likelihood of selecting the same choice again state[options_letter.index(choice)] += -1e5 return response, state if __name__ == "__main__": # print(sys.path) kbrd = KBRD( seed=42, kg_dataset="redial", debug=False, hidden_size=128, num_bases=8, rec_model=f"/mnt/tangxinyu/crs/eval_model/redial_rec/best", conv_model="/mnt/tangxinyu/crs/eval_model/redial_conv/final/", encoder_layers=2, decoder_layers=2, attn_head=2, resp_max_length=128, text_hidden_size=300, entity_hidden_size=128, context_max_length=200, entity_max_length=32, tokenizer_path="../utils/tokenizer/bart-base", ) # print(kbrd) context_dict = { "dialog_id": "20001", "turn_id": 1, "context": ["Hi I am looking for a movie like Super Troopers (2001)"], "entity": ["Super Troopers (2001)"], "rec": ["Police Academy (1984)"], "resp": "You should watch Police Academy (1984)", "template": [ "Hi I am looking for a movie like ", "You should watch ", ], } preds, labels = kbrd.get_rec(context_dict) gen_seq = kbrd.get_conv(context_dict)