import math import os from typing import Optional, List, Union, Tuple import torch from loguru import logger from torch import nn from torch.nn import functional as F, CrossEntropyLoss from torch_geometric.nn import RGCNConv from transformers import BartPretrainedModel, BartConfig, BartModel from transformers.modeling_outputs import Seq2SeqLMOutput import sys sys.path.append("..") from src.model.utils import SelfAttention, shift_tokens_right class KBRDforRec(nn.Module): def __init__(self, hidden_size, num_relations, num_bases, num_entities): super(KBRDforRec, self).__init__() # kg encoder self.kg_encoder = RGCNConv( hidden_size, hidden_size, num_relations=num_relations, num_bases=num_bases ) self.node_embeds = nn.Parameter(torch.empty(num_entities, hidden_size)) stdv = math.sqrt(6.0 / (self.node_embeds.size(-2) + self.node_embeds.size(-1))) self.node_embeds.data.uniform_(-stdv, stdv) self.special_token_embeddings = nn.Parameter( torch.zeros(1, hidden_size), requires_grad=False ) self.attn = SelfAttention(hidden_size) def get_node_embeds(self, edge_index, edge_type): node_embeds = self.kg_encoder(self.node_embeds, edge_index, edge_type) node_embeds = torch.cat([node_embeds, self.special_token_embeddings], dim=0) return node_embeds def forward( self, entity_embeds=None, entity_ids=None, edge_index=None, edge_type=None, node_embeds=None, entity_mask=None, labels=None, reduction="none", ): if node_embeds is None: node_embeds = self.get_node_embeds(edge_index, edge_type) if entity_embeds is None: entity_embeds = node_embeds[entity_ids] # (bs, seq_len, hs) user_embeds = self.attn(entity_embeds, entity_mask) logits = user_embeds @ node_embeds.T # (bs, n_node) loss = None if labels is not None: loss = F.cross_entropy(logits, labels, reduction=reduction) return {"loss": loss, "logit": logits, "user_embeds": user_embeds} def save(self, save_dir): os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, "model.pt") torch.save(self.state_dict(), save_path) def load(self, load_dir): load_path = os.path.join(load_dir, "model.pt") missing_keys, unexpected_keys = self.load_state_dict( torch.load(load_path, map_location=torch.device("cpu")) ) class KBRDforConv(BartPretrainedModel): base_model_prefix = "model" _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head.weight"] def __init__(self, config: BartConfig, user_hidden_size): super().__init__(config) self.model = BartModel(config) self.register_buffer( "final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)) ) self.lm_head = nn.Linear( config.d_model, self.model.shared.num_embeddings, bias=False ) self.rec_proj = nn.Linear(user_hidden_size, self.model.shared.num_embeddings) # Initialize weights and apply final processing self.post_init() def get_encoder(self): return self.model.get_encoder() def get_decoder(self): return self.model.get_decoder() def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: new_embeddings = super().resize_token_embeddings(new_num_tokens) self._resize_final_logits_bias(new_num_tokens) return new_embeddings def _resize_final_logits_bias(self, new_num_tokens: int) -> None: old_num_tokens = self.final_logits_bias.shape[-1] if new_num_tokens <= old_num_tokens: new_bias = self.final_logits_bias[:, :new_num_tokens] else: extra_bias = torch.zeros( (1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device, ) new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) self.register_buffer("final_logits_bias", new_bias) def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, decoder_user_embeds=None, ) -> Union[Tuple, Seq2SeqLMOutput]: return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if labels is not None: if use_cache: logger.warning( "The `use_cache` argument is changed to `False` since `labels` is provided." ) use_cache = False if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) outputs = self.model( input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) lm_logits = ( self.lm_head(outputs[0]) + self.final_logits_bias + self.rec_proj(decoder_user_embeds).unsqueeze(1) ) masked_lm_loss = None if labels is not None: loss_fct = CrossEntropyLoss() masked_lm_loss = loss_fct( lm_logits.view(-1, self.config.vocab_size), labels.view(-1) ) if not return_dict: output = (lm_logits,) + outputs[1:] return ( ((masked_lm_loss,) + output) if masked_lm_loss is not None else output ) return Seq2SeqLMOutput( loss=masked_lm_loss, logits=lm_logits, past_key_values=outputs.past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, cross_attentions=outputs.cross_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, ) def prepare_inputs_for_generation( self, decoder_input_ids, past=None, attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, use_cache=None, encoder_outputs=None, decoder_user_embeds=None, **kwargs ): # cut decoder_input_ids if past is used if past is not None: decoder_input_ids = decoder_input_ids[:, -1:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) "decoder_user_embeds": decoder_user_embeds, } def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) @staticmethod def _reorder_cache(past, beam_idx): reordered_past = () for layer_past in past: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( tuple( past_state.index_select(0, beam_idx) for past_state in layer_past[:2] ) + layer_past[2:], ) return reordered_past